Classifier-Free Guidance
Classifier-Free Guidance is a technique that allows for guiding the diffusion process towards a specific class or condition without needing a separate classifier model. It's a powerful method for controlling the output of a conditional diffusion model.
The implementation for this technique can be found in denoising_diffusion_pytorch/classifier_free_guidance.py
.
Core Concept
The key idea is to train a single conditional diffusion model that can operate in two modes: conditional and unconditional.
- Conditional Prediction: The model makes a prediction conditioned on class information (e.g., a class label
y
). Let's call thisε_θ(x_t, y)
. - Unconditional Prediction: The model makes an unconditional prediction. This is achieved by randomly replacing the class label with a special
null
embedding during training. Let's call thisε_θ(x_t, ∅)
.
During sampling, both predictions are made at each step. The final prediction is a combination of the two, extrapolated away from the unconditional prediction and towards the conditional one:
ε_final = ε_θ(x_t, ∅) + cond_scale * (ε_θ(x_t, y) - ε_θ(x_t, ∅))
cond_scale
is a guidance scale parameter. A value of1.0
means no guidance, while values greater than 1.0 push the generation to be more representative of the target class.
Implementation
The file classifier_free_guidance.py
contains a Unet
and GaussianDiffusion
class specifically designed for this purpose.
Unet
with Conditional Dropout
The Unet
model in this file is modified to accept class embeddings. During the forward pass, it uses a cond_drop_prob
(conditional dropout probability) to randomly replace the true class embedding with a learned null_classes_emb
.
forward_with_cond_scale
The Unet
has a special method, forward_with_cond_scale
, which automates the guidance logic. It performs two forward passes internally:
- One with
cond_drop_prob = 0.0
(fully conditional). - One with
cond_drop_prob = 1.0
(fully unconditional).
It then combines the results using the cond_scale
parameter.
Usage Example
Here's how you would set up and sample from a classifier-free guidance model.
import torch
from denoising_diffusion_pytorch.classifier_free_guidance import Unet, GaussianDiffusion
num_classes = 10
# 1. Define the special U-Net that accepts class conditions
model = Unet(
dim = 64,
dim_mults = (1, 2, 4, 8),
num_classes = num_classes,
cond_drop_prob = 0.5 # Set dropout probability for training
)
# 2. Wrap it with the diffusion model
diffusion = GaussianDiffusion(
model,
image_size = 128,
timesteps = 1000
).cuda()
# --- Training Loop ---
# During training, you provide both images and class labels
training_images = torch.randn(8, 3, 128, 128).cuda()
image_classes = torch.randint(0, num_classes, (8,)).cuda()
loss = diffusion(training_images, classes = image_classes)
loss.backward()
# ... training continues ...
# --- Sampling ---
# During sampling, you provide the target classes and a guidance scale
sampled_images = diffusion.sample(
classes = image_classes, # Target classes for the generated images
cond_scale = 6. # Guidance scale (values from 3 to 8 often work well)
)
sampled_images.shape # (8, 3, 128, 128)