U-Net Architecture

The primary neural network architecture used in this library is the U-Net, a model originally designed for biomedical image segmentation but found to be exceptionally effective for diffusion models. Its structure, featuring skip connections between downsampling and upsampling paths, allows it to process information at multiple resolutions simultaneously.

The Core Unet Class

The main U-Net implementation is in denoising_diffusion_pytorch.denoising_diffusion_pytorch.Unet.

from denoising_diffusion_pytorch import Unet

model = Unet(
    dim = 64,
    init_dim = None,
    out_dim = None,
    dim_mults = (1, 2, 4, 8),
    channels = 3,
    self_condition = False,
    flash_attn = True
)

Key Parameters

  • dim (int): The base channel dimension. This is the number of channels after the initial convolution.
  • dim_mults (tuple of ints): A tuple that determines the channel dimensions at each stage of the U-Net. The dim is multiplied by each value in the tuple. For dim=64 and dim_mults=(1, 2, 4, 8), the channel dimensions will be 64, 128, 256, 512 in the downsampling path.
  • channels (int): The number of input channels in the image (e.g., 3 for RGB).
  • self_condition (bool): If True, the model will be able to condition its prediction on its own previous prediction. This can improve sample quality but slows down training by about 25%.
  • flash_attn (bool): If True, uses FlashAttention, a highly efficient attention implementation. This requires PyTorch 2.0+ and a compatible GPU.
  • out_dim (int, optional): The number of output channels. If not specified, it defaults to channels.

Architectural Components

1. Time Embeddings

The diffusion timestep t is a critical piece of information for the model. It's converted into a vector embedding using sinusoidal position embeddings, similar to those used in Transformers. This embedding is then processed by a small MLP (time_mlp) and injected into each ResNet block.

2. Downsampling Path

The U-Net starts with an initial convolution (init_conv) followed by a series of downsampling stages. Each stage typically consists of:

  • Two ResnetBlock modules.
  • An Attention or LinearAttention module.
  • A Downsample module that reduces the spatial resolution by 2x and increases the channel dimension.

Skip connections are stored from the output of each ResNet block.

3. Middle Block

At the bottleneck of the U-Net, there is a middle block consisting of:

  • A ResnetBlock.
  • A full Attention module.
  • Another ResnetBlock.

4. Upsampling Path

The upsampling path mirrors the downsampling path. Each stage consists of:

  • An Upsample module that increases spatial resolution and decreases channels.
  • Concatenation with the corresponding skip connection from the downsampling path.
  • Two ResnetBlock modules.
  • An Attention or LinearAttention module.

5. Final Block

After the last upsampling stage, a final residual block (final_res_block) and a final 1x1 convolution (final_conv) produce the output, which has the same dimensions as the input image.