RePaint: Inpainting with Diffusion Models

RePaint is an algorithm for image inpainting that uses a pre-trained, unconditional diffusion model. It was introduced in the paper "RePaint: Inpainting using Denoising Diffusion Probabilistic Models" by Lugmayr et al. The key idea is to condition the sampling process on the known pixels of an image by repeatedly applying the forward diffusion process to the known regions and sampling the unknown regions.

This library provides an implementation of the RePaint algorithm in denoising_diffusion_pytorch/repaint.py.

Core Concept

The RePaint algorithm modifies the standard DDPM sampling loop (p_sample_loop). At each reverse step t, it does the following:

  1. Sample the unknown region: It uses the diffusion model's prediction to sample the noisy image x_{t-1} for the entire image.
  2. Noise the known region: It takes the ground truth image x_0 (the parts we want to keep) and applies the forward process to noise it to step t-1.
  3. Combine: It replaces the pixels in the known region of the sampled image from step 1 with the noised ground truth pixels from step 2.

This ensures that at every step of the reverse process, the known regions are consistent with the original image, while the unknown (masked) regions are generated coherently by the diffusion model.

To improve quality, RePaint also introduces a "resampling" step, where the reverse process jumps back several steps and denoises the same region multiple times. This helps harmonize the boundary between the known and unknown pixels.

Implementation in repaint.py

The GaussianDiffusion class in repaint.py overrides the sampling methods to accept a ground truth image (gt) and a mask.

  • Mask: The mask should be a tensor of the same height and width as the image, with a value of 1 for known pixels and 0 for unknown (inpainting) pixels.

p_sample

The p_sample method is modified to include the core RePaint logic:

# From repaint.py's GaussianDiffusion class
def p_sample(self, x, t: int, ..., gt=None, mask=None):
    if mask is not None:
        # ... noise the ground truth image to step t ...
        weighed_gt = gt_part + noise_part
        # Combine the denoised unknown region with the noised known region
        x = (mask * weighed_gt) + ((1 - mask) * x)

    # ... proceed with the standard denoising step ...

p_sample_loop

The sampling loop orchestrates the process and includes the resampling logic.

# From repaint.py's GaussianDiffusion class
def p_sample_loop(self, ..., resample=True, resample_iter=10, ...):
    # ... standard loop from T to 0 ...
        img, x_start = self.p_sample(x=img, t=t, ..., gt=gt, mask=mask)

        # Resampling loop
        if resample is True and (t > 0) and (t % resample_every == 0):
            # ... jump back t_jump steps and re-denoise for resample_iter iterations ...

Usage Example

To perform inpainting, you need a pre-trained unconditional diffusion model.

import torch
from denoising_diffusion_pytorch.repaint import Unet, GaussianDiffusion

# 1. Load a pre-trained unconditional model
model = Unet(dim=64, dim_mults=(1, 2, 4, 8))

diffusion = GaussianDiffusion(
    model,
    image_size=128,
    timesteps=1000
)

# Assume model is trained and loaded
# diffusion.load('./path/to/model.pt')

# 2. Prepare the ground truth image and mask
# Let's say we want to inpaint a 4x4 image batch
gt_images = torch.randn(4, 3, 128, 128) # Your ground truth images

# Create a mask (e.g., a square in the middle to inpaint)
mask = torch.ones_like(gt_images)
mask[:, :, 64-16:64+16, 64-16:64+16] = 0 # 0 for unknown region, 1 for known

# 3. Call the sample method with gt and mask
inpainted_images = diffusion.sample(
    gt = gt_images,
    mask = mask,
    resample = True,      # Enable resampling for better quality
    resample_iter = 5,    # Number of resampling iterations
    resample_jump = 10,   # How many steps to jump back for resampling
    resample_every = 50   # How often to perform resampling
)

# inpainted_images will contain the original image with the masked region filled in