Paper Review: VAE - Auto-Encoding Variational Bayes

A review of the foundational paper on Variational Autoencoders

Introduction

Variational Auto-Encoders (VAEs) are powerful generative models that combine the ideas of variational inference with deep learning. Introduced by Kingma and Welling in 2013, VAEs provide a principled way to perform both inference and learning in deep latent variable models.

The key innovation of VAEs lies in their ability to learn complex probability distributions over high-dimensional data spaces by optimizing a variational lower bound using stochastic gradient descent. This allows them to generate new samples, perform dimensionality reduction, and learn useful representations of data.

Background

Variational Inference

Variational inference is a method from Bayesian statistics that approximates complex posterior distributions with simpler ones. Given a latent variable model \(p(x,z)\), we want to approximate the true posterior

$$p(z x)\(with a simpler distribution\)q_{\phi}(z x)$$.

Expectation-Maximization Algorithm

The EM algorithm is a classical approach for learning latent variable models. It consists of two steps:

  1. E-step: Compute the expected complete log-likelihood with respect to the posterior: \(Q(\theta|\theta^{\text{old}}) = \mathbb{E}_{p(z|x,\theta^{\text{old}})}[\log p(x,z|\theta)]\)

  2. M-step: Maximize this expectation with respect to \(\theta\): \(\theta^{\text{new}} = \arg\max_\theta Q(\theta|\theta^{\text{old}})\)

Key limitations of EM:

Monte Carlo EM

Monte Carlo EM (MCEM) attempts to address intractability by using sampling:

  1. E-step: Draw samples from posterior using MCMC: \(z^{(l)} \sim p(z|x,\theta^{\text{old}})\)

    \[Q(\theta|\theta^{\text{old}}) \approx \frac{1}{L}\sum_{l=1}^L \log p(x,z^{(l)}|\theta)\]
  2. M-step: Same as standard EM

Limitations of MCEM:

Auto-Encoders

Traditional auto-encoders learn to compress data into a lower-dimensional representation by minimizing reconstruction error. However, they lack a probabilistic interpretation and cannot generate new samples.

Problem Setting and Method

Problem Scenario

Consider a dataset \(X = \{x^{(i)}\}_{i=1}^N\) consisting of \(N\) i.i.d. samples. We assume the data are generated by a random process involving an unobserved continuous random variable \(z\):

  1. A value \(z^{(i)}\) is generated from a prior distribution \(p_{\theta^*}(z)\)
  2. A value \(x^{(i)}\) is generated from a conditional distribution \(p_{\theta^*}(x|z)\)

Key assumptions:

Important challenges addressed:

  1. Intractability: The marginal likelihood \(p_\theta(x) = \int p_\theta(z)p_\theta(x|z)dz\) is intractable
  2. Large datasets: Batch optimization is too costly; need efficient minibatch methods

The paper aims to solve:

  1. Efficient approximate ML/MAP estimation for \(\theta\)
  2. Efficient approximate posterior inference of \(z\) given \(x\)
  3. Efficient approximate marginal inference of \(x\)

The Variational Bound

For a single datapoint, the marginal likelihood can be rewritten as:

\[\log p_\theta(x^{(i)}) = D_{KL}(q_\phi(z|x^{(i)})||p_\theta(z|x^{(i)})) + \mathcal{L}(\theta, \phi; x^{(i)})\]

where:

The lower bound can be written in two instructive ways:

  1. As an expectation of a complete log-likelihood: \(\mathcal{L}(\theta, \phi; x^{(i)}) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x,z) - \log q_\phi(z|x)]\)

  2. As reconstruction error minus KL divergence: \(\mathcal{L}(\theta, \phi; x^{(i)}) = -D_{KL}(q_\phi(z|x^{(i)})||p_\theta(z)) + \mathbb{E}_{q_\phi(z|x^{(i)})}[\log p_\theta(x^{(i)}|z)]\)

A key challenge is that the naive Monte Carlo gradient estimator: \(\nabla_\phi \mathbb{E}_{q_\phi(z)}[f(z)] \approx \frac{1}{L}\sum_{l=1}^L f(z^{(l)})\nabla_{q_\phi(z^{(l)})}\log q_\phi(z^{(l)})\) exhibits very high variance. This motivates the reparameterization trick.

The VAE optimizes the evidence lower bound (ELBO):

\[\mathcal{L}(\theta,\phi;x) = \mathbb{E}_{q_{\phi}(z|x)}[\log p_{\theta}(x|z)] - D_{KL}(q_{\phi}(z|x)||p(z))\]

