The Trainer Class

To simplify the process of training a diffusion model, this library provides a Trainer class. It encapsulates the entire training loop, including data loading, optimization, EMA (Exponential Moving Average) updates, sampling, and saving checkpoints.

Basic Usage

Using the Trainer is straightforward. You instantiate it with your diffusion model and a path to your dataset, then call the train() method.

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer

model = Unet(dim=64, dim_mults=(1, 2, 4, 8))

diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000)

trainer = Trainer(
    diffusion,
    'path/to/your/images',
    train_batch_size = 32,
    train_num_steps = 700000
)

trainer.train()

Key Parameters

The Trainer class accepts several parameters to customize the training process:

Parameter Type Description
diffusion_model nn.Module An instance of GaussianDiffusion or one of its variants.
folder str The path to the folder containing your training images.
train_batch_size int The batch size for training.
gradient_accumulate_every int The number of steps to accumulate gradients over before performing an optimizer step.
train_lr float The learning rate for the Adam optimizer.
train_num_steps int The total number of training steps to perform.
ema_decay float The decay rate for the exponential moving average of the model weights.
ema_update_every int How often to update the EMA model.
amp bool If True, enables automatic mixed-precision training for better performance.
save_and_sample_every int The interval (in steps) at which to save a checkpoint and generate sample images.
num_samples int The number of images to generate at each sampling interval.
results_folder str The directory where checkpoints and samples will be saved.
calculate_fid bool If True, computes the Fréchet Inception Distance (FID) at sampling intervals.

Multi-GPU Training

The Trainer is built on top of 🤗 accelerate, making multi-GPU training seamless. You don't need to change your training script at all. Simply configure accelerate for your machine:

accelerate config

And then launch your script with accelerate launch:

accelerate launch your_script.py

The Trainer will automatically handle distributing the data and model across all available GPUs.

FID Calculation

If calculate_fid is set to True, the trainer will periodically compare a large number of generated samples to your training dataset to compute the FID score, a common metric for evaluating the quality of generative models. This process can be time-consuming but provides a quantitative measure of your model's progress.

First, it will pre-calculate and cache the statistics for your real dataset. Then, at each sampling interval, it generates samples from the EMA model and calculates the FID score. The score is printed to the console.