Variational Autoencoders (VAEs) learn compressed representations of data while enabling generation of new samples. Unlike traditional autoencoders that learn deterministic mappings, VAEs model the latent space as probability distributions, typically Gaussian with learned means and variances. This probabilistic framework provides theoretical guarantees and smooth interpolation between data points.
The encoder network q_φ(z|x)
maps input x to distribution parameters μ and σ². For image inputs like 28×28 MNIST digits, convolutional layers progressively downsample: Conv2d(1→32, 3×3)→Conv2d(32→64, 3×3)→Conv2d(64→128, 3×3)
with 2×2 max pooling between layers. The flattened features pass through dense layers: Linear(2048→512)→ReLU→Linear(512→2×latent_dim)
. The output splits into μ and log(σ²) vectors, each of dimension latent_dim (typically 20-256). Using log-variance instead of variance directly ensures numerical stability and allows unconstrained optimization.
The decoder network p_θ(x|z)
reconstructs data from sampled latent codes. Architecture mirrors the encoder in reverse: Linear(latent_dim→512)→ReLU→Linear(512→2048)→Reshape(128,4,4)→ConvTranspose2d(128→64)→ConvTranspose2d(64→32)→ConvTranspose2d(32→1)
. Stride-2 transposed convolutions upsample spatial dimensions. The final layer uses sigmoid activation for normalized pixel values or no activation for unbounded outputs.
The Reparameterization Trick
Sampling z ~ N(μ, σ²)
is non-differentiable, blocking gradient flow. The reparameterization trick reformulates sampling as z = μ + σ ⊙ ε
where ε ~ N(0, I)
and ⊙ denotes element-wise multiplication. This moves stochasticity to the input, making the operation differentiable with respect to μ and σ. During forward pass: ε = torch.randn_like(μ)
, z = μ + torch.exp(0.5 × log_var) × ε
. The 0.5 factor converts log-variance to log-standard-deviation.
Gradients flow through the deterministic transformation while ε provides necessary stochasticity. Without this trick, we'd need high-variance score function estimators like REINFORCE, requiring 100-1000x more samples for convergence.
Evidence Lower Bound (ELBO)
VAEs maximize the log-likelihood log p(x)
indirectly through the ELBO: log p(x) >= E_q[log p(x|z)] - KL[q(z|x) || p(z)]
. The first term is reconstruction loss; the second is KL divergence between posterior and prior. For Gaussian distributions, KL divergence has closed form: KL = -0.5 × Σ(1 + log(σ²) - μ² - σ²)
.
The full loss function becomes: L = -E_q[log p(x|z)] + β × KL[q(z|x) || p(z)]
. For binary data, reconstruction uses binary cross-entropy: -Σ(x_i log(x̂_i) + (1-x_i)log(1-x̂_i))
. For continuous data, use MSE: ||x - x̂||²
or assume Gaussian likelihood with fixed variance. The β weight controls regularization strength; β=1 recovers standard VAE while β>1 encourages disentangled representations (β-VAE).
Training Dynamics and Hyperparameters
Standard training uses Adam optimizer with learning rate 1e-3 to 1e-4. Batch sizes typically range 32-256. KL annealing gradually increases β from 0 to 1 over 10-100 epochs, preventing posterior collapse where q(z|x) ≈ p(z)
and the model ignores latent codes. Linear warmup: β = min(1, epoch / warmup_epochs)
. Cyclical annealing alternates between β=0 and β=1 to escape local minima.
Free bits technique prevents individual KL dimensions from shrinking below threshold λ (typically 0.125): KL_i = max(λ, KL_i)
. This maintains minimum information capacity per latent dimension. Without such techniques, many latent dimensions collapse to the prior, wasting model capacity.
Architecture Variants and Scaling
Hierarchical VAEs like Ladder VAEs use multiple latent layers: z_L -> z_{L-1} -> ... -> z_1 -> x
. Each level captures different abstraction scales. Bottom-up inference computes q(z_i|z_{i-1}, x)
using both previous latents and input features. Top-down generation samples p(z_i|z_{i+1})
then p(x|z_1)
. Skip connections between encoder and decoder improve gradient flow.
VQ-VAE replaces continuous latents with discrete codebook. Encoder output quantizes to nearest codebook vector: z_q = arg min_k ||z_e - e_k||²
. Straight-through estimator copies gradients around non-differentiable quantization. Codebook updates via exponential moving average or gradient descent. Discrete latents enable autoregressive priors like PixelCNN or transformers.
For high-resolution images, progressive growing trains on increasing resolutions: 4×4 -> 8×8 -> ... -> 256×256
. Each stage adds new layers while keeping previous ones frozen initially. This stabilizes training and reduces computational requirements during early epochs.
Disentanglement and Interpretability
Disentangled representations encode independent factors of variation in separate latent dimensions. Factor-VAE adds total correlation penalty: L = ELBO - γ × TC(z)
where TC(z) = KL[q(z) || Π_i q(z_i)]
. Estimating TC requires density ratio trick with adversarial discriminator distinguishing q(z)
from Π_i q(z_i)
.
Evaluation metrics include Mutual Information Gap (MIG) and SAP score. MIG measures difference between top two mutual informations: MIG = 1/K Σ_k (I(z_j*; v_k) - max_{j!=j*} I(z_j; v_k)) / H(v_k)
. Higher scores indicate better disentanglement. SAP computes prediction accuracy of most informative latent for each factor.
Posterior Collapse and Mitigation
Posterior collapse occurs when KL[q(z|x) || p(z)] -> 0
, making the model ignore latent variables and rely solely on the decoder's implicit capacity. Diagnosis: monitor KL divergence per dimension and active units (dimensions with KL > 0.01). Collapsed models show near-zero KL and poor sample diversity.
Mitigation strategies beyond annealing include: skip connections from encoder to decoder reducing reconstruction burden on latents, minimum desired KL per batch using max(target_KL, actual_KL)
, and structured priors like AR-VAE where p(z) = Π_i p(z_i|z_{<i})
. Aggressive dropout (0.5) in decoder forces reliance on latent information.
Implementation Details
Numerical stability requires careful handling of log-probabilities. Log-sum-exp trick for mixture posteriors: log Σ_i exp(x_i) = max_i x_i + log Σ_i exp(x_i - max_i x_i)
. Gradient clipping at norm 5-10 prevents instabilities from high-variance gradients early in training.
Memory optimization for large models: gradient checkpointing reduces activation memory by 10x, mixed precision training with automatic loss scaling, and batch accumulation for effective batch sizes beyond GPU memory. Conv2d layers use bias=False when followed by BatchNorm.
Standard VAE on MNIST achieves reconstruction error ~70 (binary cross-entropy) with 20-dimensional latent space after 50 epochs. CelebA models with 128-dimensional latents reach FID scores of 45-50. Training typically converges within 100-300 epochs depending on dataset complexity and architecture depth.