Open In Colab

A Minimal Diffusion Model from Scratch#

Outline#

  • Setup & Imports

  • Generate Simple Data (e.g. 2D Gaussian blobs or MNIST)

  • Forward Process (Adding Noise)

  • Reverse Process (Learning to Denoise)

  • Sampling from Noise

  • Visualization & Insights

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import matplotlib.animation as animation

Generate Simple Data#

Let’s use 2D Gaussians so we can visualize easily:

# Generate 2D Gaussian data (2 clusters)
def generate_data(n=1000):
    mean1 = [2, 2]
    mean2 = [-2, -2]
    cov = [[0.1, 0], [0, 0.1]]
    data1 = np.random.multivariate_normal(mean1, cov, n // 2)
    data2 = np.random.multivariate_normal(mean2, cov, n // 2)
    data = np.vstack([data1, data2])
    np.random.shuffle(data)
    return torch.tensor(data, dtype=torch.float32)

data = generate_data()
plt.scatter(data[:, 0], data[:, 1], alpha=0.5)
plt.title("Toy 2D Data")
plt.show()

Forward Process (Diffusion)#

We define a noise schedule and simulate noisy steps.

T = 100  # number of diffusion steps
beta = torch.linspace(1e-4, 0.02, T)
alpha = 1. - beta
alpha_hat = torch.cumprod(alpha, dim=0)

def q_sample(x0, t):
    noise = torch.randn_like(x0)
    sqrt_alpha_hat = alpha_hat[t].sqrt().unsqueeze(1)
    sqrt_one_minus = (1 - alpha_hat[t]).sqrt().unsqueeze(1)
    return sqrt_alpha_hat * x0 + sqrt_one_minus * noise, noise
fig, ax = plt.subplots()
sc = ax.scatter(data[:, 0], data[:, 1], alpha=0.5)
ax.set_xlim(-4, 4)
ax.set_ylim(-4, 4)

frames = []
for t in range(0, T, 5):
    xt, _ = q_sample(data, torch.tensor([t] * data.shape[0]))
    frames.append(xt.numpy())

def update(frame):
    sc.set_offsets(frame)
    return sc,

ani = animation.FuncAnimation(fig, update, frames=frames, interval=100)
plt.close()
from IPython.display import HTML
HTML(ani.to_jshtml())

Reverse Process (Train a Denoiser)#

A simple MLP to predict noise:

class DenoiseMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2 + 1, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 2),
        )
        
    def forward(self, x, t):
        t_embed = t.float().unsqueeze(1) / T
        x_input = torch.cat([x, t_embed], dim=1)
        return self.net(x_input)

model = DenoiseMLP()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

Training Loop#

dataset = TensorDataset(data)
loader = DataLoader(dataset, batch_size=128, shuffle=True)

for epoch in range(100):
    for batch, in loader:
        t = torch.randint(0, T, (batch.shape[0],))
        x_t, noise = q_sample(batch, t)
        pred_noise = model(x_t, t)
        loss = F.mse_loss(pred_noise, noise)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: loss = {loss.item():.4f}")

Sample#

The function p_sample implements one reverse diffusion step:

Given a noisy input \(x_t\) at timestep \(t\), the model predicts the noise \(\epsilon_\theta(x_t, t)\) and estimates the mean of the original clean sample \(x_0\).

From the diffusion process:

\[ x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon [0] \]

Solving for \(x_0\):

\[ \hat{x}_0 = \frac{1}{\sqrt{\bar{\alpha}_t}} (x_t - \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon_\theta(x_t, t)) [1] \]

We can then define the mean for the reverse distribution \(p(x_{t-1} | x_t)\) as:

\[ \mu_t(x_t, \epsilon_\theta) = \frac{1}{\sqrt{\alpha_t}} \left(x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \cdot \epsilon_\theta(x_t, t)\right) [2] \]

And sample from:

\[ x_{t-1} \sim \mathcal{N}(\mu_t, \sigma_t^2 I), \quad \text{where } \sigma_t^2 = \beta_t [3] \]

This implements the denoising step for one timestep.

@torch.no_grad()
def p_sample(model, x, t):
    beta_t = beta[t]
    sqrt_one_minus = (1 - alpha_hat[t]).sqrt()
    sqrt_recip_alpha = (1. / alpha[t]).sqrt()

    pred_noise = model(x, torch.tensor([t] * x.shape[0]))
    x0_pred = (x - sqrt_one_minus * pred_noise) / alpha_hat[t].sqrt() # Eqn 1 above
    mean = sqrt_recip_alpha * (x - beta_t / sqrt_one_minus * pred_noise) # Eqn 2 above
    if t > 0:
        z = torch.randn_like(x)
    else:
        z = 0
    return mean + beta_t.sqrt() * z # Eqn 3 above

@torch.no_grad()
def sample(model, n_samples=1000, return_trajectory=False):
    x = torch.randn(n_samples, 2)
    trajectory = [x.clone()]
    for t in reversed(range(T)):
        x = p_sample(model, x, t)
        if return_trajectory:
            trajectory.append(x.clone())
    if return_trajectory:
        return x, trajectory
    return x

samples, trajectory = sample(model, return_trajectory=True)

plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)
plt.title("Generated Samples")
plt.axis("equal")
plt.show()

Visualizing the Reverse Process#

fig, ax = plt.subplots()
sc = ax.scatter([], [], alpha=0.5)
ax.set_xlim(-4, 4)
ax.set_ylim(-4, 4)
ax.set_title("Reverse Process")
ax.set_aspect("equal")

def animate(i):
    x = trajectory[i]  # reverse order
    sc.set_offsets(x.numpy())
    ax.set_title(f"Step {i}/{T}")
    return sc,

ani = animation.FuncAnimation(fig, animate, frames=T, interval=60, blit=True)
plt.close()
HTML(ani.to_jshtml())
len(trajectory)