What is it?
This paper proposes a variational approach for training directed graphical models with continuous latent variables and intractable posteriors/marginals. The idea is to reparameterize the latent variables so that they can be written as a deterministic mapping followed by a stochastic perturbation. This allows Monte Carlo estimators of the variational lower bound to be differentiated with respect to the variational parameters.
Anonymous Openreview Reviewer
Problem Setting:
- A massive dataset $X=\{x^{(i)} \}_{i=1}^N$ consisting of N i.i.d. samples of some continuous or discrete variables $x$
- There exists a simple latent space $z$:
$z \sim p_{\theta^{*}} (z)$
$x|z \sim p_{\theta^{*}} (x|z) $
- Furthermore, we assume the prior $p_{\theta^{*}} (z)$ and likelihood $ p_{\theta^{*}} (x|z) $ come from parametric families of distributions $p_{\theta}(z)$ and $p_{\theta} (x|z) $. And their PDF (Probability Density Function) are differentiable almost everywhere w.r.t. to both $\theta$ and $z$
What do we want to achieve?
- Learn (approximately) via ML/MAP estimation for the parameter $\theta$
- Approximate $p_{\theta} (z|x)$ given $\theta$ and an observed value $x$, useful for coding or data representation task
- Approximately marginal inference of the variable $x$. $p_{\theta}(x)=\int p_{\theta} (x|z)p_{\theta}(z)dz$
Difficulties we are facing:
- Intractability: $p_{\theta}(x)=\int p_{\theta} (x|z)p_{\theta}(z)dz$, $p_{\theta}(z|x)=\frac{ \int p_{\theta} (x|z)p_{\theta}(z)dz} {p_{\theta}(x)}$ are all intractable. Therefore:
- Standard EM algorithm cannot be used
- Mean-field VB is often also impossible(it requires closed-form solutions to certain expectations of the joint PDF)[2]
- A large dataset: Sampling based methods, e.g., MCMC, Monte Carlo EM are still possible but too slow[2].
- Pure MAP will lead to over-fitting when dimension of $z$ is high[2].
Auto-encoding Variational Bayes
Idea:
- Learn a neural net parameterized by $\phi$ $p_{\phi}(z|x)$to approximate the posterior
- Construct estimator of the variational lower bound, jointly optimizing $\phi$ and $\theta$
|
Graphical model under consideration. Solid lines denote the generative model $p_{\theta}(z) p_{\theta}(x|z)$ while dashed lines denote the variational approximation $q_{\phi}(z|x)$ to the intractable posterior $p_{\theta}(z|x) $ From [1]'s slides |
The variational bound:
$\log p_{\theta}(x^{(1)}, ..., x^{(N)})=\sum_{i=1}^N \log p_{\theta} (x^{(i)})$
for each $i$:
\begin{align}
\log p_{\theta}(x^{(i)})&=E_{z \sim q_{\phi}(z|x^{(i)}) } [\log p_{\theta}(x^{(i)})] \\
&=E_z [\log \frac{p_{\theta}(x^{(i)} |z) p_{\theta}(z)} {p_{\theta}(z | x^{(i)} ) }] \: (\textrm{Bayes' Rule}) \\
&=E_z[\log \frac{p_{\theta}(x^{(i)} |z) p_{\theta}(z)} {p_{\theta}(z | x^{(i)} ) } \frac{ q_{\phi} (z| x^{(i)}) } {q_{\phi} (z| x^{(i)}) } ] \\
&=E_z [ \log p_{\theta} ( x^{(i)} | z)] - E_z[ \log \frac{q_{\phi} (z| x^{(i)}) }{p_{\theta}(z) }] + E_z [ \log \frac{q_{\phi} (z| x^{(i)}) } {p_{\theta}(z | x^{(i)}) }] \\
&=\color{red} { E_z[ \log p_{\theta} ( x^{(i)} | z)] } - \color{green}{ D_{KL}( q_{\phi} (z| x^{(i)}) || p_{\theta}(z) ) } +\color{blue} { D_{KL}( q_{\phi} (z| x^{(i)}) || p_{\theta}(z | x^{(i)})) }\\
&= \mathcal{L}(\theta, \phi; x^{(i)}) + D_{KL}( q_{\phi} (z| x^{(i)}) || p_{\theta}(z | x^{(i)})) \\
&\ge \mathcal{L}(\theta, \phi; x^{(i)}) \: \textrm{(Variational lower bound)}
\end{align}
In the red term, $ \log p_{\theta} ( x^{(i)} | z)]$ is the probability of $x^{(i)}$ given $z$ we have sampled. This term is greater if given true ( or more "truthful") latent encoding $z|x^{(i)}$. Therefore, red term is also called
reconstruction error/loss.
Green term is the KL between probabilistic encoder $q_{\phi} (z| x^{(i)})$ and prior $p_{\theta}(z)$. It encourages the model to learn some more interesting latent encoding rather than simply encoding an identity mapping. This term is called
regularization term.
Blue term is the KL between two intractable posteriors. To my knowledge there isn't good way to deal with it. We simply acknowledge this is a non-negative term.
Red term minus green term yields the classical
variational lower bound:
\begin{align}
\mathcal{L}(\theta, \phi; x^{(i)}) &= E_z[ \log p_{\theta} ( x^{(i)} | z)] - D_{KL}( q_{\phi} (z| x^{(i)}) || p_{\theta}(z) ) \\
&=E_z[ \log p_{\theta} ( x^{(i)} ,z) - \log q_{\phi} (z| x^{(i)}) ]
\end{align}
Possible ways to deal with this bound:
Define:
$f_{\theta, \phi}(z)= \log p_{\theta} ( x^{(i)} ,z) - \log q_{\phi} (z| x^{(i)}) $
$$
\begin{align}
\mathcal{L}(\theta, \phi; x^{(i)}) &= E_{q_{\phi}(z|x^{(i)})}[f_{\theta, \phi}(z)] \\
&\approx \frac{1}{L} \sum_{i=1}^L f_{\theta, \phi}(z^{(l)}) \:\textrm{(Monte Carlo Estimator of this bound)}
\end{align}
$$
where $z^{(l)} \sim q_{\phi}(z|x)$
We also need to compute the gradient:
$ \nabla_{\theta, \phi} E_{q_\phi(z)} \left [ \log p_\theta(x,z) - \log q_\phi(z|x) \right] $
For $\nabla_{\theta}$ things are not too hard:
$ \nabla_{\theta} E_{q_\phi(z)} \left [ f_{\theta, \phi}(z) \right]=E_{q_\phi(z)} \left[ \nabla_{\theta} \log p_\theta(x,z) \right]
$
For $\nabla_{\phi}$ things become a bit more complicated:
\begin{align}
\nabla_{\phi} E_{q_\phi(z)} \left [ f_{\theta, \phi}(z) \right] &= \nabla_{\phi} E_{q_\phi(z)} \left [ \log p_\theta(x,z) - \log q_\phi(z|x) \right] \\
&= \nabla_{\phi} \int{\log p_\theta(x,z) q_\phi(z|x) dz} -\nabla_{\phi} \int{q_\phi(z|x) \log q_\phi(z|x) dz }\: \textrm{(Assuming continuous latent variable case)} \\
&=\int{\log p_\theta(x,z) \nabla_{\phi} q_\phi(z|x) dz} - \int{(\log q_\phi(z|x) +1) \nabla_{\phi}
q_\phi(z|x) dz } \\
&= \int{ (\log p_\theta(x,z) - \log q_\phi(z|x) ) \nabla_{\phi} q_\phi(z|x) dz }
\end{align}
where we use the fact that
$$
\begin{align}
\int{ \nabla_{\phi} q_\phi(z|x) dz } &=\nabla_{\phi}\int{ q_\phi(z|x) dz } \\
&= \nabla_{\phi} 1 \\
&= 0
\end{align}
$$
Again, we use the identity
$\nabla_{\phi} q_\phi(z|x) = q_\phi(z|x) \nabla_{\phi} \log{q_\phi(z|x) }$
(Why? Because $\nabla_{\phi} \log{q_\phi(z|x) }= \frac{ \nabla_{\phi} q_\phi(z|x) } {q_\phi(z|x) }q_\phi(z|x) $) )
So we have:
\begin{align}
\nabla_{\phi} E_{q_\phi(z)} \left [ f_{\theta, \phi}(z) \right] &= \nabla_{\phi} E_{q_\phi(z)} \left [ \log p_\theta(x,z) - \log q_\phi(z|x) \right] \\
&=\int{ (\log p_\theta(x,z) - \log q_\phi(z|x) ) \nabla_{\phi} q_\phi(z|x) dz } \\
&=\int{ (\log p_\theta(x,z) - \log q_\phi(z|x) ) q_\phi(z|x) \nabla_{\phi} \log{q_\phi(z|x) } }\\
&=\color{red } { E_{q_\phi(z|x)} [ (\log p_\theta(x,z) - \log q_\phi(z|x) ) \nabla_{\phi}\log{q_\phi(z|x) } ] }
\end{align}
An possible estimator can be:
\begin{align}
\nabla_{\phi} E_{q_\phi(z)} \left [ f_{\theta, \phi}(z) \right]
&=\color{red } { E_{q_\phi(z|x)} [ (\log p_\theta(x,z) - \log q_\phi(z|x) ) \nabla_{\phi}\log{q_\phi(z|x) } ] } \\
&\approx \color{blue}{\frac{1}{L} \sum_{l=1}^L f_{\theta, \phi}(z) \nabla_{\phi} \log q_\phi(z^{(i)} |x)}
\end{align}
where $z^{(l)} \sim q_{\phi}(z|x)$
However, this estimator exhibits very high variance.
Re-parameterization trick:
The core of re-parameterization trick is to rewrite the $q_{\phi}(z|x)$ in terms of a deterministic mapping $g$ of both $x$ and auxiliary noise variable $\epsilon $ sampled from some probability distribution $p(\epsilon)$.
i.e.:
$\widetilde{\mathbf{z}}=g_{\phi}(\epsilon, \mathbf{x})$ with $\boldsymbol{\epsilon} \sim p(\boldsymbol{\epsilon})$
|
Re-parameterization moves the random part to the $\epsilon$ which is reflected in the sampling process. Picture from [1]. |
Now the Monte Carlo estimates of the expectations of some function $f(z)$ w.r.t. $q_{\phi}(z|x)$ is as follows:
$\mathbb{E}_{q_{\phi}\left(\mathbf{z} | \mathbf{x}^{(i)}\right)}[f(\mathbf{z})]=\mathbb{E}_{p(\epsilon)}\left[f\left(g_{\phi}\left(\boldsymbol{\epsilon}, \mathbf{x}^{(i)}\right)\right)\right] \simeq \frac{1}{L} \sum_{l=1}^{L} f\left(g_{\phi}\left(\boldsymbol{\epsilon}^{(l)}, \mathbf{x}^{(i)}\right)\right)$
where
$\boldsymbol{\epsilon}^{(l)} \sim p(\boldsymbol{\epsilon})$.
And the variational lower bound can be approximated rewritten in two ways, yielding two Stochastic Gradient Variational Bayes Estimator :
$$\begin{align}
\mathcal{L}\left(\boldsymbol{\theta}, \boldsymbol{\phi} ; \mathbf{x}^{(i)}\right) &\approx
\widetilde{\mathcal{L}}^{A}\left(\boldsymbol{\theta}, \boldsymbol{\phi} ; \mathbf{x}^{(i)}\right) \\
&=\color{blue} { \frac{1}{L} \sum_{l=1}^{L} \log p_{\boldsymbol{\theta}}\left(\mathbf{x}^{(i)}, \mathbf{z}^{(i, l)}\right)-\log q_{\boldsymbol{\phi}}\left(\mathbf{z}^{(i, l)} | \mathbf{x}^{(i)}\right) }
\end{align}
$$
where $\mathbf{z}^{(i, l)}=g_{\phi}\left(\boldsymbol{\epsilon}^{(i, l)}, \mathbf{x}^{(i)}\right)$
and $\boldsymbol{\epsilon}^{(l)} \sim p(\boldsymbol{\epsilon})$
2. Often the KL-divergence $D_{K L}\left(q_{\phi}\left(\mathbf{z} | \mathbf{x}^{(i)}\right) \| p_{\boldsymbol{\theta}}(\mathbf{z})\right)$ can be integrated analytically (especially true when both distributions are gaussians) leading to our second version of estimator:
$$\begin{align} \mathcal{L}\left(\boldsymbol{\theta}, \boldsymbol{\phi} ; \mathbf{x}^{(i)}\right) &\approx \widetilde{\mathcal{L}}^{B}\left(\boldsymbol{\theta}, \boldsymbol{\phi} ; \mathbf{x}^{(i)}\right) \\
&=\color{blue} {-D_{K L}\left(q_{\phi}\left(\mathbf{z} | \mathbf{x}^{(i)}\right) \| p_{\boldsymbol{\theta}}(\mathbf{z})\right)+\frac{1}{L} \sum_{l=1}^{L}\left(\log p_{\boldsymbol{\theta}}\left(\mathbf{x}^{(i)} | \mathbf{z}^{(i, l)}\right)\right)}
\end{align} $$
where $\mathbf{z}^{(i, l)}=g_{\phi}\left(\boldsymbol{\epsilon}^{(i, l)}, \mathbf{x}^{(i)}\right)$
and $\boldsymbol{\epsilon}^{(l)} \sim p(\boldsymbol{\epsilon})$
When given multiple datapoints from dataset $X={x^{(i)} }_{i=1}^N $, we can construct an estimator based on mini-batches:
$\mathcal{L}(\boldsymbol{\theta}, \boldsymbol{\phi} ; \mathbf{X}) \simeq \widetilde{\mathcal{L}}^{M}\left(\boldsymbol{\theta}, \boldsymbol{\phi} ; \mathbf{X}^{M}\right)=\frac{N}{M} \sum_{i=1}^{M} \widetilde{\mathcal{L}}\left(\boldsymbol{\theta}, \boldsymbol{\phi} ; \mathbf{x}^{(i)}\right)$
where $\mathbf{X}^{M}=\left\{\mathbf{x}^{(i)}\right\}_{i=1}^{M}$ is the randomly drawn sample.
$\nabla_{\boldsymbol{\theta}, \boldsymbol{\phi}} \widetilde{\mathcal{L}}\left(\boldsymbol{\theta} ; \mathbf{X}^{M}\right)$ can be taken, of course.
The intuition behind this re-parameterization trick is:
we want to find to find some $p(\epsilon)$ such that
$q_{\phi}(\mathbf{z} | \mathbf{x}) \prod_{i} d z_{i}=p(\epsilon) \prod_{i} d \epsilon_{i}$
therefore, $\int q_{\phi}(\mathbf{z} | \mathbf{x}) f(\mathbf{z}) d \mathbf{z}=\int p(\boldsymbol{\epsilon}) f(\mathbf{z}) d \boldsymbol{\epsilon}=\int p(\boldsymbol{\epsilon}) f\left(g_{\phi}(\boldsymbol{\epsilon}, \mathbf{x})\right) d \boldsymbol{\epsilon}$
and so a differentiable estimator can be constructed: $\int q_{\phi}(\mathbf{z} | \mathbf{x}) f(\mathbf{z}) d \mathbf{z} \simeq \frac{1}{L} \sum_{l=1}^{L} f\left(g_{\phi}\left(\mathbf{x}, \epsilon^{(l)}\right)\right)$ where $\boldsymbol{\epsilon}^{(l)} \sim p(\boldsymbol{\epsilon})$.
Concrete ways of finding reparameterizations can be found from the original paper[7].
But the author didn't mention why re-parameterization helps alleviate variance, possible explanations can be found in the appendix section of [8].
Ref:
- Kingma's NIPS 2015 workshop slides
- Kingma's Stochastic Gradient VB. Intractable posterior distributions? Gradients to the rescue!
- https://ermongroup.github.io/cs228-notes/extras/vae/
- CS231n Lecture 11 Slides
- Mnih, A., & Gregor, K. (n.d.). Neural Variational Inference and Learning in Belief Networks.
- Openreview page for the original paper
- Kingma, D., & Welling, M. (2014). Auto-Encoding Variational Bayes
- Rezende, D., Mohamed, S., & Wierstra, D. (2014). Stochastic Backpropagation and Approximate Inference in Deep Generative Models.