where:

The Reparameterization Trick

A key innovation of VAEs is the reparameterization trick, which enables backpropagation through random sampling. Instead of directly sampling from \(q_{\phi}(z|x)\), we sample from a simple distribution \(\epsilon \sim \mathcal{N}(0,I)\) and transform it:

\[z = \mu_{\phi}(x) + \sigma_{\phi}(x) \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0,I)\]

SGVB Estimator and AEVB Algorithm

The Stochastic Gradient Variational Bayes (SGVB) estimator comes in two forms:

  1. Generic estimator: \(\tilde{\mathcal{L}}^A(\theta, \phi; x^{(i)}) = \frac{1}{L}\sum_{l=1}^L \log p_\theta(x^{(i)}, z^{(i,l)}) - \log q_\phi(z^{(i,l)}|x^{(i)})\)

    where \(z^{(i,l)} = g_\phi(\epsilon^{(l)}, x^{(i)})\) and \(\epsilon^{(l)} \sim p(\epsilon)\)

  2. For cases where the KL divergence is analytically tractable: \(\tilde{\mathcal{L}}^B(\theta, \phi; x^{(i)}) = -D_{KL}(q_\phi(z|x^{(i)})||p_\theta(z)) + \frac{1}{L}\sum_{l=1}^L \log p_\theta(x^{(i)}|z^{(i,l)})\)

The complete Auto-Encoding Variational Bayes (AEVB) algorithm:

def AEVB_algorithm(X):
    θ, φ = initialize_parameters()
    
    while not converged:
        # Get minibatch
        XM = random_minibatch(X, M=100)
        
        # Get random samples from noise distribution
        ε = sample_noise(p(ε))
        
        # Compute gradients of minibatch estimator
        g = compute_gradients(θ,φ𝓛M(θ, φ; XM, ε))
        
        # Update parameters using e.g. SGD or Adagrad
        θ, φ = update_parameters(θ, φ, g)
    
    return θ, φ

For Gaussian encoder and prior, the KL divergence term has the analytical solution:

\[-D_{KL} = \frac{1}{2}\sum_{j=1}^J(1 + \log((\sigma_j)^2) - (\mu_j)^2 - (\sigma_j)^2)\]

The Stochastic Gradient Variational Bayes (SGVB) estimator allows efficient optimization of the ELBO using stochastic gradient descent:

\[\mathcal{L} \approx \frac{1}{L}\sum_{l=1}^L \log p_{\theta}(x|z^{(l)}) - D_{KL}(q_{\phi}(z|x)||p(z))\]

where \(z^{(l)}\) are samples drawn using the reparameterization trick.

Architecture and Implementation Details

Neural Network Architecture

For real-valued data, both encoder and decoder are implemented as MLPs:

  1. Encoder (Recognition Model) \(q_\phi(z|x)\):
    def encoder(x):
        h = tanh(W3 @ x + b3)
        μ = W4 @ h + b4
        log_σ2 = W5 @ h + b5
        return μ, log_σ2
    
  2. Decoder (Generative Model) \(p_\theta(x|z)\):
    def decoder(z):
        h = tanh(W1 @ z + b1)
        x_recon = sigmoid(W2 @ h + b2)  # for binary data
        # or for continuous data:
        # μ = W2 @ h + b2
        # log_σ2 = W3 @ h + b3
        return x_recon
    

Approximate Posterior

For continuous latent variables, we use a Gaussian approximate posterior with diagonal covariance:

\[\log q_\phi(z|x^{(i)}) = \log \mathcal{N}(z; \mu^{(i)}, \sigma^{2(i)}I)\]

where \(\mu^{(i)}\) and \(\sigma^{(i)}\) are outputs of the encoding MLP.

Prior

The prior over the latent variables is chosen to be a centered isotropic multivariate Gaussian:

\[p_\theta(z) = \mathcal{N}(z; 0, I)\]

The VAE consists of:

  1. An encoder network that outputs parameters of \(q_{\phi}(z|x)\) (usually mean and variance)
  2. A sampling layer implementing the reparameterization trick
  3. A decoder network that reconstructs the input from the latent representation

Comparison and Experimental Results

