Monday, June 17, 2019

Importance Weighted Autoencoders

Key Idea:

Original VAE optimize over variational lower bound:
$\mathcal{L}(x)= E_{q(h|x)}[\log \frac{p(x,h)}{q(h|x)}]$
This paper introduces a tighter bound (k-sample importance weighting estimate)
$\mathcal{L}_k(x)= E_{h_1, ..., h_k \sim q(h|x)}[\log \frac{1}{k} \sum_{i=1}^k \frac{p(x,h_i)}{q(h_i|x)}]$

Furthermore, the author claims:

  • $\log p(x) \ge \mathcal{L}_{k+1} \ge \mathcal{L}_k$
  • $\log p(x)= \lim_{k \to \infty} \mathcal{L}_k$ if $\frac{p(h,x)}{q(h|x)}$ is bounded

1. First, $\mathcal{L}_{k}=\mathbb{E}\left[\log \frac{1}{k} \sum_{i=1}^{k} w_{i}\right] \leq \log \mathbb{E}\left[\frac{1}{k} \sum_{i=1}^{k} w_{i}\right]=\log p(\mathbf{x})$

2. $I \subset \{1, 2, ..., k \}$ ($I$ being uniformly drawn from $\{ 1, 2, ..., k \} , |I|=m , m \le k$ )
$\mathbb{E}_{I=\left\{i_{1}, \ldots, i_{m}\right\} } \left[\frac{a_{i_{1}}+\ldots+a_{i_{m}}}{m}\right]= \frac{a_{1}+\ldots+a_{k}}{k}$
$$ \begin{align} \mathcal{L}_{k} &=\mathbb{E}_{\mathbf{h}_{1}, \ldots, \mathbf{h}_{k}}\left[\log \frac{1}{k} \sum_{i=1}^{k} \frac{p\left(\mathbf{x}, \mathbf{h}_{i}\right)}{q\left(\mathbf{h}_{i} | \mathbf{x}\right)}\right] \\ &=\mathbb{E}_{\mathbf{h}_{1}, \ldots, \mathbf{h}_{k}}\left[\log \mathbb{E}_{I=\left\{i_{1}, \ldots, i_{m}\right\}}\left[\frac{1}{m} \sum_{j=1}^{m} \frac{p\left(\mathbf{x}, \mathbf{h}_{i_{j}}\right)}{q\left(\mathbf{h}_{i_{j}} | \mathbf{x}\right)}\right]\right] \\ &\ge \mathbb{E}_{\mathbf{h}_{1}, \ldots, \mathbf{h}_{k}}\left[\mathbb{E}_{I=\left\{i_{1}, \ldots, i_{m}\right\}}\left[\log \frac{1}{m} \sum_{j=1}^{m} \frac{p\left(\mathbf{x}, \mathbf{h}_{i_{j}}\right)}{q\left(\mathbf{h}_{i_{j}} | \mathbf{x}\right)}\right]\right] \textrm{(Jensen's inequality) }\\ &=\mathbb{E}_{\mathbf{h}_{1}, \ldots, \mathbf{h}_{m}}\left[\log \frac{1}{m} \sum_{i=1}^{m} \frac{p\left(\mathbf{x}, \mathbf{h}_{i}\right)}{q\left(\mathbf{h}_{i} | \mathbf{x}\right)}\right]\\ &=\mathcal{L}_{m} \textrm{($k \ge m$)} \end{align} $$

where $h_{i_j}= h_j , j \in I \subset \{ 1, 2, ..., k \}$

3. Consider the random variable $M_{k}=\frac{1}{k} \sum_{i=1}^{k} \frac{p\left(\mathbf{x}, \mathbf{h}_{i}\right)}{q\left(\mathbf{h}_{i} | \mathbf{x}\right)}$, if $p(\mathbf{h}, \mathbf{x}) / q(\mathbf{h} | \mathbf{x})$ is bounded. Then it follows from the strong law of large numbers that $M_k$ converges to $E_{q\left(\mathbf{h}_{i} | \mathbf{x}\right)}\left[\frac{p\left(\mathbf{x}, \mathbf{h}_{i}\right)}{q\left(\mathbf{h}_{i} | \mathbf{x}\right)}\right]= p(\mathbf{x})$  almost surely.

In the paper[1], the author uses mean absolute deviation to show the overall estimation variance cannot be large.  Furthermore, the author proves that the mean absolute variance of our estimator $\mathcal{L}_k $ is no more than $2+2\delta$, where $\delta=\log p(x)- \mathcal{L}_k$

Training Procedure:

\begin{align} \nabla_{\boldsymbol{\theta}} \mathcal{L}_{k}(\mathbf{x}) &=\nabla_{\boldsymbol{\theta}} \mathbb{E}_{\mathbf{h}_{1}, \ldots, \mathbf{h}_{k}}\left[\log \frac{1}{k} \sum_{i=1}^{k} w_{i}\right] \\ &=\nabla_{\boldsymbol{\theta}} \mathbb{E}_{\boldsymbol{\epsilon}_{1}, \ldots, \epsilon_{k}}\left[\log \frac{1}{k} \sum_{i=1}^{k} w\left(\mathbf{x}, \mathbf{h}\left(\mathbf{x}, \boldsymbol{\epsilon}_{i}, \boldsymbol{\theta}\right), \boldsymbol{\theta}\right)\right] \\ &=\mathbb{E}_{\epsilon_{1}, \ldots, \epsilon_{k}}\left[\nabla_{\boldsymbol{\theta}} \log \frac{1}{k} \sum_{i=1}^{k} w\left(\mathbf{x}, \mathbf{h}\left(\mathbf{x}, \boldsymbol{\epsilon}_{i}, \boldsymbol{\theta}\right), \boldsymbol{\theta}\right)\right] \\ &=\mathbb{E}_{\epsilon_{1}, \ldots, \epsilon_{k}}\left[\sum_{i=1}^{k} \widetilde{w}_{i} \nabla_{\boldsymbol{\theta}} \log w\left(\mathbf{x}, \mathbf{h}\left(\mathbf{x}, \boldsymbol{\epsilon}_{i}, \boldsymbol{\theta}\right), \boldsymbol{\theta}\right)\right] \end{align}

where $\epsilon_{1}, \dots, \epsilon_{k}$ are auxiliary variables, $\epsilon_i \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$;
$w_i=w\left(\mathbf{x}, \mathbf{h}\left(\mathbf{x}, \boldsymbol{\epsilon}_{i}, \boldsymbol{\theta}\right), \boldsymbol{\theta}\right)$ are the importance weights expressed as a deterministic function;

$\widetilde{w_{i}}=\frac {w_{i}}  { \sum_{i=1}^{k} w_{i} } $ are the normalized importance weights.

In the gradient-based algorithm, we draw $k$ samples from the recognition network (i.e., draw k sets of auxiliary variables) and use the Monte Carlo estimate:
$\sum_{i=1}^{k} \widetilde{w}_{i} \nabla_{\boldsymbol{\theta}} \log w\left(\mathbf{x}, \mathbf{h}\left(\boldsymbol{\epsilon}_{i}, \mathbf{x}, \boldsymbol{\theta}\right), \boldsymbol{\theta}\right)$.
If $k=1$ then the above degenerates into the standard VAE update rule.


The implementation part is rather interesting, too.
We analyze the implementation from this place. First, two classes are defined:

  • IWAE_1 for one-stochastic-layer IWAE
  • IWAE_2 for two-stochastic-layer IWAE 
IWAE_1 and IWAE_2 are roughly similar. We focus on IWAE_1:

class IWAE_1(nn.Module):
    def __init__(self, dim_h1, dim_image_vars):
        super(IWAE_1, self).__init__()
        self.dim_h1 = dim_h1
        self.dim_image_vars = dim_image_vars

        ## encoder
        self.encoder_h1 = BasicBlock(dim_image_vars, 200, dim_h1)
        ## decoder
        self.decoder_x =  nn.Sequential(nn.Linear(dim_h1, 200),
                                        nn.Linear(200, 200),
                                        nn.Linear(200, dim_image_vars),
    def encoder(self, x):
        mu_h1, sigma_h1 = self.encoder_h1(x)
        eps = Variable(sigma_h1.data.new(sigma_h1.size()).normal_())
        h1 = mu_h1 + sigma_h1 * eps                
        return h1, mu_h1, sigma_h1, eps
    def decoder(self, h1):
        p = self.decoder_x(h1)
        return p
    def forward(self, x):
        h1, mu_h1, sigma_h1, eps = self.encoder(x)
        p = self.decoder(h1)
        return (h1, mu_h1, sigma_h1, eps), (p)

forward()  method invokes both encoder() and decoder() which makes intuitive sense. encoder() returns approximated posterior Gaussian density, mean and variance of that Gaussian and auxiliary variable eps. decoder() directly invokes:

nn.Sequential(nn.Linear(dim_h1, 200),   nn.Tanh(),
                                        nn.Linear(200, 200),
                                        nn.Linear(200, dim_image_vars),
 This is perhaps why PyTorch is so popular----very intuitive framework.

The loss function is worth noticing:

    def train_loss(self, inputs):
        h1, mu_h1, sigma_h1, eps = self.encoder(inputs)
        #log_Qh1Gx = torch.sum(-0.5*((h1-mu_h1)/sigma_h1)**2 - torch.log(sigma_h1), -1)
        log_Qh1Gx = torch.sum(-0.5*(eps)**2 - torch.log(sigma_h1), -1)
        p = self.decoder(h1)
        log_Ph1 = torch.sum(-0.5*h1**2, -1)
        log_PxGh1 = torch.sum(inputs*torch.log(p) + (1-inputs)*torch.log(1-p), -1)

        log_weight = log_Ph1 + log_PxGh1 - log_Qh1Gx
        log_weight = log_weight - torch.max(log_weight, 0)[0]
        weight = torch.exp(log_weight)
        weight = weight / torch.sum(weight, 0)
        weight = Variable(weight.data, requires_grad = False)
        loss = -torch.mean(torch.sum(weight * (log_Ph1 + log_PxGh1 - log_Qh1Gx), 0))
        return loss
Notice the loss. We might think the loss should just $-\mathcal{L}_k$ and let Adam optimize it. But that's not the case. We decompose the ELBO:
\nabla_{\boldsymbol{\theta}} \log w\left(\mathbf{x}, \mathbf{h}\left(\mathbf{x}, \boldsymbol{\epsilon}_{i}, \boldsymbol{\theta}\right), \boldsymbol{\theta}\right) &=\nabla_{\boldsymbol{\theta}} \log p\left(\mathbf{x}, \mathbf{h}\left(\mathbf{x}, \boldsymbol{\epsilon}_{i}, \boldsymbol{\theta}\right) | \boldsymbol{\theta}\right)-\nabla_{\boldsymbol{\theta}} \log q\left(\mathbf{h}\left(\mathbf{x}, \boldsymbol{\epsilon}_{i}, \boldsymbol{\theta}\right) | \mathbf{x}, \boldsymbol{\theta}\right) \\

&=\nabla_{\boldsymbol{\theta}} \log p\left(\mathbf{x}| \mathbf{h}\left(\mathbf{x}, \boldsymbol{\epsilon}_{i}, \boldsymbol{\theta}\right) , \boldsymbol{\theta}\right) + \nabla_{\boldsymbol{\theta}} \log p\left( \mathbf{h}\left(\mathbf{x}, \boldsymbol{\epsilon}_{i}, \boldsymbol{\theta}\right) , \boldsymbol{\theta}\right) -\nabla_{\boldsymbol{\theta}} \log q\left(\mathbf{h}\left(\mathbf{x}, \boldsymbol{\epsilon}_{i}, \boldsymbol{\theta}\right) | \mathbf{x}, \boldsymbol{\theta}\right) \\

&= \nabla_{\boldsymbol{\theta}} \: \textrm{log_PxGh1}  + \nabla_{\boldsymbol{\theta}} \: \textrm{log_Ph1} -  \nabla_{\boldsymbol{\theta}} \: \textrm{log_Qh1Gx}


You may want to use version B of ELBO in VAE paper:

\begin{align} \mathcal{L}\left(\boldsymbol{\theta}, \boldsymbol{\phi} ; \mathbf{x}^{(i)}\right) &\approx \widetilde{\mathcal{L}}^{B}\left(\boldsymbol{\theta}, \boldsymbol{\phi} ; \mathbf{x}^{(i)}\right) \\ &=\color{blue} {-D_{K L}\left(q_{\phi}\left(\mathbf{z} | \mathbf{x}^{(i)}\right) \| p_{\boldsymbol{\theta}}(\mathbf{z})\right)+\frac{1}{L} \sum_{l=1}^{L}\left(\log p_{\boldsymbol{\theta}}\left(\mathbf{x}^{(i)} | \mathbf{z}^{(i, l)}\right)\right)} \end{align}
But this trick cannot be reapplied in $k>1$ case where $k$ is the sample number ( not mini-batch sample number).

log_weight = log_weight - torch.max(log_weight, 0)[0]

What is this line for? It's a typical way of dealing with numerical instability issue. We borrow the example from CS231n lecture notes:

"When you’re writing code for computing the Softmax function in practice, the intermediate terms $e^{f_{y_i}}$ and $\sum_j e^{f_j}$ may be very large due to the exponentials. Dividing large numbers can be numerically unstable, so it is important to use a normalization trick. Notice that if we multiply the top and bottom of the fraction by a constant C and push it into the sum, we get the following (mathematically equivalent) expression:

$ \frac{e^{f_{y_i}}}{\sum_j e^{f_j}} = \frac{Ce^{f_{y_i}}}{C\sum_j e^{f_j}} = \frac{e^{f_{y_i} + \log C}}{\sum_j e^{f_j + \log C}} $

We are free to choose the value of $C$. This will not change any of the results, but we can use this value to improve the numerical stability of the computation. A common choice for $C$ is to set $\log C=−max_j f_j$. This simply states that we should shift the values inside the vector f so that the highest value is zero. In code:"
f = np.array([123, 456, 789]) # example with 3 classes and each having large scores
p = np.exp(f) / np.sum(np.exp(f)) # Bad: Numeric problem, potential blowup

# instead: first shift the values of f so that the highest number is 0:
f -= np.max(f) # f becomes [-666, -333, 0]
p = np.exp(f) / np.sum(np.exp(f)) # safe to do, gives the correct answer

However I have reasonable suspicion on this line of code. Since $p(x) \lt 1$ , $\log p(x) <0$. Subtraction by the biggest only makes each negative term fly towards the negative infinity. Is that really beneficial?

In the paper, $\tilde{w}_i =\frac{w_i}{\sum_i^k w_i}$ are called normalized importance weights. Notice in the update formula:
$\mathbb{E}_{\epsilon_{1}, \ldots, \epsilon_{k}}\left[\sum_{i=1}^{k} \widetilde{w}_{i} \nabla_{\boldsymbol{\theta}} \log w\left(\mathbf{x}, \mathbf{h}\left(\mathbf{x}, \boldsymbol{\epsilon}_{i}, \boldsymbol{\theta}\right), \boldsymbol{\theta}\right)\right]$

We don't take the gradient w.r.t. the normalized importance weights. This detail, reflected on PyTorch code is the following:

weight = torch.exp(log_weight)
weight = weight / torch.sum(weight, 0)
weight = Variable(weight.data, requires_grad = False)

requires_grad = False 
will let PyTorch ignore variable weight when doing backpropagation, which is exactly what we want.


  1. Burda, Y., Grosse, R., & Salakhutdinov, R. (2016). IMPORTANCE WEIGHTED AUTOENCODERS. 
  2. PyTorch Implementation by Xinqiang Ding
  3. CS231n course notes

No comments:

Post a Comment