Guided Diffusion

Guided Diffusion refers to a technique where the sampling process of a diffusion model is guided by the gradients of an external classifier. This allows you to steer the generation towards a desired class, improving sample fidelity and class-conditional generation quality.

The implementation for this can be found in denoising_diffusion_pytorch/guided_diffusion.py.

Core Concept

The method, introduced in "Diffusion Models Beat GANs on Image Synthesis" by Dhariwal and Nichol, modifies the reverse process sampling step. At each timestep t, after the diffusion model predicts the mean μ_θ(x_t) for the previous step, an additional term is added:

μ_guided = μ_θ(x_t) + s * Σ_θ * ∇_x_t log(p_φ(y|x_t))

Where:

  • s is the guidance scale.
  • Σ_θ is the variance of the reverse step.
  • p_φ(y|x_t) is the probability of the target class y given the noisy image x_t, as predicted by a separate, noise-aware classifier φ.
  • ∇_x_t log(p_φ(y|x_t)) is the gradient of the log-probability with respect to the noisy image x_t.

This gradient essentially pushes the sampling process towards a direction that makes the image more recognizable as class y to the classifier.

Implementation Details

The guided_diffusion.py file provides a GaussianDiffusion class with modified sampling methods.

The key change is in the p_sample method, which now accepts an optional cond_fn.

# From guided_diffusion.py

def p_sample(self, x, t: int, x_self_cond = None, cond_fn=None, guidance_kwargs=None):
    # ... standard reverse step calculation ...
    model_mean, variance, model_log_variance, x_start = self.p_mean_variance(...)

    if exists(cond_fn) and exists(guidance_kwargs):
        model_mean = self.condition_mean(cond_fn, model_mean, variance, x, batched_times, guidance_kwargs)

    # ... add noise and return ...

The condition_mean method computes the gradient from the cond_fn and adds it to the predicted mean.

Usage Example

To use guided diffusion, you need two components:

  1. A trained diffusion model (Unet + GaussianDiffusion).
  2. A separate classifier trained on noisy images.
import torch
import torch.nn.functional as F
from torch import nn
from denoising_diffusion_pytorch.guided_diffusion import Unet, GaussianDiffusion

# 1. A pre-trained diffusion model
model = Unet(dim=64, dim_mults=(1, 2, 4, 8))
image_size = 128
diffusion = GaussianDiffusion(model, image_size=image_size, timesteps=1000)
# ... assume `diffusion` model is trained ...

# 2. A noise-aware classifier (example implementation)
class Classifier(nn.Module):
    def __init__(self, image_size, num_classes):
        super().__init__()
        # A simple linear classifier for demonstration
        self.linear_img = nn.Linear(image_size * image_size * 3, num_classes)

    def forward(self, x, t):
        # Classifier must accept both the noisy image `x` and timestep `t`
        x_flat = x.view(x.shape[0], -1)
        return self.linear_img(x_flat)

classifier = Classifier(image_size=image_size, num_classes=1000)
# ... assume `classifier` is trained on noisy images ...

# 3. A gradient function (`cond_fn`)
def classifier_cond_fn(x, t, classifier, y, classifier_scale=1.0):
    assert y is not None
    with torch.enable_grad():
        x_in = x.detach().requires_grad_(True)
        logits = classifier(x_in, t)
        log_probs = F.log_softmax(logits, dim=-1)
        selected = log_probs[range(len(logits)), y.view(-1)]
        grad = torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale
        return grad

# 4. Sampling with guidance
batch_size = 4
target_classes = torch.randint(0, 1000, (batch_size,))

sampled_images = diffusion.sample(
    batch_size=batch_size,
    cond_fn=classifier_cond_fn,
    guidance_kwargs={
        "classifier": classifier,
        "y": target_classes,
        "classifier_scale": 1.0,  # The guidance scale `s`
    }
)