Elucidated Diffusion
This module implements the model from the paper "Elucidating the Design Space of Diffusion-Based Generative Models" by Karras et al. (2022). This work provides a new theoretical framework for diffusion models and proposes a set of design choices that lead to significantly improved performance and faster sampling.
The implementation can be found in denoising_diffusion_pytorch/elucidated_diffusion.py.
Core Concepts
Karras et al. re-formulate the diffusion process using stochastic differential equations (SDEs) and introduce a number of key improvements:
- Noise Schedule: A new noise schedule is proposed that is defined in terms of noise standard deviation
σrather thanαandβ. The schedule is designed to cover a wide range of noise levels, fromσ_maxtoσ_min. - Preconditioning: The U-Net's input, output, and time embedding are re-scaled using functions
c_in(σ),c_out(σ),c_skip(σ), andc_noise(σ). This preconditioning ensures that the network operates on a signal of constant variance, which stabilizes training and improves performance. - Solver: The sampling process is viewed as solving an ordinary differential equation (ODE). The paper proposes a second-order solver (Heun's method) for more accurate and efficient sampling.
- Stochastic Sampling: A "churn" parameter
S_churnis introduced to inject a controlled amount of noise during sampling, which can help the solver escape local minima and find better solutions.
Implementation: ElucidatedDiffusion
The ElucidatedDiffusion class wraps a Unet and implements the logic described in the paper.
Initialization
from denoising_diffusion_pytorch import Unet
from denoising_diffusion_pytorch.elucidated_diffusion import ElucidatedDiffusion
# The U-Net must have `random_or_learned_sinusoidal_cond = True`
model = Unet(
dim=64,
random_or_learned_sinusoidal_cond=True
)
diffusion = ElucidatedDiffusion(
model,
image_size = 128,
num_sample_steps = 32, # Much fewer steps are needed
sigma_min = 0.002,
sigma_max = 80,
rho = 7, # Controls the noise schedule curve
# Stochastic sampling parameters
S_churn = 80,
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
)
Preconditioned Network
The core of the model is the preconditioned_network_forward method, which applies the c_* scaling functions before and after passing the data through the U-Net.
Sampling
The sample method implements the second-order ODE solver with stochastic churn. It first generates a noise schedule (sample_schedule) and then iteratively denoises the image, applying the Heun's method correction at each step.
This implementation also includes an alternative sampler, sample_using_dpmpp, which uses the DPM-Solver++ for potentially even faster and higher-quality sampling.
# Training is similar to other models
images = torch.randn(2, 3, 128, 128)
loss = diffusion(images)
loss.backward()
# Sampling
# Using the default Heun solver with stochastic churn
sampled_images = diffusion.sample(batch_size = 4)
# Using the DPM-Solver++
sampled_images_dpm = diffusion.sample_using_dpmpp(batch_size = 4)