Performance Comparison

  1. VAE vs Wake-Sleep:
    • Faster convergence for VAE
    • Better final solutions across all latent dimensions
    • More stable training dynamics
  2. VAE vs MCEM:
    • VAE scales better to large datasets
    • MCEM requires expensive sampling per datapoint
    • VAE achieves better marginal likelihood estimates
  3. Effects of Latent Space Dimensionality:
    • Higher dimensions don’t lead to overfitting
    • Regularizing effect of KL term
    • Smooth latent space interpolations

Implementation Details

For experiments:

Discussion

VAEs combine the strengths of variational inference and deep learning, enabling both probabilistic inference and generation. Key limitations include:

Recent work has addressed these limitations through:

Future Directions

  1. Hierarchical Models:
    • Deep generative architectures
    • Complex posterior approximations
  2. Time-Series Models:
    • Dynamic Bayesian networks
    • Sequential latent variables
  3. Supervised Learning:
    • Latent variable classifiers
    • Semi-supervised learning
  4. Advanced Inference:
    • Flow-based posteriors
    • Importance weighted bounds

Technical Details and Derivations

Detailed KL Divergence Computation for Gaussian Case

For both prior \(p_\theta(z) = \mathcal{N}(0, I)\) and posterior approximation \(q_\phi(z|x)\) being Gaussian, the KL term can be computed analytically:

  1. For a univariate Gaussian: \(\int q_\theta(z)\log p(z)dz = \int \mathcal{N}(z; \mu, \sigma^2)\log \mathcal{N}(z; 0, I)dz = -\frac{1}{2}\log(2\pi) - \frac{1}{2}(\mu^2 + \sigma^2)\)

  2. For \(q_\theta(z)\): \(\int q_\theta(z)\log q_\theta(z)dz = \int \mathcal{N}(z; \mu, \sigma^2)\log \mathcal{N}(z; \mu, \sigma^2)dz = -\frac{1}{2}\log(2\pi) - \frac{1}{2}(1 + \log \sigma^2)\)

  3. Therefore: \(-D_{KL}(q_\phi(z|x)||p_\theta(z)) = \frac{1}{2}\sum_{j=1}^J(1 + \log((\sigma_j)^2) - (\mu_j)^2 - (\sigma_j)^2)\)

Marginal Likelihood Estimator

For low-dimensional latent spaces (\(< 5\) dimensions), we can estimate the marginal likelihood using:

  1. Sample \(L\) values \(\{z^{(l)}\}\) from posterior using HMC
  2. Fit a density estimator \(q(z)\) to these samples
  3. Compute the estimator:
\[p_\theta(x^{(i)}) \approx \left(\frac{1}{L}\sum_{l=1}^L \frac{q(z^{(l)})}{p_\theta(z)p_\theta(x^{(i)}|z^{(l)})}\right)^{-1}\]

where \(z^{(l)} \sim p_\theta(z|x^{(i)})\)

Derivation:

$$
1/p_\theta(x^{(i)}) = ∫q(z)dz/p_\theta(x^{(i)})
                   = ∫q(z)p_\theta(x^{(i)},z)/(p_\theta(x^{(i)},z)p_\theta(x^{(i)}))dz
                   = ∫p_\theta(z|x^{(i)})q(z)/p_\theta(x^{(i)},z)dz
                   ≈ 1/L ∑(q(z^{(l)})/(p_\theta(z)p_\theta(x^{(i)}|z^{(l)}))

$$

Monte Carlo EM Details

For comparison with VAE, the MCEM procedure uses:

The marginal likelihood estimation uses:

Equations

Key equations in VAEs include:

  1. ELBO: \(\log p(x) \geq \mathbb{E}_{q_{\phi}(z|x)}[\log p_{\theta}(x|z)] - D_{KL}(q_{\phi}(z|x)||p(z))\)

  2. Reparameterization: \(z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0,I)\)

  3. KL divergence for Gaussian case: \(D_{KL} = \frac{1}{2}\sum_{j=1}^J(1 + \log(\sigma_j^2) - \mu_j^2 - \sigma_j^2)\)

Code Blocks

Here’s a simple PyTorch implementation of the VAE’s forward pass:

def forward(self, x):
    # Encode
    mu, logvar = self.encode(x)
    
    # Reparameterize
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    z = mu + eps * std
    
    # Decode
    recon_x = self.decode(z)
    
    # Loss
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_x, mu, logvar, recon_loss + kl_loss

Interactive Plots

Visualizations of the latent space and reconstructions can be created using Plotly. For example, latent space interpolations or t-SNE visualizations of the encoded representations.

References

See the original papers: