A Minimal Diffusion Model from Scratch#
Outline#
Setup & Imports
Generate Simple Data (e.g. 2D Gaussian blobs or MNIST)
Forward Process (Adding Noise)
Reverse Process (Learning to Denoise)
Sampling from Noise
Visualization & Insights
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import matplotlib.animation as animation
Generate Simple Data#
Let’s use 2D Gaussians so we can visualize easily:
# Generate 2D Gaussian data (2 clusters)
def generate_data(n=1000):
mean1 = [2, 2]
mean2 = [-2, -2]
cov = [[0.1, 0], [0, 0.1]]
data1 = np.random.multivariate_normal(mean1, cov, n // 2)
data2 = np.random.multivariate_normal(mean2, cov, n // 2)
data = np.vstack([data1, data2])
np.random.shuffle(data)
return torch.tensor(data, dtype=torch.float32)
data = generate_data()
plt.scatter(data[:, 0], data[:, 1], alpha=0.5)
plt.title("Toy 2D Data")
plt.show()
Forward Process (Diffusion)#
We define a noise schedule and simulate noisy steps.
T = 100 # number of diffusion steps
beta = torch.linspace(1e-4, 0.02, T)
alpha = 1. - beta
alpha_hat = torch.cumprod(alpha, dim=0)
def q_sample(x0, t):
noise = torch.randn_like(x0)
sqrt_alpha_hat = alpha_hat[t].sqrt().unsqueeze(1)
sqrt_one_minus = (1 - alpha_hat[t]).sqrt().unsqueeze(1)
return sqrt_alpha_hat * x0 + sqrt_one_minus * noise, noise
fig, ax = plt.subplots()
sc = ax.scatter(data[:, 0], data[:, 1], alpha=0.5)
ax.set_xlim(-4, 4)
ax.set_ylim(-4, 4)
frames = []
for t in range(0, T, 5):
xt, _ = q_sample(data, torch.tensor([t] * data.shape[0]))
frames.append(xt.numpy())
def update(frame):
sc.set_offsets(frame)
return sc,
ani = animation.FuncAnimation(fig, update, frames=frames, interval=100)
plt.close()
from IPython.display import HTML
HTML(ani.to_jshtml())
Reverse Process (Train a Denoiser)#
A simple MLP to predict noise:
class DenoiseMLP(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(2 + 1, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 2),
)
def forward(self, x, t):
t_embed = t.float().unsqueeze(1) / T
x_input = torch.cat([x, t_embed], dim=1)
return self.net(x_input)
model = DenoiseMLP()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
Training Loop#
dataset = TensorDataset(data)
loader = DataLoader(dataset, batch_size=128, shuffle=True)
for epoch in range(100):
for batch, in loader:
t = torch.randint(0, T, (batch.shape[0],))
x_t, noise = q_sample(batch, t)
pred_noise = model(x_t, t)
loss = F.mse_loss(pred_noise, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}: loss = {loss.item():.4f}")
Sample#
The function p_sample
implements one reverse diffusion step:
Given a noisy input \(x_t\) at timestep \(t\), the model predicts the noise \(\epsilon_\theta(x_t, t)\) and estimates the mean of the original clean sample \(x_0\).
From the diffusion process:
\[
x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon [0]
\]
Solving for \(x_0\):
\[
\hat{x}_0 = \frac{1}{\sqrt{\bar{\alpha}_t}} (x_t - \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon_\theta(x_t, t)) [1]
\]
We can then define the mean for the reverse distribution \(p(x_{t-1} | x_t)\) as:
\[
\mu_t(x_t, \epsilon_\theta) = \frac{1}{\sqrt{\alpha_t}} \left(x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \cdot \epsilon_\theta(x_t, t)\right) [2]
\]
And sample from:
\[
x_{t-1} \sim \mathcal{N}(\mu_t, \sigma_t^2 I), \quad \text{where } \sigma_t^2 = \beta_t [3]
\]
This implements the denoising step for one timestep.
@torch.no_grad()
def p_sample(model, x, t):
beta_t = beta[t]
sqrt_one_minus = (1 - alpha_hat[t]).sqrt()
sqrt_recip_alpha = (1. / alpha[t]).sqrt()
pred_noise = model(x, torch.tensor([t] * x.shape[0]))
x0_pred = (x - sqrt_one_minus * pred_noise) / alpha_hat[t].sqrt() # Eqn 1 above
mean = sqrt_recip_alpha * (x - beta_t / sqrt_one_minus * pred_noise) # Eqn 2 above
if t > 0:
z = torch.randn_like(x)
else:
z = 0
return mean + beta_t.sqrt() * z # Eqn 3 above
@torch.no_grad()
def sample(model, n_samples=1000, return_trajectory=False):
x = torch.randn(n_samples, 2)
trajectory = [x.clone()]
for t in reversed(range(T)):
x = p_sample(model, x, t)
if return_trajectory:
trajectory.append(x.clone())
if return_trajectory:
return x, trajectory
return x
samples, trajectory = sample(model, return_trajectory=True)
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)
plt.title("Generated Samples")
plt.axis("equal")
plt.show()
Visualizing the Reverse Process#
fig, ax = plt.subplots()
sc = ax.scatter([], [], alpha=0.5)
ax.set_xlim(-4, 4)
ax.set_ylim(-4, 4)
ax.set_title("Reverse Process")
ax.set_aspect("equal")
def animate(i):
x = trajectory[i] # reverse order
sc.set_offsets(x.numpy())
ax.set_title(f"Step {i}/{T}")
return sc,
ani = animation.FuncAnimation(fig, animate, frames=T, interval=60, blit=True)
plt.close()
HTML(ani.to_jshtml())
len(trajectory)