Generative models are a class of machine learning models that learn to create new data samples that resemble the training data. Unlike discriminative models that learn to classify or predict labels, generative models learn the underlying probability distribution of the data itself.
Generative Adversarial Networks (GANs), introduced by Ian Goodfellow and colleagues in 2014, revolutionized generative modeling by framing it as a competitive game between two neural networks.
Before GANs, generative models like Variational Autoencoders (VAEs) struggled to produce sharp, realistic images. GANs changed this by introducing an adversarial training process that pushes both networks to improve simultaneously.
Think of a GAN as a game between a counterfeiter and a detective:
As the detective gets better at spotting fakes, the counterfeiter must improve their technique. As the counterfeiter produces more convincing fakes, the detective must become more discerning. This back-and-forth competition drives both to excellence.
Generator (G):
Discriminator (D):
Generator Architecture: Noise → Image
The generator transforms a low-dimensional random noise vector into a high-dimensional data sample (e.g., an image). This is an upsampling process.
Discriminator Architecture: Image → Probability
The discriminator is essentially a binary classifier that outputs a single probability value indicating whether the input is real or fake.
The GAN training process is formalized as a minimax game where the generator tries to minimize what the discriminator tries to maximize.
Breaking down the objective:
In practice, the minimax objective is implemented using Binary Cross-Entropy (BCE) loss, which measures the difference between predicted and actual binary labels.
For the Discriminator:
# Real samples (label = 1)
loss_real = BCE(D(x_real), 1)
# Fake samples (label = 0)
loss_fake = BCE(D(G(z)), 0)
# Total discriminator loss
loss_D = loss_real + loss_fake
For the Generator:
# Generator wants D to output 1 for fake samples
loss_G = BCE(D(G(z)), 1)
GANs use alternating optimization: train the discriminator for one or more steps, then train the generator for one step, and repeat.
# Simplified vanilla GAN training loop
for epoch in range(num_epochs):
for real_images, _ in dataloader:
batch_size = real_images.size(0)
# ==================
# Train Discriminator
# ==================
optimizer_D.zero_grad()
# Real images
real_labels = torch.ones(batch_size, 1)
real_output = discriminator(real_images)
loss_D_real = criterion(real_output, real_labels)
# Fake images
noise = torch.randn(batch_size, latent_dim)
fake_images = generator(noise)
fake_labels = torch.zeros(batch_size, 1)
fake_output = discriminator(fake_images.detach())
loss_D_fake = criterion(fake_output, fake_labels)
# Total discriminator loss
loss_D = loss_D_real + loss_D_fake
loss_D.backward()
optimizer_D.step()
# ==================
# Train Generator
# ==================
optimizer_G.zero_grad()
# Generator wants D to output 1 for fake images
noise = torch.randn(batch_size, latent_dim)
fake_images = generator(noise)
fake_output = discriminator(fake_images)
real_labels = torch.ones(batch_size, 1)
loss_G = criterion(fake_output, real_labels)
loss_G.backward()
optimizer_G.step()
Despite their success, vanilla GANs are notoriously difficult to train. Several fundamental problems arise from the adversarial training dynamics.
What it is:
Mode collapse occurs when the generator learns to produce only a limited variety of samples, ignoring much of the data distribution. Instead of generating diverse outputs, it "collapses" to producing a few safe samples that fool the discriminator.
Why it happens:
When discriminator becomes too strong:
If the discriminator becomes very good at distinguishing real from fake, it outputs values very close to 0 or 1. This causes the gradient of log(1 - D(G(z))) to vanish, leaving the generator with no learning signal.
This is known as the vanishing gradient problem. The generator stops learning because the discriminator is so confident that the samples are fake.
Loss of learning signal:
Unlike typical neural network training where loss decreases monotonically, GAN training involves two competing objectives that may never reach equilibrium.
Oscillating behavior:
| Epoch | G Loss | D Loss | Observation |
|---|---|---|---|
| 1 | 2.45 | 0.89 | D too strong |
| 2 | 1.12 | 1.23 | Better balance |
| 3 | 3.78 | 0.45 | Oscillation |
| 4 | 0.67 | 2.01 | G too strong |
| 5 | 2.89 | 0.91 | Instability |
BCE loss doesn't correlate with quality:
The discriminator and generator losses provide little information about the actual quality of generated samples. You can have:
Diagnostic challenges:
The instability of vanilla GANs stems from using the Jensen-Shannon (JS) divergence implicitly through the BCE loss. When the real and fake distributions have minimal overlap, the JS divergence becomes constant, providing no useful gradient.
The core problem:
Intuitive explanation (moving piles of earth):
Imagine you have two piles of earth with different shapes. The Wasserstein distance measures the minimum amount of "work" needed to transform one pile into the other, where work = amount of earth × distance moved.
Mathematical definition:
Where:
Why Wasserstein distance is better:
WGAN replaces the discriminator with a "critic" that outputs raw scores instead of probabilities.
| Aspect | Vanilla GAN Discriminator | WGAN Critic |
|---|---|---|
| Output activation | Sigmoid (0 to 1) | None (any real number) |
| Output interpretation | Probability of being real | Raw score (higher = more real) |
| Training objective | Maximize classification accuracy | Maximize separation between real and fake scores |
| Loss function | Binary Cross-Entropy | Wasserstein loss |
Critic architecture (no sigmoid):
class Critic(nn.Module):
def __init__(self):
super(Critic, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1)
# NO SIGMOID - outputs raw scores
)
def forward(self, x):
return self.model(x)
Computing the Wasserstein distance directly is intractable. The Kantorovich-Rubinstein duality theorem provides a practical way to compute it using a critic function.
Where:
A function f is 1-Lipschitz if for all x₁ and x₂:
This means the function cannot change faster than the input changes - it has bounded gradients.
Weight clipping (original WGAN approach):
The original WGAN paper enforced the Lipschitz constraint by clipping critic weights to a small range [-c, c] after each update.
# Weight clipping (original WGAN)
for param in critic.parameters():
param.data.clamp_(-0.01, 0.01) # Clip to [-0.01, 0.01]
Problems with weight clipping:
Critic loss (maximize):
Generator loss (minimize):
PyTorch implementation:
# WGAN losses (without gradient penalty yet)
# Critic loss
critic_real = critic(real_images).mean()
critic_fake = critic(fake_images).mean()
loss_C = -(critic_real - critic_fake) # Maximize → minimize negative
# Generator loss
fake_images = generator(noise)
loss_G = -critic(fake_images).mean() # Maximize critic score on fakes
WGAN-GP improves upon WGAN by replacing weight clipping with a gradient penalty term that directly enforces the Lipschitz constraint by penalizing the gradient norm of the critic.
The idea:
A 1-Lipschitz function must have gradients with norm at most 1 everywhere. Instead of clipping weights, we add a penalty term that encourages ||∇f(x)||₂ = 1.
Where:
Interpolated samples:
We compute the gradient penalty on random interpolations between real and fake samples, not on real/fake data directly.
Why interpolations?
Full gradient penalty implementation (PyTorch):
def compute_gradient_penalty(critic, real_images, fake_images, device):
"""
Compute gradient penalty for WGAN-GP
Args:
critic: Critic network
real_images: Batch of real images
fake_images: Batch of generated images
device: 'cuda' or 'cpu'
Returns:
gradient_penalty: Scalar penalty value
"""
batch_size = real_images.size(0)
# Random weight term for interpolation
epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
epsilon = epsilon.expand_as(real_images)
# Interpolated samples
interpolated = epsilon * real_images + (1 - epsilon) * fake_images
interpolated.requires_grad_(True)
# Critic scores for interpolated samples
critic_interpolated = critic(interpolated)
# Compute gradients of critic scores w.r.t. interpolated samples
gradients = torch.autograd.grad(
outputs=critic_interpolated,
inputs=interpolated,
grad_outputs=torch.ones_like(critic_interpolated),
create_graph=True,
retain_graph=True,
only_inputs=True
)[0]
# Flatten gradients
gradients = gradients.view(batch_size, -1)
# Compute gradient norm
gradient_norm = gradients.norm(2, dim=1)
# Penalty for deviation from norm = 1
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
return gradient_penalty
WGAN-GP Training Procedure:
Key hyperparameters:
| Parameter | Standard Value | Purpose |
|---|---|---|
| λ (lambda) | 10 | Gradient penalty coefficient |
| n_critic | 5 | Critic updates per generator update |
| Learning rate | 1e-4 (0.0001) | Lower than vanilla GAN for stability |
| β₁ (Adam) | 0.5 | Momentum parameter |
| β₂ (Adam) | 0.9 | RMSprop parameter |
# WGAN-GP Training Loop
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(dataloader):
real_images = real_images.to(device)
batch_size = real_images.size(0)
# ==================
# Train Critic (n_critic times)
# ==================
for _ in range(n_critic):
optimizer_C.zero_grad()
# Generate fake images
noise = torch.randn(batch_size, latent_dim, device=device)
fake_images = generator(noise)
# Critic scores
critic_real = critic(real_images).mean()
critic_fake = critic(fake_images).mean()
# Gradient penalty
gp = compute_gradient_penalty(critic, real_images,
fake_images.detach(), device)
# Total critic loss
loss_C = critic_fake - critic_real + lambda_gp * gp
loss_C.backward()
optimizer_C.step()
# ==================
# Train Generator
# ==================
optimizer_G.zero_grad()
noise = torch.randn(batch_size, latent_dim, device=device)
fake_images = generator(noise)
# Generator loss
loss_G = -critic(fake_images).mean()
loss_G.backward()
optimizer_G.step()
# Log Wasserstein distance estimate
wasserstein_distance = critic_real - critic_fake
Smoother loss curves:
WGAN-GP exhibits much more stable training compared to vanilla GAN. The losses decrease smoothly and predictably.
Meaningful Wasserstein distance metric:
Better sample quality:
Traditional image-to-image translation methods (like pix2pix) require paired training examples: input image A and corresponding output image B. CycleGAN solves the harder problem of translation without paired data.
The challenge:
The key innovation of CycleGAN is cycle consistency: if we translate from domain X to Y and back to X, we should get back the original image.
Forward cycle: X → Y → X
If we map image x from domain X to domain Y using generator G, then map it back using generator F, we should recover x.
Backward cycle: Y → X → Y
Similarly, mapping y from Y to X and back should recover y.
Cycle consistency loss:
Two generators:
Two discriminators:
The CycleGAN objective combines adversarial losses (to make translations realistic) with cycle consistency losses (to preserve content).
Adversarial losses:
Full objective:
Where λ controls the relative importance of cycle consistency (typically λ = 10).
# CycleGAN Training (simplified)
for epoch in range(num_epochs):
for real_X, real_Y in dataloader:
# ==================
# Train Generators
# ==================
optimizer_G.zero_grad()
# Forward cycle: X -> Y -> X
fake_Y = G(real_X)
reconstructed_X = F(fake_Y)
loss_cycle_X = L1(reconstructed_X, real_X)
# Backward cycle: Y -> X -> Y
fake_X = F(real_Y)
reconstructed_Y = G(fake_X)
loss_cycle_Y = L1(reconstructed_Y, real_Y)
# Adversarial losses
loss_G_adv = -D_Y(fake_Y).mean()
loss_F_adv = -D_X(fake_X).mean()
# Total generator loss
loss_G = (loss_G_adv + loss_F_adv +
lambda_cyc * (loss_cycle_X + loss_cycle_Y))
loss_G.backward()
optimizer_G.step()
# ==================
# Train Discriminators
# ==================
# D_Y discriminates Y domain
optimizer_D_Y.zero_grad()
loss_D_Y_real = D_Y(real_Y).mean()
loss_D_Y_fake = D_Y(fake_Y.detach()).mean()
loss_D_Y = loss_D_Y_fake - loss_D_Y_real
loss_D_Y.backward()
optimizer_D_Y.step()
# D_X discriminates X domain
optimizer_D_X.zero_grad()
loss_D_X_real = D_X(real_X).mean()
loss_D_X_fake = D_X(fake_X.detach()).mean()
loss_D_X = loss_D_X_fake - loss_D_X_real
loss_D_X.backward()
optimizer_D_X.step()
Style transfer:
Object transfiguration:
Domain adaptation:
Evaluating GANs is challenging because we care about both sample quality (do images look realistic?) and sample diversity (do we cover the full distribution?). No single metric captures both perfectly.
What it measures:
Inception Score uses a pre-trained Inception network to evaluate generated images based on two criteria:
Where:
Interpretation:
Limitations:
What it measures:
FID compares the distribution of generated images to real images by looking at their features in the Inception network's feature space.
How it works:
Where:
Interpretation:
Advantages over IS:
Limitations:
| Aspect | Inception Score (IS) | Fréchet Inception Distance (FID) |
|---|---|---|
| What it measures | Quality + diversity of classes | Distance to real distribution |
| Better value | Higher is better | Lower is better |
| Uses real data | No (only generator samples) | Yes (compares to real) |
| Detects mode collapse | Poorly | Well |
| Samples needed | ~5,000 | ~10,000+ |
| Computational cost | Low | Medium |
| Human correlation | Moderate | Better |
import torch
from scipy import linalg
import numpy as np
def calculate_fid(real_features, fake_features):
"""
Calculate Fréchet Inception Distance
Args:
real_features: Features from real images (N x D)
fake_features: Features from generated images (M x D)
Returns:
fid_score: Scalar FID value (lower is better)
"""
# Calculate mean and covariance
mu_real = np.mean(real_features, axis=0)
mu_fake = np.mean(fake_features, axis=0)
sigma_real = np.cov(real_features, rowvar=False)
sigma_fake = np.cov(fake_features, rowvar=False)
# Calculate squared difference of means
diff = mu_real - mu_fake
mean_diff = diff.dot(diff)
# Calculate sqrt of product of covariances
covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)
# Handle numerical errors
if np.iscomplexobj(covmean):
covmean = covmean.real
# Calculate FID
fid = mean_diff + np.trace(sigma_real + sigma_fake - 2*covmean)
return fid
# Extract features using pre-trained Inception
inception_model = torchvision.models.inception_v3(pretrained=True)
inception_model.fc = torch.nn.Identity() # Remove final layer
inception_model.eval()
def get_features(images):
with torch.no_grad():
features = inception_model(images)
return features.cpu().numpy()
# Compute FID
real_features = get_features(real_images)
fake_features = get_features(generated_images)
fid_score = calculate_fid(real_features, fake_features)
print(f"FID Score: {fid_score:.2f}")
The Assignment 8 implementation on FashionMNIST provides valuable insights into the practical differences between Vanilla GAN and WGAN-GP.
Vanilla GAN instability observed:
WGAN-GP stability improvements:
Batch-level volatility analysis:
Examining loss at the batch level reveals the extent of training instability:
| Metric | Vanilla GAN | WGAN-GP |
|---|---|---|
| Generator loss std dev | ~0.45 | ~0.12 |
| Discriminator/Critic loss std dev | ~0.38 | ~0.09 |
| Stability improvement | Baseline | 3.7× less volatile |
import torch
import torch.nn as nn
class Generator(nn.Module):
"""
Generator network for FashionMNIST (28x28 grayscale images)
Maps 64-dimensional noise to 784-dimensional image
"""
def __init__(self, latent_dim=64, img_dim=784):
super(Generator, self).__init__()
self.model = nn.Sequential(
# Input: latent_dim (64)
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm1d(256),
# Hidden layer 1
nn.Linear(256, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm1d(512),
# Hidden layer 2
nn.Linear(512, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm1d(1024),
# Output: img_dim (784)
nn.Linear(1024, img_dim),
nn.Tanh() # Output in [-1, 1] to match normalized images
)
def forward(self, z):
"""
Args:
z: Noise vector (batch_size, latent_dim)
Returns:
Generated image (batch_size, img_dim)
"""
img = self.model(z)
return img
class Critic(nn.Module):
"""
Critic network for WGAN-GP
NO sigmoid activation - outputs raw scores
"""
def __init__(self, img_dim=784):
super(Critic, self).__init__()
self.model = nn.Sequential(
# Input: img_dim (784)
nn.Linear(img_dim, 512),
nn.LeakyReLU(0.2, inplace=True),
# Hidden layer 1
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
# Output: single score (no sigmoid!)
nn.Linear(256, 1)
)
def forward(self, img):
"""
Args:
img: Image (batch_size, img_dim)
Returns:
Critic score (batch_size, 1) - raw unbounded value
"""
score = self.model(img)
return score
def compute_gradient_penalty(critic, real_images, fake_images, device='cuda'):
"""
Compute gradient penalty for WGAN-GP
Enforces 1-Lipschitz constraint by penalizing gradients
that deviate from norm = 1
Args:
critic: Critic network
real_images: Batch of real images (B, 784)
fake_images: Batch of generated images (B, 784)
device: 'cuda' or 'cpu'
Returns:
gradient_penalty: Scalar penalty value
"""
batch_size = real_images.size(0)
# Random interpolation coefficient for each sample
epsilon = torch.rand(batch_size, 1, device=device)
# Interpolated samples between real and fake
interpolated = epsilon * real_images + (1 - epsilon) * fake_images
interpolated.requires_grad_(True)
# Get critic scores for interpolated samples
critic_interpolated = critic(interpolated)
# Compute gradients of scores w.r.t. interpolated inputs
gradients = torch.autograd.grad(
outputs=critic_interpolated,
inputs=interpolated,
grad_outputs=torch.ones_like(critic_interpolated),
create_graph=True, # Allow backprop through this operation
retain_graph=True, # Don't free computation graph
only_inputs=True # Only compute w.r.t. inputs
)[0]
# Compute L2 norm of gradients for each sample
gradients = gradients.view(batch_size, -1)
gradient_norm = gradients.norm(2, dim=1)
# Penalize deviation from norm = 1
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
return gradient_penalty
# Hyperparameters
latent_dim = 64
img_dim = 28 * 28
lr = 1e-4
beta1 = 0.5
beta2 = 0.9
n_critic = 5
lambda_gp = 10
num_epochs = 5
batch_size = 128
# Initialize models
generator = Generator(latent_dim, img_dim).to(device)
critic = Critic(img_dim).to(device)
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(),
lr=lr, betas=(beta1, beta2))
optimizer_C = torch.optim.Adam(critic.parameters(),
lr=lr, betas=(beta1, beta2))
# Training loop
for epoch in range(num_epochs):
for batch_idx, (real_images, _) in enumerate(dataloader):
real_images = real_images.view(-1, img_dim).to(device)
batch_size = real_images.size(0)
# ==================
# Train Critic (n_critic times per generator update)
# ==================
for _ in range(n_critic):
optimizer_C.zero_grad()
# Sample noise and generate fake images
noise = torch.randn(batch_size, latent_dim, device=device)
fake_images = generator(noise)
# Critic scores on real and fake
critic_real = critic(real_images).mean()
critic_fake = critic(fake_images.detach()).mean()
# Gradient penalty
gp = compute_gradient_penalty(critic, real_images,
fake_images.detach(), device)
# Wasserstein loss with gradient penalty
loss_C = critic_fake - critic_real + lambda_gp * gp
loss_C.backward()
optimizer_C.step()
# ==================
# Train Generator (once per n_critic critic updates)
# ==================
optimizer_G.zero_grad()
# Generate fake images
noise = torch.randn(batch_size, latent_dim, device=device)
fake_images = generator(noise)
# Generator wants critic to output high scores for fakes
loss_G = -critic(fake_images).mean()
loss_G.backward()
optimizer_G.step()
# ==================
# Logging
# ==================
if batch_idx % 100 == 0:
wasserstein_dist = (critic_real - critic_fake).item()
print(f"Epoch [{epoch}/{num_epochs}] Batch [{batch_idx}] "
f"Loss_G: {loss_G.item():.4f} "
f"Loss_C: {loss_C.item():.4f} "
f"W-dist: {wasserstein_dist:.4f} "
f"GP: {gp.item():.4f}")
| Problem | Symptom | Solution |
|---|---|---|
| Forgot to remove sigmoid | Critic outputs always in [0,1] | Ensure critic has no sigmoid activation |
| Wrong sign in losses | Losses increase instead of decrease | Critic minimizes negative W-distance |
| Gradient penalty too low | Training unstable, mode collapse | Use λ = 10 (standard value) |
| Not enough critic updates | Generator dominates, poor samples | Use n_critic = 5 |
| Learning rate too high | Oscillating losses, instability | Use lr = 1e-4 (conservative) |
| Forgot create_graph=True | Error during backward pass | Enable in autograd.grad for GP |
| Feature | Vanilla GAN | WGAN | WGAN-GP |
|---|---|---|---|
| Distance Metric | JS Divergence (implicit) | Wasserstein Distance | Wasserstein Distance |
| Loss Function | Binary Cross-Entropy | Wasserstein loss | Wasserstein loss + GP |
| Output Activation | Sigmoid (0-1) | None (raw scores) | None (raw scores) |
| Network Name | Discriminator | Critic | Critic |
| Lipschitz Constraint | None | Weight clipping | Gradient penalty |
| Training Stability | Unstable | More stable | Very stable |
| Mode Collapse | Common | Less common | Rare |
| Meaningful Metric | No | Yes (W-distance) | Yes (W-distance) |
| Learning Rate | ~2e-4 | ~1e-4 | ~1e-4 |
| Update Ratio (D/C:G) | 1:1 | 5:1 | 5:1 |
| Sample Quality | Good (if stable) | Better | Best |
| Training Time | Fastest | Moderate | Slowest (GP overhead) |
Use Vanilla GAN when:
Use WGAN when:
Use WGAN-GP when:
From Vanilla GAN:
From WGAN:
From WGAN-GP:
You have completed the comprehensive guide to Generative Adversarial Networks.
Topics Covered:
Introduction • Vanilla GAN • Training Problems • Wasserstein GAN •
WGAN-GP • CycleGAN • Evaluation Metrics • Implementation • Comparison
Next Steps:
Practice implementing these architectures in PyTorch •
Study the Diffusion Models lesson for state-of-the-art generative modeling •
Review the Anki flashcards to reinforce key concepts