A review of the foundational paper on Variational Autoencoders
Variational Auto-Encoders (VAEs) are powerful generative models that combine the ideas of variational inference with deep learning. Introduced by Kingma and Welling in 2013, VAEs provide a principled way to perform both inference and learning in deep latent variable models.
The key innovation of VAEs lies in their ability to learn complex probability distributions over high-dimensional data spaces by optimizing a variational lower bound using stochastic gradient descent. This allows them to generate new samples, perform dimensionality reduction, and learn useful representations of data.
Variational inference is a method from Bayesian statistics that approximates complex posterior distributions with simpler ones. Given a latent variable model \(p(x,z)\), we want to approximate the true posterior
$$p(z | x)\(with a simpler distribution\)q_{\phi}(z | x)$$. |
The EM algorithm is a classical approach for learning latent variable models. It consists of two steps:
E-step: Compute the expected complete log-likelihood with respect to the posterior: \(Q(\theta|\theta^{\text{old}}) = \mathbb{E}_{p(z|x,\theta^{\text{old}})}[\log p(x,z|\theta)]\)
M-step: Maximize this expectation with respect to \(\theta\): \(\theta^{\text{new}} = \arg\max_\theta Q(\theta|\theta^{\text{old}})\)
Key limitations of EM:
Monte Carlo EM (MCEM) attempts to address intractability by using sampling:
E-step: Draw samples from posterior using MCMC: \(z^{(l)} \sim p(z|x,\theta^{\text{old}})\)
\[Q(\theta|\theta^{\text{old}}) \approx \frac{1}{L}\sum_{l=1}^L \log p(x,z^{(l)}|\theta)\]M-step: Same as standard EM
Limitations of MCEM:
Traditional auto-encoders learn to compress data into a lower-dimensional representation by minimizing reconstruction error. However, they lack a probabilistic interpretation and cannot generate new samples.
Consider a dataset \(X = \{x^{(i)}\}_{i=1}^N\) consisting of \(N\) i.i.d. samples. We assume the data are generated by a random process involving an unobserved continuous random variable \(z\):
Key assumptions:
Important challenges addressed:
The paper aims to solve:
For a single datapoint, the marginal likelihood can be rewritten as:
\[\log p_\theta(x^{(i)}) = D_{KL}(q_\phi(z|x^{(i)})||p_\theta(z|x^{(i)})) + \mathcal{L}(\theta, \phi; x^{(i)})\]where:
The lower bound can be written in two instructive ways:
As an expectation of a complete log-likelihood: \(\mathcal{L}(\theta, \phi; x^{(i)}) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x,z) - \log q_\phi(z|x)]\)
As reconstruction error minus KL divergence: \(\mathcal{L}(\theta, \phi; x^{(i)}) = -D_{KL}(q_\phi(z|x^{(i)})||p_\theta(z)) + \mathbb{E}_{q_\phi(z|x^{(i)})}[\log p_\theta(x^{(i)}|z)]\)
A key challenge is that the naive Monte Carlo gradient estimator: \(\nabla_\phi \mathbb{E}_{q_\phi(z)}[f(z)] \approx \frac{1}{L}\sum_{l=1}^L f(z^{(l)})\nabla_{q_\phi(z^{(l)})}\log q_\phi(z^{(l)})\) exhibits very high variance. This motivates the reparameterization trick.
The VAE optimizes the evidence lower bound (ELBO):
\[\mathcal{L}(\theta,\phi;x) = \mathbb{E}_{q_{\phi}(z|x)}[\log p_{\theta}(x|z)] - D_{KL}(q_{\phi}(z|x)||p(z))\]where:
$$p_{\theta}(x | z)$$ is the decoder (generative model) |
$$q_{\phi}(z | x)$$ is the encoder (inference model) |
A key innovation of VAEs is the reparameterization trick, which enables backpropagation through random sampling. Instead of directly sampling from \(q_{\phi}(z|x)\), we sample from a simple distribution \(\epsilon \sim \mathcal{N}(0,I)\) and transform it:
\[z = \mu_{\phi}(x) + \sigma_{\phi}(x) \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0,I)\]The Stochastic Gradient Variational Bayes (SGVB) estimator comes in two forms:
Generic estimator: \(\tilde{\mathcal{L}}^A(\theta, \phi; x^{(i)}) = \frac{1}{L}\sum_{l=1}^L \log p_\theta(x^{(i)}, z^{(i,l)}) - \log q_\phi(z^{(i,l)}|x^{(i)})\)
where \(z^{(i,l)} = g_\phi(\epsilon^{(l)}, x^{(i)})\) and \(\epsilon^{(l)} \sim p(\epsilon)\)
For cases where the KL divergence is analytically tractable: \(\tilde{\mathcal{L}}^B(\theta, \phi; x^{(i)}) = -D_{KL}(q_\phi(z|x^{(i)})||p_\theta(z)) + \frac{1}{L}\sum_{l=1}^L \log p_\theta(x^{(i)}|z^{(i,l)})\)
The complete Auto-Encoding Variational Bayes (AEVB) algorithm:
def AEVB_algorithm(X):
θ, φ = initialize_parameters()
while not converged:
# Get minibatch
XM = random_minibatch(X, M=100)
# Get random samples from noise distribution
ε = sample_noise(p(ε))
# Compute gradients of minibatch estimator
g = compute_gradients(∇θ,φ𝓛M(θ, φ; XM, ε))
# Update parameters using e.g. SGD or Adagrad
θ, φ = update_parameters(θ, φ, g)
return θ, φ
For Gaussian encoder and prior, the KL divergence term has the analytical solution:
\[-D_{KL} = \frac{1}{2}\sum_{j=1}^J(1 + \log((\sigma_j)^2) - (\mu_j)^2 - (\sigma_j)^2)\]The Stochastic Gradient Variational Bayes (SGVB) estimator allows efficient optimization of the ELBO using stochastic gradient descent:
\[\mathcal{L} \approx \frac{1}{L}\sum_{l=1}^L \log p_{\theta}(x|z^{(l)}) - D_{KL}(q_{\phi}(z|x)||p(z))\]where \(z^{(l)}\) are samples drawn using the reparameterization trick.
For real-valued data, both encoder and decoder are implemented as MLPs:
def encoder(x):
h = tanh(W3 @ x + b3)
μ = W4 @ h + b4
log_σ2 = W5 @ h + b5
return μ, log_σ2
def decoder(z):
h = tanh(W1 @ z + b1)
x_recon = sigmoid(W2 @ h + b2) # for binary data
# or for continuous data:
# μ = W2 @ h + b2
# log_σ2 = W3 @ h + b3
return x_recon
For continuous latent variables, we use a Gaussian approximate posterior with diagonal covariance:
\[\log q_\phi(z|x^{(i)}) = \log \mathcal{N}(z; \mu^{(i)}, \sigma^{2(i)}I)\]where \(\mu^{(i)}\) and \(\sigma^{(i)}\) are outputs of the encoding MLP.
The prior over the latent variables is chosen to be a centered isotropic multivariate Gaussian:
\[p_\theta(z) = \mathcal{N}(z; 0, I)\]The VAE consists of:
For experiments:
VAEs combine the strengths of variational inference and deep learning, enabling both probabilistic inference and generation. Key limitations include:
Recent work has addressed these limitations through:
For both prior \(p_\theta(z) = \mathcal{N}(0, I)\) and posterior approximation \(q_\phi(z|x)\) being Gaussian, the KL term can be computed analytically:
For a univariate Gaussian: \(\int q_\theta(z)\log p(z)dz = \int \mathcal{N}(z; \mu, \sigma^2)\log \mathcal{N}(z; 0, I)dz = -\frac{1}{2}\log(2\pi) - \frac{1}{2}(\mu^2 + \sigma^2)\)
For \(q_\theta(z)\): \(\int q_\theta(z)\log q_\theta(z)dz = \int \mathcal{N}(z; \mu, \sigma^2)\log \mathcal{N}(z; \mu, \sigma^2)dz = -\frac{1}{2}\log(2\pi) - \frac{1}{2}(1 + \log \sigma^2)\)
Therefore: \(-D_{KL}(q_\phi(z|x)||p_\theta(z)) = \frac{1}{2}\sum_{j=1}^J(1 + \log((\sigma_j)^2) - (\mu_j)^2 - (\sigma_j)^2)\)
For low-dimensional latent spaces (\(< 5\) dimensions), we can estimate the marginal likelihood using:
where \(z^{(l)} \sim p_\theta(z|x^{(i)})\)
Derivation:
$$
1/p_\theta(x^{(i)}) = ∫q(z)dz/p_\theta(x^{(i)})
= ∫q(z)p_\theta(x^{(i)},z)/(p_\theta(x^{(i)},z)p_\theta(x^{(i)}))dz
= ∫p_\theta(z|x^{(i)})q(z)/p_\theta(x^{(i)},z)dz
≈ 1/L ∑(q(z^{(l)})/(p_\theta(z)p_\theta(x^{(i)}|z^{(l)}))
$$
For comparison with VAE, the MCEM procedure uses:
The marginal likelihood estimation uses:
Key equations in VAEs include:
ELBO: \(\log p(x) \geq \mathbb{E}_{q_{\phi}(z|x)}[\log p_{\theta}(x|z)] - D_{KL}(q_{\phi}(z|x)||p(z))\)
Reparameterization: \(z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0,I)\)
KL divergence for Gaussian case: \(D_{KL} = \frac{1}{2}\sum_{j=1}^J(1 + \log(\sigma_j^2) - \mu_j^2 - \sigma_j^2)\)
Here’s a simple PyTorch implementation of the VAE’s forward pass:
def forward(self, x):
# Encode
mu, logvar = self.encode(x)
# Reparameterize
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
# Decode
recon_x = self.decode(z)
# Loss
recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_x, mu, logvar, recon_loss + kl_loss
Visualizations of the latent space and reconstructions can be created using Plotly. For example, latent space interpolations or t-SNE visualizations of the encoded representations.
See the original papers: