1D Sequence Diffusion
Beyond images, diffusion models can be applied to any type of data, including 1D sequences. This is useful for generative tasks involving time-series data, audio, or any other kind of sequential information. This library provides dedicated Unet1D
, GaussianDiffusion1D
, and Trainer1D
classes for this purpose.
Unet1D
and GaussianDiffusion1D
These classes are direct analogues of their 2D counterparts, but with all 2D operations (Conv2d
, Upsample
, etc.) replaced by their 1D equivalents.
Unet1D
: A U-Net architecture that operates on 1D data of shape(batch, channels, sequence_length)
.GaussianDiffusion1D
: Manages the 1D diffusion process, taking aUnet1D
as its core predictive model.
Example Usage
Here's how to set up and train a 1D diffusion model.
1. Prepare the Data
Your data should be a PyTorch tensor. For this example, we'll create some random data.
import torch
# Let's create a dataset of 64 sequences.
# Each sequence has 32 channels (features) and a length of 128.
training_seq = torch.rand(64, 32, 128)
2. Define the Model and Diffusion Process
Instantiate Unet1D
and wrap it with GaussianDiffusion1D
.
from denoising_diffusion_pytorch import Unet1D, GaussianDiffusion1D
model = Unet1D(
dim = 64,
dim_mults = (1, 2, 4, 8),
channels = 32 # Must match the number of channels in your data
)
diffusion = GaussianDiffusion1D(
model,
seq_length = 128, # Must match the length of your sequences
timesteps = 1000,
objective = 'pred_v'
)
3. Training
You can train with a manual loop:
# Training with a manual loop
loss = diffusion(training_seq)
loss.backward()
# ... optimizer step ...
Or, more conveniently, use the Trainer1D
class.
from denoising_diffusion_pytorch import Trainer1D, Dataset1D
# The library provides a simple Dataset wrapper
dataset = Dataset1D(training_seq)
trainer = Trainer1D(
diffusion,
dataset = dataset,
train_batch_size = 32,
train_lr = 8e-5,
train_num_steps = 700000, # total training steps
gradient_accumulate_every = 2, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
amp = True # turn on mixed precision
)
trainer.train()
4. Sampling
After training, you can generate new sequences.
# after a lot of training
sampled_seq = diffusion.sample(batch_size = 4)
sampled_seq.shape # (4, 32, 128)
Note: Unlike the 2D Trainer
, Trainer1D
does not perform any evaluation (like FID) on the generated samples, as the nature of the 1D data is unknown. You can customize the training loop by performing an editable install (pip install -e .
) and modifying the Trainer1D
class to include metrics relevant to your specific data type.