Building a Simple Diffusion Model with Python
This tutorial walks through implementing a basic Denoising Diffusion Probabilistic Model in Python, explaining the forward noise schedule, reverse denoising training, and providing complete code for noise schedules, diffusion functions, residual and attention blocks, a UNet architecture, loss computation, and a training loop.
Diffusion models are the backbone of most AIGC image generation models. They learn to recover an image from Gaussian noise by progressively denoising it. This article demonstrates how to implement a simple diffusion model from scratch using Python.
Theory
DDPM (Denoising Diffusion Probabilistic Models) consists of two parts:
A fixed forward process that gradually adds Gaussian noise to an image until it becomes pure noise.
A learnable reverse process that trains a neural network to denoise step by step, starting from pure noise.
Forward (Noise) Process
The forward process repeatedly adds noise to an image. With K diffusion steps, the original data is transformed into a random noise matrix.
Training Process
The reverse process learns to remove the noise. The network receives a noisy image and predicts the noise component.
Environment Packages
!pip install -q -U einops datasets matplotlib tqdm
%matplotlib inline
import math, torch, torch.nn as nn, torch.nn.functional as F
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from tqdm.auto import tqdmNoise Schedules
# cosine schedule
def cosine_beta_schedule(timesteps, s=0.008):
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
# linear schedule
def linear_beta_schedule(timesteps):
beta_start, beta_end = 0.0001, 0.02
return torch.linspace(beta_start, beta_end, timesteps)
# quadratic schedule
def quadratic_beta_schedule(timesteps):
beta_start, beta_end = 0.0001, 0.02
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
# sigmoid schedule
def sigmoid_beta_schedule(timesteps):
beta_start, beta_end = 0.0001, 0.02
betas = torch.linspace(-6, 6, timesteps)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_startForward Diffusion Function
def q_sample(x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noiseResidual Block
class Block(nn.Module):
def __init__(self, dim, dim_out, groups=8):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift=None):
x = self.proj(x)
x = self.norm(x)
if scale_shift is not None:
scale, shift = scale_shift
x = x * (scale + 1) + shift
return self.act(x)
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
super().__init__()
self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) if time_emb_dim else None
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None):
h = self.block1(x)
if self.mlp is not None and time_emb is not None:
h = h + rearrange(self.mlp(time_emb), "b c -> b c 1 1")
h = self.block2(h)
return h + self.res_conv(x)Attention Mechanisms
class Attention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = [rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads) for t in qkv]
q = q * self.scale
sim = torch.einsum('b h d i, b h d j -> b h i j', q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = torch.einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
return self.to_out(out)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = [rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads) for t in qkv]
q = q.softmax(dim=-2) * self.scale
k = k.softmax(dim=-1)
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
return self.to_out(out)Time Embedding
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = time[:, None] * emb[None, :]
return torch.cat((emb.sin(), emb.cos()), dim=-1)U‑Net Architecture
class Unet(nn.Module):
def __init__(self, dim, init_dim=None, out_dim=None, dim_mults=(1,2,4,8),
channels=3, with_time_emb=True, resnet_block_groups=8,
use_convnext=True, convnext_mult=2):
super().__init__()
self.channels = channels
init_dim = init_dim or dim // 3 * 2
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
dims = [init_dim, *[dim * m for m in dim_mults]]
in_out = list(zip(dims[:-1], dims[1:]))
block_klass = partial(ConvNextBlock, mult=convnext_mult) if use_convnext else partial(ResnetBlock, groups=resnet_block_groups)
if with_time_emb:
time_dim = dim * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(dim),
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim),
)
else:
time_dim = None
self.time_mlp = None
self.downs = nn.ModuleList()
self.ups = nn.ModuleList()
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity(),
]))
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
self.ups.append(nn.ModuleList([
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity(),
]))
out_dim = out_dim or channels
self.final_conv = nn.Sequential(
block_klass(dim, dim),
nn.Conv2d(dim, out_dim, 1),
)
def forward(self, x, time):
x = self.init_conv(x)
t = self.time_mlp(time) if self.time_mlp is not None else None
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)
return self.final_conv(x)Loss Function
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
if noise is None:
noise = torch.randn_like(x_start)
x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
predicted_noise = denoise_model(x_noisy, t)
if loss_type == 'l1':
loss = F.l1_loss(noise, predicted_noise)
elif loss_type == 'l2':
loss = F.mse_loss(noise, predicted_noise)
elif loss_type == "huber":
loss = F.smooth_l1_loss(noise, predicted_noise)
else:
raise NotImplementedError()
return lossTraining Loop
if __name__ == "__main__":
for epoch in range(epochs):
for step, batch in tqdm(enumerate(dataloader), desc='Training'):
optimizer.zero_grad()
batch = batch[0].to(device)
batch_size = batch.shape[0]
t = torch.randint(0, timesteps, (batch_size,), device=device).long()
loss = p_losses(model, batch, t, loss_type="huber")
if step % 50 == 0:
print("Loss:", loss.item())
loss.backward()
optimizer.step()
if step % save_and_sample_every == 0 and step != 0:
# generate and save samples
torch.save(model, "train.pt")References
Diffusion Model Principles – http://www.egbenz.com/#/my_article/12
Detailed DDPM Explanation – https://zhuanlan.zhihu.com/p/582072317
DDPM Architecture – https://zhuanlan.zhihu.com/p/637815071
Transformer Positional Encoding – https://zhuanlan.zhihu.com/p/637815071
Additional resources – https://zhuanlan.zhihu.com/p/632809634
DaTaobao Tech
Official account of DaTaobao Technology
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.