Elucidated Diffusion
This module implements the model from the paper "Elucidating the Design Space of Diffusion-Based Generative Models" by Karras et al. (2022). This work provides a new theoretical framework for diffusion models and proposes a set of design choices that lead to significantly improved performance and faster sampling.
The implementation can be found in denoising_diffusion_pytorch/elucidated_diffusion.py
.
Core Concepts
Karras et al. re-formulate the diffusion process using stochastic differential equations (SDEs) and introduce a number of key improvements:
- Noise Schedule: A new noise schedule is proposed that is defined in terms of noise standard deviation
σ
rather thanα
andβ
. The schedule is designed to cover a wide range of noise levels, fromσ_max
toσ_min
. - Preconditioning: The U-Net's input, output, and time embedding are re-scaled using functions
c_in(σ)
,c_out(σ)
,c_skip(σ)
, andc_noise(σ)
. This preconditioning ensures that the network operates on a signal of constant variance, which stabilizes training and improves performance. - Solver: The sampling process is viewed as solving an ordinary differential equation (ODE). The paper proposes a second-order solver (Heun's method) for more accurate and efficient sampling.
- Stochastic Sampling: A "churn" parameter
S_churn
is introduced to inject a controlled amount of noise during sampling, which can help the solver escape local minima and find better solutions.
Implementation: ElucidatedDiffusion
The ElucidatedDiffusion
class wraps a Unet
and implements the logic described in the paper.
Initialization
from denoising_diffusion_pytorch import Unet
from denoising_diffusion_pytorch.elucidated_diffusion import ElucidatedDiffusion
# The U-Net must have `random_or_learned_sinusoidal_cond = True`
model = Unet(
dim=64,
random_or_learned_sinusoidal_cond=True
)
diffusion = ElucidatedDiffusion(
model,
image_size = 128,
num_sample_steps = 32, # Much fewer steps are needed
sigma_min = 0.002,
sigma_max = 80,
rho = 7, # Controls the noise schedule curve
# Stochastic sampling parameters
S_churn = 80,
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
)
Preconditioned Network
The core of the model is the preconditioned_network_forward
method, which applies the c_*
scaling functions before and after passing the data through the U-Net.
Sampling
The sample
method implements the second-order ODE solver with stochastic churn. It first generates a noise schedule (sample_schedule
) and then iteratively denoises the image, applying the Heun's method correction at each step.
This implementation also includes an alternative sampler, sample_using_dpmpp
, which uses the DPM-Solver++ for potentially even faster and higher-quality sampling.
# Training is similar to other models
images = torch.randn(2, 3, 128, 128)
loss = diffusion(images)
loss.backward()
# Sampling
# Using the default Heun solver with stochastic churn
sampled_images = diffusion.sample(batch_size = 4)
# Using the DPM-Solver++
sampled_images_dpm = diffusion.sample_using_dpmpp(batch_size = 4)