Karras U-Net (Magnitude Preserving)

This module implements the advanced U-Net architecture from the paper "Analyzing and Improving the Training Dynamics of Diffusion Models" by Karras et al. (2023). This architecture is meticulously designed to be "magnitude preserving," which means the signal magnitude remains consistent throughout the network. This property dramatically stabilizes training and improves the final performance of the diffusion model.

Implementations are available for 2D images (karras_unet.py), 1D sequences (karras_unet_1d.py), and 3D volumes (karras_unet_3d.py).

Core Concepts

The Karras U-Net departs from traditional designs by removing all sources of implicit magnitude changes:

  • No Biases: All Conv and Linear layers are bias-free.
  • No Normalization Layers: Layers like LayerNorm or GroupNorm are completely removed.
  • Forced Weight Normalization: The weights of all Conv and Linear layers are explicitly re-normalized at every forward pass during training.
  • Magnitude-Preserving Operations: Standard operations are replaced with magnitude-preserving equivalents:
    • MPAdd: Replaces residual additions (x + res).
    • MPCat: Replaces channel-wise concatenation for skip connections.
    • MPSiLU: A scaled SiLU activation function.
    • PixelNorm: Normalizes features across the channel dimension.

Implementation

The primary class is KarrasUnet. It is built from Encoder and Decoder blocks that use the special magnitude-preserving modules.

KarrasUnet

This is the main 2D U-Net. Its structure is defined by num_downsamples and num_blocks_per_stage.

from denoising_diffusion_pytorch.karras_unet import KarrasUnet

unet = KarrasUnet(
    image_size = 64,
    dim = 192,
    dim_max = 768,
    num_classes = 1000, # for class-conditional models
    channels = 4,
    num_downsamples = 3,
    num_blocks_per_stage = 4,
    attn_res = (16, 8) # resolutions at which to use attention
)

images = torch.randn(2, 4, 64, 64)

denoised_images = unet(
    images,
    time = torch.ones(2,),
    class_labels = torch.randint(0, 1000, (2,))
)

KarrasUnet1D

A 1D version of the Karras U-Net, suitable for sequential data.

from denoising_diffusion_pytorch.karras_unet_1d import KarrasUnet1D

unet_1d = KarrasUnet1D(
    seq_len = 256,
    dim = 192,
    channels = 2
)

sequences = torch.randn(2, 2, 256)
denoised_sequences = unet_1d(sequences, time = torch.ones(2,))

KarrasUnet3D

A 3D version for volumetric data like videos or medical scans.

from denoising_diffusion_pytorch.karras_unet_3d import KarrasUnet3D

unet_3d = KarrasUnet3D(
    frames = 32,
    image_size = 64,
    dim = 8,
    num_downsamples = 3,
    factorize_space_time_attn = True # Use separate attention for space and time
)

video = torch.randn(2, 4, 32, 64, 64)
denoised_video = unet_3d(video, time = torch.ones(2,))

The paper recommends a specific learning rate schedule, InvSqrtDecayLRSched, for optimal training of this architecture.

from denoising_diffusion_pytorch.karras_unet import InvSqrtDecayLRSched
from torch.optim import Adam

optimizer = Adam(unet.parameters())
scheduler = InvSqrtDecayLRSched(optimizer)

# In your training loop:
# optimizer.step()
# scheduler.step()