Guided Diffusion
Guided Diffusion refers to a technique where the sampling process of a diffusion model is guided by the gradients of an external classifier. This allows you to steer the generation towards a desired class, improving sample fidelity and class-conditional generation quality.
The implementation for this can be found in denoising_diffusion_pytorch/guided_diffusion.py
.
Core Concept
The method, introduced in "Diffusion Models Beat GANs on Image Synthesis" by Dhariwal and Nichol, modifies the reverse process sampling step. At each timestep t
, after the diffusion model predicts the mean μ_θ(x_t)
for the previous step, an additional term is added:
μ_guided = μ_θ(x_t) + s * Σ_θ * ∇_x_t log(p_φ(y|x_t))
Where:
s
is the guidance scale.Σ_θ
is the variance of the reverse step.p_φ(y|x_t)
is the probability of the target classy
given the noisy imagex_t
, as predicted by a separate, noise-aware classifierφ
.∇_x_t log(p_φ(y|x_t))
is the gradient of the log-probability with respect to the noisy imagex_t
.
This gradient essentially pushes the sampling process towards a direction that makes the image more recognizable as class y
to the classifier.
Implementation Details
The guided_diffusion.py
file provides a GaussianDiffusion
class with modified sampling methods.
The key change is in the p_sample
method, which now accepts an optional cond_fn
.
# From guided_diffusion.py
def p_sample(self, x, t: int, x_self_cond = None, cond_fn=None, guidance_kwargs=None):
# ... standard reverse step calculation ...
model_mean, variance, model_log_variance, x_start = self.p_mean_variance(...)
if exists(cond_fn) and exists(guidance_kwargs):
model_mean = self.condition_mean(cond_fn, model_mean, variance, x, batched_times, guidance_kwargs)
# ... add noise and return ...
The condition_mean
method computes the gradient from the cond_fn
and adds it to the predicted mean.
Usage Example
To use guided diffusion, you need two components:
- A trained diffusion model (
Unet
+GaussianDiffusion
). - A separate classifier trained on noisy images.
import torch
import torch.nn.functional as F
from torch import nn
from denoising_diffusion_pytorch.guided_diffusion import Unet, GaussianDiffusion
# 1. A pre-trained diffusion model
model = Unet(dim=64, dim_mults=(1, 2, 4, 8))
image_size = 128
diffusion = GaussianDiffusion(model, image_size=image_size, timesteps=1000)
# ... assume `diffusion` model is trained ...
# 2. A noise-aware classifier (example implementation)
class Classifier(nn.Module):
def __init__(self, image_size, num_classes):
super().__init__()
# A simple linear classifier for demonstration
self.linear_img = nn.Linear(image_size * image_size * 3, num_classes)
def forward(self, x, t):
# Classifier must accept both the noisy image `x` and timestep `t`
x_flat = x.view(x.shape[0], -1)
return self.linear_img(x_flat)
classifier = Classifier(image_size=image_size, num_classes=1000)
# ... assume `classifier` is trained on noisy images ...
# 3. A gradient function (`cond_fn`)
def classifier_cond_fn(x, t, classifier, y, classifier_scale=1.0):
assert y is not None
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
logits = classifier(x_in, t)
log_probs = F.log_softmax(logits, dim=-1)
selected = log_probs[range(len(logits)), y.view(-1)]
grad = torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale
return grad
# 4. Sampling with guidance
batch_size = 4
target_classes = torch.randint(0, 1000, (batch_size,))
sampled_images = diffusion.sample(
batch_size=batch_size,
cond_fn=classifier_cond_fn,
guidance_kwargs={
"classifier": classifier,
"y": target_classes,
"classifier_scale": 1.0, # The guidance scale `s`
}
)