The Gaussian Diffusion Process
The GaussianDiffusion
class is the central component that implements the diffusion and reverse processes. It orchestrates the training and sampling logic, using a provided neural network (typically a Unet
) to predict the noise at each step.
Initialization
When you create an instance of GaussianDiffusion
, you configure the entire diffusion process.
from denoising_diffusion_pytorch import Unet, GaussianDiffusion
model = Unet(dim = 64, dim_mults = (1, 2, 4, 8))
diffusion = GaussianDiffusion(
model,
image_size = 128,
timesteps = 1000,
sampling_timesteps = 250,
objective = 'pred_v',
beta_schedule = 'sigmoid'
)
Key Parameters
model
: The neural network that will learn to reverse the diffusion process. This is almost always aUnet
or a variant.image_size
(int or tuple): The size of the images you are training on. This is used to define the shape of the noise tensor during sampling.timesteps
(int): The total number of noising steps in the forward process. A common value is 1000.sampling_timesteps
(int, optional): If provided and less thantimesteps
, the model will use DDIM for faster sampling. If not provided, it defaults totimesteps
and uses the DDPM sampling algorithm.objective
(str): Defines what the model is trained to predict. This is a crucial choice:'pred_noise'
: The standard DDPM objective, where the model predicts the noise that was added to the image.'pred_x0'
: The model predicts the original, clean image (x_0
).'pred_v'
: The model predictsv
, a target from the "v-parameterization" proposed in the progressive distillation paper. This is often a good default.
beta_schedule
(str): The schedule for adding noise (defining the varianceβt
at each timestept
).'linear'
: The original schedule from the DDPM paper.'cosine'
: A schedule that adds noise more slowly at the beginning, often leading to better results.'sigmoid'
: A schedule that performs well for higher-resolution images (>64x64).
The Forward Process: q_sample
During training, the forward
method of the GaussianDiffusion
class is called. Internally, it performs these steps:
- Picks a random timestep
t
for each image in the batch. - Calls
q_sample(x_start, t, noise)
to apply noise to the clean imagex_start
according to the beta schedule, producing the noised imagex_t
. - Passes
x_t
andt
to the underlyingmodel
to get a prediction. - Calculates the loss between the model's prediction and the true target (which depends on the
objective
).
This process is encapsulated within the p_losses
method.
The Reverse Process: Sampling
To generate new images, we start with pure noise and iteratively denoise it.
DDPM Sampling (p_sample_loop
)
If sampling_timesteps
is not set or is equal to timesteps
, this method is used. It iterates from t = T-1
down to t = 0
, using the model's prediction at each step to estimate the image at t-1
.
# Assumes diffusion is already trained
sampled_images = diffusion.p_sample_loop((16, 3, 128, 128)) # shape = (batch, channels, height, width)
DDIM Sampling (ddim_sample
)
If sampling_timesteps
is set to a value less than timesteps
, this method is used. Denoising Diffusion Implicit Models (DDIM) allow for a much faster sampling process by skipping steps in the reverse chain. For example, with timesteps=1000
and sampling_timesteps=250
, you can get a good sample in just 250 steps instead of 1000.
# Assumes diffusion is already trained
sampled_images = diffusion.ddim_sample((16, 3, 128, 128))
The sample
Method
The diffusion.sample(batch_size=16)
method is a convenient wrapper that automatically calls either p_sample_loop
or ddim_sample
based on your configuration.