Simple Diffusion (U-ViT)

This module implements the model from the paper "simple diffusion: End-to-end diffusion for high resolution images" by Hoogeboom et al. This paper introduces a novel architecture that replaces the bottleneck of the standard U-Net with a Vision Transformer (ViT). This hybrid U-ViT design combines the multi-scale feature extraction of a U-Net with the global modeling capabilities of a Transformer.

The implementation can be found in denoising_diffusion_pytorch/simple_diffusion.py.

Core Concept: The U-ViT

The U-ViT architecture maintains the overall structure of a U-Net with a downsampling path, an upsampling path, and skip connections. The key innovation is at the bottleneck:

  1. Downsampling: The input image is processed through several stages of ResNet blocks and downsampling layers, just like a standard U-Net.
  2. Transformer Bottleneck: At the lowest resolution, the feature map is flattened into a sequence of tokens. This sequence is then processed by a standard Transformer encoder, which consists of multiple layers of self-attention and feed-forward networks. This allows the model to capture long-range dependencies across the entire feature map.
  3. Upsampling: The output from the Transformer is reshaped back into a 2D feature map and then processed by the upsampling path, which uses skip connections from the encoder to reconstruct the high-resolution output.

Implementation

The simple_diffusion.py file contains the UViT class and a GaussianDiffusion wrapper adapted for its continuous-time formulation.

UViT Class

from denoising_diffusion_pytorch.simple_diffusion import UViT

model = UViT(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
    vit_depth = 6,
    vit_dropout = 0.2,
    attn_dim_head = 32,
    attn_heads = 4
)
  • The U-Net part is defined by dim and dim_mults.
  • The Transformer bottleneck is defined by vit_depth, attn_heads, etc.

GaussianDiffusion Wrapper

The paper uses a continuous-time diffusion formulation with a specific log-SNR noise schedule. The GaussianDiffusion class in this file is designed for this setup.

from denoising_diffusion_pytorch.simple_diffusion import GaussianDiffusion

diffusion = GaussianDiffusion(
    model, # an instance of UViT
    image_size = 128,
    pred_objective = 'v',       # 'v' prediction is recommended
    num_sample_steps = 250
)

Usage Example

import torch
from denoising_diffusion_pytorch.simple_diffusion import UViT, GaussianDiffusion

# 1. Define the U-ViT model
model = UViT(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
    vit_depth = 6
)

# 2. Wrap with the continuous-time GaussianDiffusion
diffusion = GaussianDiffusion(
    model,
    image_size = 128,
    channels = 3,
    num_sample_steps = 250
)

# --- Training ---
# The forward pass expects a normalized image and returns the loss
training_images = torch.rand(4, 3, 128, 128)
loss = diffusion(training_images)
loss.backward()

# --- Sampling ---
# After training, generate samples
sampled_images = diffusion.sample(batch_size = 4)
sampled_images.shape # (4, 3, 128, 128)