U-Net Architecture
The primary neural network architecture used in this library is the U-Net, a model originally designed for biomedical image segmentation but found to be exceptionally effective for diffusion models. Its structure, featuring skip connections between downsampling and upsampling paths, allows it to process information at multiple resolutions simultaneously.
The Core Unet
Class
The main U-Net implementation is in denoising_diffusion_pytorch.denoising_diffusion_pytorch.Unet
.
from denoising_diffusion_pytorch import Unet
model = Unet(
dim = 64,
init_dim = None,
out_dim = None,
dim_mults = (1, 2, 4, 8),
channels = 3,
self_condition = False,
flash_attn = True
)
Key Parameters
dim
(int): The base channel dimension. This is the number of channels after the initial convolution.dim_mults
(tuple of ints): A tuple that determines the channel dimensions at each stage of the U-Net. Thedim
is multiplied by each value in the tuple. Fordim=64
anddim_mults=(1, 2, 4, 8)
, the channel dimensions will be64, 128, 256, 512
in the downsampling path.channels
(int): The number of input channels in the image (e.g., 3 for RGB).self_condition
(bool): IfTrue
, the model will be able to condition its prediction on its own previous prediction. This can improve sample quality but slows down training by about 25%.flash_attn
(bool): IfTrue
, uses FlashAttention, a highly efficient attention implementation. This requires PyTorch 2.0+ and a compatible GPU.out_dim
(int, optional): The number of output channels. If not specified, it defaults tochannels
.
Architectural Components
1. Time Embeddings
The diffusion timestep t
is a critical piece of information for the model. It's converted into a vector embedding using sinusoidal position embeddings, similar to those used in Transformers. This embedding is then processed by a small MLP (time_mlp
) and injected into each ResNet block.
2. Downsampling Path
The U-Net starts with an initial convolution (init_conv
) followed by a series of downsampling stages. Each stage typically consists of:
- Two
ResnetBlock
modules. - An
Attention
orLinearAttention
module. - A
Downsample
module that reduces the spatial resolution by 2x and increases the channel dimension.
Skip connections are stored from the output of each ResNet block.
3. Middle Block
At the bottleneck of the U-Net, there is a middle block consisting of:
- A
ResnetBlock
. - A full
Attention
module. - Another
ResnetBlock
.
4. Upsampling Path
The upsampling path mirrors the downsampling path. Each stage consists of:
- An
Upsample
module that increases spatial resolution and decreases channels. - Concatenation with the corresponding skip connection from the downsampling path.
- Two
ResnetBlock
modules. - An
Attention
orLinearAttention
module.
5. Final Block
After the last upsampling stage, a final residual block (final_res_block
) and a final 1x1 convolution (final_conv
) produce the output, which has the same dimensions as the input image.