The Trainer Class
To simplify the process of training a diffusion model, this library provides a Trainer
class. It encapsulates the entire training loop, including data loading, optimization, EMA (Exponential Moving Average) updates, sampling, and saving checkpoints.
Basic Usage
Using the Trainer
is straightforward. You instantiate it with your diffusion model and a path to your dataset, then call the train()
method.
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
model = Unet(dim=64, dim_mults=(1, 2, 4, 8))
diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000)
trainer = Trainer(
diffusion,
'path/to/your/images',
train_batch_size = 32,
train_num_steps = 700000
)
trainer.train()
Key Parameters
The Trainer
class accepts several parameters to customize the training process:
Parameter | Type | Description |
---|---|---|
diffusion_model |
nn.Module |
An instance of GaussianDiffusion or one of its variants. |
folder |
str |
The path to the folder containing your training images. |
train_batch_size |
int |
The batch size for training. |
gradient_accumulate_every |
int |
The number of steps to accumulate gradients over before performing an optimizer step. |
train_lr |
float |
The learning rate for the Adam optimizer. |
train_num_steps |
int |
The total number of training steps to perform. |
ema_decay |
float |
The decay rate for the exponential moving average of the model weights. |
ema_update_every |
int |
How often to update the EMA model. |
amp |
bool |
If True , enables automatic mixed-precision training for better performance. |
save_and_sample_every |
int |
The interval (in steps) at which to save a checkpoint and generate sample images. |
num_samples |
int |
The number of images to generate at each sampling interval. |
results_folder |
str |
The directory where checkpoints and samples will be saved. |
calculate_fid |
bool |
If True , computes the Fréchet Inception Distance (FID) at sampling intervals. |
Multi-GPU Training
The Trainer
is built on top of 🤗 accelerate
, making multi-GPU training seamless. You don't need to change your training script at all. Simply configure accelerate
for your machine:
accelerate config
And then launch your script with accelerate launch
:
accelerate launch your_script.py
The Trainer
will automatically handle distributing the data and model across all available GPUs.
FID Calculation
If calculate_fid
is set to True
, the trainer will periodically compare a large number of generated samples to your training dataset to compute the FID score, a common metric for evaluating the quality of generative models. This process can be time-consuming but provides a quantitative measure of your model's progress.
First, it will pre-calculate and cache the statistics for your real dataset. Then, at each sampling interval, it generates samples from the EMA model and calculates the FID score. The score is printed to the console.