Guided Diffusion
Guided Diffusion refers to a technique where the sampling process of a diffusion model is guided by the gradients of an external classifier. This allows you to steer the generation towards a desired class, improving sample fidelity and class-conditional generation quality.
The implementation for this can be found in denoising_diffusion_pytorch/guided_diffusion.py.
Core Concept
The method, introduced in "Diffusion Models Beat GANs on Image Synthesis" by Dhariwal and Nichol, modifies the reverse process sampling step. At each timestep t, after the diffusion model predicts the mean μ_θ(x_t) for the previous step, an additional term is added:
μ_guided = μ_θ(x_t) + s * Σ_θ * ∇_x_t log(p_φ(y|x_t))
Where:
sis the guidance scale.Σ_θis the variance of the reverse step.p_φ(y|x_t)is the probability of the target classygiven the noisy imagex_t, as predicted by a separate, noise-aware classifierφ.∇_x_t log(p_φ(y|x_t))is the gradient of the log-probability with respect to the noisy imagex_t.
This gradient essentially pushes the sampling process towards a direction that makes the image more recognizable as class y to the classifier.
Implementation Details
The guided_diffusion.py file provides a GaussianDiffusion class with modified sampling methods.
The key change is in the p_sample method, which now accepts an optional cond_fn.
# From guided_diffusion.py
def p_sample(self, x, t: int, x_self_cond = None, cond_fn=None, guidance_kwargs=None):
# ... standard reverse step calculation ...
model_mean, variance, model_log_variance, x_start = self.p_mean_variance(...)
if exists(cond_fn) and exists(guidance_kwargs):
model_mean = self.condition_mean(cond_fn, model_mean, variance, x, batched_times, guidance_kwargs)
# ... add noise and return ...
The condition_mean method computes the gradient from the cond_fn and adds it to the predicted mean.
Usage Example
To use guided diffusion, you need two components:
- A trained diffusion model (
Unet+GaussianDiffusion). - A separate classifier trained on noisy images.
import torch
import torch.nn.functional as F
from torch import nn
from denoising_diffusion_pytorch.guided_diffusion import Unet, GaussianDiffusion
# 1. A pre-trained diffusion model
model = Unet(dim=64, dim_mults=(1, 2, 4, 8))
image_size = 128
diffusion = GaussianDiffusion(model, image_size=image_size, timesteps=1000)
# ... assume `diffusion` model is trained ...
# 2. A noise-aware classifier (example implementation)
class Classifier(nn.Module):
def __init__(self, image_size, num_classes):
super().__init__()
# A simple linear classifier for demonstration
self.linear_img = nn.Linear(image_size * image_size * 3, num_classes)
def forward(self, x, t):
# Classifier must accept both the noisy image `x` and timestep `t`
x_flat = x.view(x.shape[0], -1)
return self.linear_img(x_flat)
classifier = Classifier(image_size=image_size, num_classes=1000)
# ... assume `classifier` is trained on noisy images ...
# 3. A gradient function (`cond_fn`)
def classifier_cond_fn(x, t, classifier, y, classifier_scale=1.0):
assert y is not None
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
logits = classifier(x_in, t)
log_probs = F.log_softmax(logits, dim=-1)
selected = log_probs[range(len(logits)), y.view(-1)]
grad = torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale
return grad
# 4. Sampling with guidance
batch_size = 4
target_classes = torch.randint(0, 1000, (batch_size,))
sampled_images = diffusion.sample(
batch_size=batch_size,
cond_fn=classifier_cond_fn,
guidance_kwargs={
"classifier": classifier,
"y": target_classes,
"classifier_scale": 1.0, # The guidance scale `s`
}
)