Artificial Intelligence 26 min read

Building a Simple Diffusion Model with Python

This tutorial walks through implementing a basic Denoising Diffusion Probabilistic Model in Python, explaining the forward noise schedule, reverse denoising training, and providing complete code for noise schedules, diffusion functions, residual and attention blocks, a UNet architecture, loss computation, and a training loop.

DaTaobao Tech
DaTaobao Tech
DaTaobao Tech
Building a Simple Diffusion Model with Python

Diffusion models are the backbone of most AIGC image generation models. They learn to recover an image from Gaussian noise by progressively denoising it. This article demonstrates how to implement a simple diffusion model from scratch using Python.

Theory

DDPM (Denoising Diffusion Probabilistic Models) consists of two parts:

A fixed forward process that gradually adds Gaussian noise to an image until it becomes pure noise.

A learnable reverse process that trains a neural network to denoise step by step, starting from pure noise.

Forward (Noise) Process

The forward process repeatedly adds noise to an image. With K diffusion steps, the original data is transformed into a random noise matrix.

Training Process

The reverse process learns to remove the noise. The network receives a noisy image and predicts the noise component.

Environment Packages

!pip install -q -U einops datasets matplotlib tqdm
%matplotlib inline
import math, torch, torch.nn as nn, torch.nn.functional as F
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from tqdm.auto import tqdm

Noise Schedules

# cosine schedule
def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

# linear schedule
def linear_beta_schedule(timesteps):
    beta_start, beta_end = 0.0001, 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

# quadratic schedule
def quadratic_beta_schedule(timesteps):
    beta_start, beta_end = 0.0001, 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

# sigmoid schedule
def sigmoid_beta_schedule(timesteps):
    beta_start, beta_end = 0.0001, 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

Forward Diffusion Function

def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

Residual Block

class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()
    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)
        if scale_shift is not None:
            scale, shift = scale_shift
            x = x * (scale + 1) + shift
        return self.act(x)

class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) if time_emb_dim else None
        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
    def forward(self, x, time_emb=None):
        h = self.block1(x)
        if self.mlp is not None and time_emb is not None:
            h = h + rearrange(self.mlp(time_emb), "b c -> b c 1 1")
        h = self.block2(h)
        return h + self.res_conv(x)

Attention Mechanisms

class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)
    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = [rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads) for t in qkv]
        q = q * self.scale
        sim = torch.einsum('b h d i, b h d j -> b h i j', q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)
        out = torch.einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = [rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads) for t in qkv]
        q = q.softmax(dim=-2) * self.scale
        k = k.softmax(dim=-1)
        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)

Time Embedding

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = time[:, None] * emb[None, :]
        return torch.cat((emb.sin(), emb.cos()), dim=-1)

U‑Net Architecture

class Unet(nn.Module):
    def __init__(self, dim, init_dim=None, out_dim=None, dim_mults=(1,2,4,8),
                 channels=3, with_time_emb=True, resnet_block_groups=8,
                 use_convnext=True, convnext_mult=2):
        super().__init__()
        self.channels = channels
        init_dim = init_dim or dim // 3 * 2
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
        dims = [init_dim, *[dim * m for m in dim_mults]]
        in_out = list(zip(dims[:-1], dims[1:]))
        block_klass = partial(ConvNextBlock, mult=convnext_mult) if use_convnext else partial(ResnetBlock, groups=resnet_block_groups)
        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            self.downs.append(nn.ModuleList([
                block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                Downsample(dim_out) if not is_last else nn.Identity(),
            ]))
        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            self.ups.append(nn.ModuleList([
                block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                Upsample(dim_in) if not is_last else nn.Identity(),
            ]))
        out_dim = out_dim or channels
        self.final_conv = nn.Sequential(
            block_klass(dim, dim),
            nn.Conv2d(dim, out_dim, 1),
        )
    def forward(self, x, time):
        x = self.init_conv(x)
        t = self.time_mlp(time) if self.time_mlp is not None else None
        h = []
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)
        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            x = upsample(x)
        return self.final_conv(x)

Loss Function

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)
    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()
    return loss

Training Loop

if __name__ == "__main__":
    for epoch in range(epochs):
        for step, batch in tqdm(enumerate(dataloader), desc='Training'):
            optimizer.zero_grad()
            batch = batch[0].to(device)
            batch_size = batch.shape[0]
            t = torch.randint(0, timesteps, (batch_size,), device=device).long()
            loss = p_losses(model, batch, t, loss_type="huber")
            if step % 50 == 0:
                print("Loss:", loss.item())
            loss.backward()
            optimizer.step()
            if step % save_and_sample_every == 0 and step != 0:
                # generate and save samples
                torch.save(model, "train.pt")

References

Diffusion Model Principles – http://www.egbenz.com/#/my_article/12

Detailed DDPM Explanation – https://zhuanlan.zhihu.com/p/582072317

DDPM Architecture – https://zhuanlan.zhihu.com/p/637815071

Transformer Positional Encoding – https://zhuanlan.zhihu.com/p/637815071

Additional resources – https://zhuanlan.zhihu.com/p/632809634

Pythondeep learningdiffusion modelAttentionDDPMU-Net
DaTaobao Tech
Written by

DaTaobao Tech

Official account of DaTaobao Technology

0 followers
Reader feedback

How this landed with the community

login Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.