Simple Diffusion (U-ViT)
This module implements the model from the paper "simple diffusion: End-to-end diffusion for high resolution images" by Hoogeboom et al. This paper introduces a novel architecture that replaces the bottleneck of the standard U-Net with a Vision Transformer (ViT). This hybrid U-ViT design combines the multi-scale feature extraction of a U-Net with the global modeling capabilities of a Transformer.
The implementation can be found in denoising_diffusion_pytorch/simple_diffusion.py
.
Core Concept: The U-ViT
The U-ViT architecture maintains the overall structure of a U-Net with a downsampling path, an upsampling path, and skip connections. The key innovation is at the bottleneck:
- Downsampling: The input image is processed through several stages of ResNet blocks and downsampling layers, just like a standard U-Net.
- Transformer Bottleneck: At the lowest resolution, the feature map is flattened into a sequence of tokens. This sequence is then processed by a standard Transformer encoder, which consists of multiple layers of self-attention and feed-forward networks. This allows the model to capture long-range dependencies across the entire feature map.
- Upsampling: The output from the Transformer is reshaped back into a 2D feature map and then processed by the upsampling path, which uses skip connections from the encoder to reconstruct the high-resolution output.
Implementation
The simple_diffusion.py
file contains the UViT
class and a GaussianDiffusion
wrapper adapted for its continuous-time formulation.
UViT
Class
from denoising_diffusion_pytorch.simple_diffusion import UViT
model = UViT(
dim = 64,
dim_mults = (1, 2, 4, 8),
vit_depth = 6,
vit_dropout = 0.2,
attn_dim_head = 32,
attn_heads = 4
)
- The U-Net part is defined by
dim
anddim_mults
. - The Transformer bottleneck is defined by
vit_depth
,attn_heads
, etc.
GaussianDiffusion
Wrapper
The paper uses a continuous-time diffusion formulation with a specific log-SNR noise schedule. The GaussianDiffusion
class in this file is designed for this setup.
from denoising_diffusion_pytorch.simple_diffusion import GaussianDiffusion
diffusion = GaussianDiffusion(
model, # an instance of UViT
image_size = 128,
pred_objective = 'v', # 'v' prediction is recommended
num_sample_steps = 250
)
Usage Example
import torch
from denoising_diffusion_pytorch.simple_diffusion import UViT, GaussianDiffusion
# 1. Define the U-ViT model
model = UViT(
dim = 64,
dim_mults = (1, 2, 4, 8),
vit_depth = 6
)
# 2. Wrap with the continuous-time GaussianDiffusion
diffusion = GaussianDiffusion(
model,
image_size = 128,
channels = 3,
num_sample_steps = 250
)
# --- Training ---
# The forward pass expects a normalized image and returns the loss
training_images = torch.rand(4, 3, 128, 128)
loss = diffusion(training_images)
loss.backward()
# --- Sampling ---
# After training, generate samples
sampled_images = diffusion.sample(batch_size = 4)
sampled_images.shape # (4, 3, 128, 128)