Elucidated Diffusion

This module implements the model from the paper "Elucidating the Design Space of Diffusion-Based Generative Models" by Karras et al. (2022). This work provides a new theoretical framework for diffusion models and proposes a set of design choices that lead to significantly improved performance and faster sampling.

The implementation can be found in denoising_diffusion_pytorch/elucidated_diffusion.py.

Core Concepts

Karras et al. re-formulate the diffusion process using stochastic differential equations (SDEs) and introduce a number of key improvements:

  1. Noise Schedule: A new noise schedule is proposed that is defined in terms of noise standard deviation σ rather than α and β. The schedule is designed to cover a wide range of noise levels, from σ_max to σ_min.
  2. Preconditioning: The U-Net's input, output, and time embedding are re-scaled using functions c_in(σ), c_out(σ), c_skip(σ), and c_noise(σ). This preconditioning ensures that the network operates on a signal of constant variance, which stabilizes training and improves performance.
  3. Solver: The sampling process is viewed as solving an ordinary differential equation (ODE). The paper proposes a second-order solver (Heun's method) for more accurate and efficient sampling.
  4. Stochastic Sampling: A "churn" parameter S_churn is introduced to inject a controlled amount of noise during sampling, which can help the solver escape local minima and find better solutions.

Implementation: ElucidatedDiffusion

The ElucidatedDiffusion class wraps a Unet and implements the logic described in the paper.

Initialization

from denoising_diffusion_pytorch import Unet
from denoising_diffusion_pytorch.elucidated_diffusion import ElucidatedDiffusion

# The U-Net must have `random_or_learned_sinusoidal_cond = True`
model = Unet(
    dim=64,
    random_or_learned_sinusoidal_cond=True
)

diffusion = ElucidatedDiffusion(
    model,
    image_size = 128,
    num_sample_steps = 32, # Much fewer steps are needed
    sigma_min = 0.002,
    sigma_max = 80,
    rho = 7,              # Controls the noise schedule curve
    # Stochastic sampling parameters
    S_churn = 80,
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
)

Preconditioned Network

The core of the model is the preconditioned_network_forward method, which applies the c_* scaling functions before and after passing the data through the U-Net.

Sampling

The sample method implements the second-order ODE solver with stochastic churn. It first generates a noise schedule (sample_schedule) and then iteratively denoises the image, applying the Heun's method correction at each step.

This implementation also includes an alternative sampler, sample_using_dpmpp, which uses the DPM-Solver++ for potentially even faster and higher-quality sampling.

# Training is similar to other models
images = torch.randn(2, 3, 128, 128)
loss = diffusion(images)
loss.backward()

# Sampling
# Using the default Heun solver with stochastic churn
sampled_images = diffusion.sample(batch_size = 4)

# Using the DPM-Solver++
sampled_images_dpm = diffusion.sample_using_dpmpp(batch_size = 4)