Classifier-Free Guidance

Classifier-Free Guidance is a technique that allows for guiding the diffusion process towards a specific class or condition without needing a separate classifier model. It's a powerful method for controlling the output of a conditional diffusion model.

The implementation for this technique can be found in denoising_diffusion_pytorch/classifier_free_guidance.py.

Core Concept

The key idea is to train a single conditional diffusion model that can operate in two modes: conditional and unconditional.

  1. Conditional Prediction: The model makes a prediction conditioned on class information (e.g., a class label y). Let's call this ε_θ(x_t, y).
  2. Unconditional Prediction: The model makes an unconditional prediction. This is achieved by randomly replacing the class label with a special null embedding during training. Let's call this ε_θ(x_t, ∅).

During sampling, both predictions are made at each step. The final prediction is a combination of the two, extrapolated away from the unconditional prediction and towards the conditional one:

ε_final = ε_θ(x_t, ∅) + cond_scale * (ε_θ(x_t, y) - ε_θ(x_t, ∅))

  • cond_scale is a guidance scale parameter. A value of 1.0 means no guidance, while values greater than 1.0 push the generation to be more representative of the target class.

Implementation

The file classifier_free_guidance.py contains a Unet and GaussianDiffusion class specifically designed for this purpose.

Unet with Conditional Dropout

The Unet model in this file is modified to accept class embeddings. During the forward pass, it uses a cond_drop_prob (conditional dropout probability) to randomly replace the true class embedding with a learned null_classes_emb.

forward_with_cond_scale

The Unet has a special method, forward_with_cond_scale, which automates the guidance logic. It performs two forward passes internally:

  1. One with cond_drop_prob = 0.0 (fully conditional).
  2. One with cond_drop_prob = 1.0 (fully unconditional).

It then combines the results using the cond_scale parameter.

Usage Example

Here's how you would set up and sample from a classifier-free guidance model.

import torch
from denoising_diffusion_pytorch.classifier_free_guidance import Unet, GaussianDiffusion

num_classes = 10

# 1. Define the special U-Net that accepts class conditions
model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
    num_classes = num_classes,
    cond_drop_prob = 0.5  # Set dropout probability for training
)

# 2. Wrap it with the diffusion model
diffusion = GaussianDiffusion(
    model,
    image_size = 128,
    timesteps = 1000
).cuda()

# --- Training Loop ---
# During training, you provide both images and class labels
training_images = torch.randn(8, 3, 128, 128).cuda()
image_classes = torch.randint(0, num_classes, (8,)).cuda()

loss = diffusion(training_images, classes = image_classes)
loss.backward()
# ... training continues ...

# --- Sampling ---
# During sampling, you provide the target classes and a guidance scale
sampled_images = diffusion.sample(
    classes = image_classes,      # Target classes for the generated images
    cond_scale = 6.               # Guidance scale (values from 3 to 8 often work well)
)

sampled_images.shape # (8, 3, 128, 128)