Karras U-Net (Magnitude Preserving)
This module implements the advanced U-Net architecture from the paper "Analyzing and Improving the Training Dynamics of Diffusion Models" by Karras et al. (2023). This architecture is meticulously designed to be "magnitude preserving," which means the signal magnitude remains consistent throughout the network. This property dramatically stabilizes training and improves the final performance of the diffusion model.
Implementations are available for 2D images (karras_unet.py
), 1D sequences (karras_unet_1d.py
), and 3D volumes (karras_unet_3d.py
).
Core Concepts
The Karras U-Net departs from traditional designs by removing all sources of implicit magnitude changes:
- No Biases: All
Conv
andLinear
layers are bias-free. - No Normalization Layers: Layers like
LayerNorm
orGroupNorm
are completely removed. - Forced Weight Normalization: The weights of all
Conv
andLinear
layers are explicitly re-normalized at every forward pass during training. - Magnitude-Preserving Operations: Standard operations are replaced with magnitude-preserving equivalents:
MPAdd
: Replaces residual additions (x + res
).MPCat
: Replaces channel-wise concatenation for skip connections.MPSiLU
: A scaled SiLU activation function.PixelNorm
: Normalizes features across the channel dimension.
Implementation
The primary class is KarrasUnet
. It is built from Encoder
and Decoder
blocks that use the special magnitude-preserving modules.
KarrasUnet
This is the main 2D U-Net. Its structure is defined by num_downsamples
and num_blocks_per_stage
.
from denoising_diffusion_pytorch.karras_unet import KarrasUnet
unet = KarrasUnet(
image_size = 64,
dim = 192,
dim_max = 768,
num_classes = 1000, # for class-conditional models
channels = 4,
num_downsamples = 3,
num_blocks_per_stage = 4,
attn_res = (16, 8) # resolutions at which to use attention
)
images = torch.randn(2, 4, 64, 64)
denoised_images = unet(
images,
time = torch.ones(2,),
class_labels = torch.randint(0, 1000, (2,))
)
KarrasUnet1D
A 1D version of the Karras U-Net, suitable for sequential data.
from denoising_diffusion_pytorch.karras_unet_1d import KarrasUnet1D
unet_1d = KarrasUnet1D(
seq_len = 256,
dim = 192,
channels = 2
)
sequences = torch.randn(2, 2, 256)
denoised_sequences = unet_1d(sequences, time = torch.ones(2,))
KarrasUnet3D
A 3D version for volumetric data like videos or medical scans.
from denoising_diffusion_pytorch.karras_unet_3d import KarrasUnet3D
unet_3d = KarrasUnet3D(
frames = 32,
image_size = 64,
dim = 8,
num_downsamples = 3,
factorize_space_time_attn = True # Use separate attention for space and time
)
video = torch.randn(2, 4, 32, 64, 64)
denoised_video = unet_3d(video, time = torch.ones(2,))
Recommended Learning Rate Schedule
The paper recommends a specific learning rate schedule, InvSqrtDecayLRSched
, for optimal training of this architecture.
from denoising_diffusion_pytorch.karras_unet import InvSqrtDecayLRSched
from torch.optim import Adam
optimizer = Adam(unet.parameters())
scheduler = InvSqrtDecayLRSched(optimizer)
# In your training loop:
# optimizer.step()
# scheduler.step()