Diffusion Fundamentals: DDPM, Score-Based Models, NCSN, and DDIM
Diffusion Fundamentals: DDPM, Score-Based Models, NCSN, and DDIM¶
This notebook introduces diffusion models from the ground up for readers who want both intuition and working implementations.
Aims¶
build physical intuition from energy landscapes, stochastic motion, and Langevin dynamics,
connect that intuition to the forward-noising / reverse-denoising recipe,
compare three concrete model families on the same toy dataset: NCSN / score-based models, DDPM, and DDIM.
Learning outcomes¶
By the end you should be able to:
explain what the score is and why Langevin dynamics matters,
describe what NCSN, DDPM, and DDIM learn and how they sample,
say clearly what DDPM and DDIM share and what changes between them,
recognize how the same logic carries over to crystals and materials.
If you want the applied crystal notebook next, open crystal-diffusion-from-scratch.ipynb.
Runtime expectations¶
The warm-up sections and plots run quickly on CPU.
The toy diffusion-model training sections are still real optimization loops, so expect roughly 10--20 minutes on CPU or faster on GPU for a full top-to-bottom first pass.
If you only want the conceptual comparison, you can stop after the plots and short model runs in Part 3.
Key references for this notebook¶
This notebook is pedagogical, but the model families and notation come from a small set of canonical papers:
Ho, Jain, and Abbeel, Denoising Diffusion Probabilistic Models (DDPM): https://
arxiv .org /abs /2006 .11239 Song and Ermon, Generative Modeling by Estimating Gradients of the Data Distribution (NCSN): https://
arxiv .org /abs /1907 .05600 Song et al., Score-Based Generative Modeling through Stochastic Differential Equations: https://
arxiv .org /abs /2011 .13456 Song, Meng, and Ermon, Denoising Diffusion Implicit Models (DDIM): https://
arxiv .org /abs /2010 .02502
The code here is intentionally small and CPU-friendly, so it illustrates the core ideas rather than reproducing the full training setups from the papers.
Table of Contents¶
If you already know the ODE/SDE warm-up, you can jump straight to Part 3 for the concrete model comparison.
from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes._axes import Axes
import torch
import torch.distributions as D
from torch.func import vmap, jacrev
from tqdm import tqdm
import seaborn as sns
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')Part 1: Energy landscapes, drift, and stochastic motion¶
Let us make precise the central objects of study: ordinary differential equations (ODEs) and stochastic differential equations (SDEs).
In this notebook, you can interpret the time-dependent score field
as an effective velocity field in configuration space. In chemistry language, this is analogous to saying:
the system is currently at configuration ,
at time ,
and the local geometry of the landscape tells us which way the configuration wants to move.
An ODE is then
This is the deterministic case: once the initial condition is fixed, the trajectory is fixed. Conceptually, this is like continuous-time relaxation or steepest-descent-style motion on an energy landscape.
An SDE is
which adds a stochastic term driven by Brownian motion . Here:
is the drift,
is the diffusion coefficient,
and the stochastic term plays the role of thermal agitation / random kicks.
So the mental model is:
ODE = deterministic relaxation
SDE = deterministic relaxation + thermal/random fluctuations
That same language will become very useful later, because diffusion models can be viewed as learning how to reverse a noising process in exactly this SDE/ODE framework.
Translation guide: math language chemistry / materials language¶
| Math / generative modeling | Chemistry / materials intuition |
|---|---|
| state | configuration, coordinate, structure descriptor |
| drift | deterministic velocity / force-induced motion |
| diffusion coefficient | noise strength, temperature-like stochasticity |
| Brownian motion | free diffusion |
| OU process | diffusion in a harmonic basin |
| (score) | force-like field toward high-probability / low-energy regions |
| Langevin dynamics | noisy force-driven sampling |
| stationary distribution | equilibrium distribution |
You do not need to think of these as separate worlds. The same equations appear in both.
class ODE(ABC):
@abstractmethod
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Returns the drift coefficient of the ODE.
Args:
- xt: state at time t, shape (bs, dim)
- t: time, shape ()
Returns:
- drift_coefficient: shape (batch_size, dim)
"""
pass
class SDE(ABC):
@abstractmethod
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Returns the drift coefficient of the ODE.
Args:
- xt: state at time t, shape (batch_size, dim)
- t: time, shape ()
Returns:
- drift_coefficient: shape (batch_size, dim)
"""
pass
@abstractmethod
def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Returns the diffusion coefficient of the ODE.
Args:
- xt: state at time t, shape (batch_size, dim)
- t: time, shape ()
Returns:
- diffusion_coefficient: shape (batch_size, dim)
"""
passNote. You can indeed view an ODE as the zero-noise limit of an SDE. We will still treat them separately here, because it keeps the simulation schemes and the physical intuition cleaner.
Part 1A: Numerical integration as repeated local updates¶
In practice, we almost never solve these equations analytically. Instead, we integrate them numerically, exactly as you would do for equations of motion in simulation.
If we think of as the current configuration, then an ODE says:
update the configuration according to the local drift field.
For the ODE
the explicit Euler step is
where is the step size.
For the SDE
the Euler--Maruyama step is
where .
This is the stochastic analogue of a basic explicit integrator: deterministic motion from the drift, plus a random increment whose scale grows like . If you come from Brownian dynamics or overdamped Langevin dynamics, this form should look very familiar.
class Simulator(ABC):
@abstractmethod
def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor):
"""
Takes one simulation step
Args:
- xt: state at time t, shape (batch_size, dim)
- t: time, shape ()
- dt: time, shape ()
Returns:
- nxt: state at time t + dt
"""
pass
@torch.no_grad()
def simulate(self, x: torch.Tensor, ts: torch.Tensor):
"""
Simulates using the discretization gives by ts
Args:
- x_init: initial state at time ts[0], shape (batch_size, dim)
- ts: timesteps, shape (nts,)
Returns:
- x_fina: final state at time ts[-1], shape (batch_size, dim)
"""
for t_idx in range(len(ts) - 1):
t = ts[t_idx]
h = ts[t_idx + 1] - ts[t_idx]
x = self.step(x, t, h)
return x
@torch.no_grad()
def simulate_with_trajectory(self, x: torch.Tensor, ts: torch.Tensor):
"""
Simulates using the discretization gives by ts
Args:
- x_init: initial state at time ts[0], shape (bs, dim)
- ts: timesteps, shape (num_timesteps,)
Returns:
- xs: trajectory of xts over ts, shape (batch_size, num_timesteps, dim)
"""
xs = [x.clone()]
for t_idx in tqdm(range(len(ts) - 1)):
t = ts[t_idx]
h = ts[t_idx + 1] - ts[t_idx]
x = self.step(x, t, h)
xs.append(x.clone())
return torch.stack(xs, dim=1)Question 1.1: Implement explicit integrators for deterministic and stochastic dynamics¶
Try first: before running the next two code cells, write down the Euler and Euler-Maruyama updates in a scratch cell or on paper.
Connect each term to the physical picture:
deterministic displacement from the drift,
a Gaussian kick for the stochastic case,
and the scaling because a Wiener increment has variance .
The worked solution is already present in the code cells below so the notebook still runs cleanly.
Worked answer
For an ODE, the explicit Euler update is
For an SDE
X_{t+h} = X_t + b(X_t,t)h + \sigma(X_t,t)\sqrt{h}\,\xi,
\qquad \xi \sim \mathcal{N}(0, I).That is exactly what the next two code cells implement.
class EulerSimulator(Simulator):
def __init__(self, ode: ODE):
self.ode = ode
def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor):
return xt + self.ode.drift_coefficient(xt,t) * hclass EulerMaruyamaSimulator(Simulator):
def __init__(self, sde: SDE):
self.sde = sde
def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor):
return xt + self.sde.drift_coefficient(xt,t) * h + self.sde.diffusion_coefficient(xt,t) * torch.sqrt(h) * torch.randn_like(xt)Note. When the diffusion coefficient is zero, Euler--Maruyama reduces to ordinary Euler. In other words, the stochastic simulator collapses back to a purely deterministic integrator.
Part 1B: Brownian motion and Ornstein--Uhlenbeck intuition¶
We now build intuition for two especially important SDEs.
From a chemistry/materials viewpoint, these are good toy models for motion on simple landscapes:
Brownian motion = diffusion on a flat landscape with no restoring force,
Ornstein--Uhlenbeck (OU) = diffusion in a harmonic basin with a linear restoring force.
These two examples already contain most of the qualitative ingredients we care about later: random exploration, drift toward preferred regions, and approach to an equilibrium distribution.
Question 2.1: Brownian motion as free diffusion¶
Brownian motion is the case with no drift and constant diffusion:
You can think of this as motion on a perfectly flat potential energy surface. There is no force pulling the particle anywhere in particular; it only wanders because of random kicks.
Try first: before running the Brownian-motion code, predict the trajectories qualitatively.
How should the paths look when sigma is very large? What about when sigma is close to zero?
Suggested answer
Large sigma gives rough, rapidly spreading trajectories because every timestep receives a stronger random kick. When sigma is close to zero, the paths barely move and stay near the initial condition.
Try first: pause and decide what the drift and diffusion should be for Brownian motion.
In physical language there is no restoring force, so the drift is zero, while the noise strength is spatially uniform.
The worked implementation is already in the next code cell so the notebook remains runnable.
Worked answer
Brownian motion satisfies
torch.zeros_like(xt) for the drift and self.sigma * torch.ones_like(xt) for the diffusion.class BrownianMotion(SDE):
def __init__(self, sigma: float):
self.sigma = sigma
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Returns the drift coefficient of the ODE.
Args:
- xt: state at time t, shape (bs, dim)
- t: time, shape ()
Returns:
- drift: shape (bs, dim)
"""
return torch.zeros_like(xt)
def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Returns the diffusion coefficient of the ODE.
Args:
- xt: state at time t, shape (bs, dim)
- t: time, shape ()
Returns:
- diffusion: shape (bs, dim)
"""
return self.sigma * torch.ones_like(xt)Now let us plot trajectories. As you inspect the figures, try to think simultaneously in two ways:
as individual paths in time, and
as an evolving ensemble of configurations.
That ensemble viewpoint is the one that later becomes central for generative modeling.
def plot_trajectories_1d(x0: torch.Tensor, simulator: Simulator, timesteps: torch.Tensor, ax: Optional[Axes] = None, show_hist: bool = False, decouple_hist_axis: bool = False):
"""
Graphs the trajectories of a one-dimensional SDE with given initial values (x0) and simulation timesteps (timesteps).
Args:
- x0: state at time t, shape (num_trajectories, 1)
- simulator: Simulator object used to simulate
- t: timesteps to simulate along, shape (num_timesteps,)
- ax: pyplot Axes object to plot on
- decouple_hist_axis: if True, do not share y-axis between trajectories and histogram
"""
if ax is None:
ax = plt.gca()
trajectories = simulator.simulate_with_trajectory(x0, timesteps) # (num_trajectories, num_timesteps, ...)
line_color = sns.color_palette("crest", 1)[0]
hist_color = sns.color_palette("flare", 1)[0]
label_size = 12
tick_size = 10
timesteps_cpu = timesteps.detach().cpu().numpy()
for trajectory_idx in range(trajectories.shape[0]):
trajectory = trajectories[trajectory_idx, :, 0].detach().cpu().numpy() # (num_timesteps,)
sns.lineplot(
x=timesteps_cpu,
y=trajectory,
ax=ax,
color=line_color,
alpha=0.45,
linewidth=1.1,
legend=False,
)
ax.set_xlabel(r"time ($t$)", fontsize=label_size)
ax.set_ylabel(r"$X_t$", fontsize=label_size)
ax.tick_params(axis='both', labelsize=tick_size)
ax.grid(alpha=0.2, linewidth=0.6)
if show_hist:
terminal_points = trajectories[:, -1, 0].detach().cpu().numpy()
data_range = float(terminal_points.max() - terminal_points.min()) if terminal_points.size else 1.0
binwidth = max(data_range / 25.0, 0.05)
from mpl_toolkits.axes_grid1 import make_axes_locatable
divider = make_axes_locatable(ax)
sharey = None if decouple_hist_axis else ax
hist_ax = divider.append_axes("right", size="22%", pad=0.45, sharey=sharey)
sns.histplot(
y=terminal_points,
ax=hist_ax,
binwidth=binwidth,
color=hist_color,
alpha=0.7,
edgecolor="white",
linewidth=0.5,
)
hist_ax.set_xlabel("count", fontsize=label_size)
hist_ax.set_ylabel("")
hist_ax.tick_params(axis='both', labelsize=tick_size)
if decouple_hist_axis:
hist_ax.tick_params(axis='y', left=True, labelleft=True)
else:
hist_ax.tick_params(axis='y', left=False, labelleft=False)
hist_ax.grid(axis='x', alpha=0.2, linewidth=0.6)
fig = ax.figure
if fig is not None:
title = ax.get_title()
if title:
title_size = ax.title.get_fontsize()
ax.set_title("")
axes = [ax]
if show_hist:
axes.append(hist_ax)
fig.canvas.draw()
bboxes = [a.get_position() for a in axes]
left = min(b.x0 for b in bboxes)
right = max(b.x1 for b in bboxes)
top = max(b.y1 for b in bboxes)
x_center = 0.5 * (left + right)
y = top + 0.005
fig.text(
x_center,
y,
title,
ha="center",
va="bottom",
fontsize=title_size,
)
sigma = 1.0
n_traj = 500
brownian_motion = BrownianMotion(sigma)
simulator = EulerMaruyamaSimulator(sde=brownian_motion)
x0 = torch.zeros(n_traj,1).to(device) # Initial values - let's start at zero
ts = torch.linspace(0.0,5.0,500).to(device) # simulation timesteps
plt.figure(figsize=(9, 6))
ax = plt.gca()
ax.set_title(r'Trajectories of Brownian Motion with $\sigma=$' + str(sigma), fontsize=18)
ax.set_xlabel(r'time ($t$)', fontsize=18)
ax.set_ylabel(r'$x_t$', fontsize=18)
plot_trajectories_1d(x0, simulator, ts, ax, show_hist=True)
plt.show()100%|██████████| 499/499 [00:00<00:00, 3537.59it/s]

Your job: What happens when you vary the value of sigma?
Your answer:
Answer
Increasing sigma increases the roughness of each trajectory and broadens the distribution of terminal positions. In other words, stronger noise explores configuration space more aggressively.
Question 2.2: Ornstein--Uhlenbeck as diffusion in a harmonic well¶
The OU process is
This is one of the cleanest SDEs to interpret physically. The drift term is a linear restoring force that pulls the system back toward the origin, so you can view it as motion in a quadratic potential
Brownian motion wandered on a flat landscape; the OU process wanders in a basin.
Try first: predict the qualitative behavior before you run the OU-process code.
What should happen when theta is very small? What about when theta is very large?
Suggested answer
When theta is very small, the restoring force is weak, so the process behaves more like wandering Brownian motion. When theta is large, the harmonic well is stiff, so trajectories are pulled back toward the origin much more aggressively.
Try first: write down the OU drift and diffusion before you inspect the next code cell.
Interpretation:
thetacontrols how stiff the harmonic well is,sigmacontrols how strongly thermal noise kicks the system around.
The worked implementation is already given below so execution stays intact.
Worked answer
The Ornstein-Uhlenbeck process is
-self.theta * xt for the drift and self.sigma * torch.ones_like(xt) for the diffusion.class OUProcess(SDE):
def __init__(self, theta: float, sigma: float):
self.theta = theta
self.sigma = sigma
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Returns the drift coefficient of the SDE.
Args:
- xt: state at time t, shape (bs, dim)
- t: time, shape ()
Returns:
- drift: shape (bs, dim)
"""
return - self.theta * xt
def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Returns the diffusion coefficient of the SDE.
Args:
- xt: state at time t, shape (bs, dim)
- t: time, shape ()
Returns:
- diffusion: shape (bs, dim)
"""
return self.sigma * torch.ones_like(xt)# Try comparing multiple choices side-by-side
thetas_and_sigmas = [
(0.25, 0.0),
(0.25, 0.5),
(0.25, 2.0),
]
simulation_time = 10.0
num_plots = len(thetas_and_sigmas)
fig, axes = plt.subplots(2, num_plots, figsize=(10.5 * num_plots, 15))
# Top row: dynamics
n_traj = 10
for idx, (theta, sigma) in enumerate(thetas_and_sigmas):
ou_process = OUProcess(theta, sigma)
simulator = EulerMaruyamaSimulator(sde=ou_process)
x0 = torch.linspace(-10.0,10.0,n_traj).view(-1,1).to(device) # Initial values - let's start at zero
ts = torch.linspace(0.0,simulation_time,1000).to(device) # simulation timesteps
ax = axes[0,idx]
ax.set_title(f'Trajectories of OU Process with $\\sigma = ${sigma}, $\\theta = ${theta}', fontsize=15)
plot_trajectories_1d(x0, simulator, ts, ax, show_hist=False)
# Bottom row: distribution
n_traj = 500
for idx, (theta, sigma) in enumerate(thetas_and_sigmas):
ou_process = OUProcess(theta, sigma)
simulator = EulerMaruyamaSimulator(sde=ou_process)
x0 = torch.linspace(-10.0,10.0,n_traj).view(-1,1).to(device) # Initial values - let's start at zero
ts = torch.linspace(0.0,simulation_time,1000).to(device) # simulation timesteps
ax = axes[1,idx]
ax.set_title(f'Trajectories of OU Process with $\\sigma = ${sigma}, $\\theta = ${theta}', fontsize=15)
ax = plot_trajectories_1d(x0, simulator, ts, ax, show_hist=True, decouple_hist_axis=True)
plt.show()100%|██████████| 999/999 [00:00<00:00, 6902.13it/s]
100%|██████████| 999/999 [00:00<00:00, 6641.86it/s]
100%|██████████| 999/999 [00:00<00:00, 6528.63it/s]
100%|██████████| 999/999 [00:00<00:00, 6457.51it/s]
100%|██████████| 999/999 [00:00<00:00, 5602.67it/s]
100%|██████████| 999/999 [00:00<00:00, 5338.44it/s]

Your job: What do you notice about the long-time behavior? Are the trajectories converging to a single point, or to a distribution?
Give two qualitative sentences of the form:
“When ( or ) goes (up or down), we see ...”
Hint. Keep an eye on the ratio
From a statistical mechanics perspective, this ratio plays the role of the equilibrium variance of the stationary Gaussian.
Your answer:
Answer
When \theta goes up (resp. down), the restoring force becomes stronger (resp. weaker), so the stationary distribution narrows (resp. broadens). When \sigma goes up (resp. down), the thermal kicks become stronger (resp. weaker), so the stationary distribution broadens (resp. narrows).
# Let's compare various OU processes!
sigmas = [1.0, 2.0, 10.0]
ds = [0.25, 1.0, 4.0] # sigma**2 / 2t
simulation_time = 15.0
n_traj = 500
fig, axes = plt.subplots(len(ds), len(sigmas), figsize=(8 * len(sigmas), 8 * len(ds)))
axes = axes.reshape((len(ds), len(sigmas)))
for d_idx, d in enumerate(ds):
for s_idx, sigma in enumerate(sigmas):
theta = sigma**2 / 2 / d
ou_process = OUProcess(theta, sigma)
simulator = EulerMaruyamaSimulator(sde=ou_process)
x0 = torch.linspace(-20.0,20.0,n_traj).view(-1,1).to(device)
time_scale = sigma**2
ts = torch.linspace(0.0,simulation_time / time_scale,1000).to(device) # simulation timesteps
ax = axes[d_idx, s_idx]
ax.set_title(f'OU Trajectories with Sigma={sigma}, Theta={theta}, D={d}')
plot_trajectories_1d(x0=x0, simulator=simulator, timesteps=ts, ax=ax, show_hist=True, decouple_hist_axis=True)
ax.set_xlabel(r'$t$')
ax.set_ylabel(r'X_t')
plt.show()100%|██████████| 999/999 [00:00<00:00, 5707.40it/s]
100%|██████████| 999/999 [00:00<00:00, 5655.08it/s]
100%|██████████| 999/999 [00:00<00:00, 5306.04it/s]
100%|██████████| 999/999 [00:00<00:00, 6157.78it/s]
100%|██████████| 999/999 [00:00<00:00, 6192.30it/s]
100%|██████████| 999/999 [00:00<00:00, 5488.86it/s]
100%|██████████| 999/999 [00:00<00:00, 6189.90it/s]
100%|██████████| 999/999 [00:00<00:00, 5709.22it/s]
100%|██████████| 999/999 [00:00<00:00, 4310.93it/s]

Your job: What conclusion can we draw from the figure above? One qualitative sentence is fine. We will revisit this in Section 3.2.
Your answer:
Answer
For fixed D=\frac{\sigma^2}{2\theta}, changing \sigma mainly changes the mixing speed toward equilibrium.
For fixed \sigma, increasing D broadens the equilibrium distribution, i.e. the basin is explored more widely at stationarity.
Temperature intuition: particles in a double-well landscape¶
Before we jump to learned diffusion models, it helps to watch particles move on a simple energy surface.
We use an overdamped Langevin update in a double-well potential
so there are two preferred basins, one on the left and one on the right.
The update is
where plays the role of a temperature-like noise strength.
What to notice¶
Low temperature means the particle mostly rattles inside one basin.
Higher temperature means barrier-crossing becomes more common.
This is the same drift-plus-noise picture we will later reuse in reverse-time diffusion sampling.
import math
def double_well_potential(xy: torch.Tensor) -> torch.Tensor:
x = xy[..., 0]
y = xy[..., 1]
return 0.25 * (x ** 2 - 1.5) ** 2 + 0.7 * y ** 2
def double_well_grad(xy: torch.Tensor) -> torch.Tensor:
x = xy[..., 0]
y = xy[..., 1]
dUx = x * (x ** 2 - 1.5)
dUy = 1.4 * y
return torch.stack([dUx, dUy], dim=-1)
@torch.no_grad()
def simulate_double_well(
temperature: float = 0.10,
n_particles: int = 24,
n_steps: int = 180,
dt: float = 0.03,
seed: int = 0,
):
torch.manual_seed(seed)
start = torch.randn(n_particles, 2, device=device) * 0.18
start[:, 0] -= 1.1
x = start.clone()
traj = [x.detach().cpu()]
for _ in range(n_steps):
drift = -double_well_grad(x)
noise = math.sqrt(2.0 * temperature * dt) * torch.randn_like(x)
x = x + dt * drift + noise
traj.append(x.detach().cpu())
return torch.stack(traj, dim=0)
def plot_double_well_demo(
temperature: float = 0.10,
n_particles: int = 24,
n_steps: int = 180,
dt: float = 0.03,
seed: int = 0,
):
traj = simulate_double_well(
temperature=temperature,
n_particles=n_particles,
n_steps=n_steps,
dt=dt,
seed=seed,
)
grid_x = torch.linspace(-2.2, 2.2, 180)
grid_y = torch.linspace(-1.8, 1.8, 180)
xx, yy = torch.meshgrid(grid_x, grid_y, indexing='xy')
grid = torch.stack([xx, yy], dim=-1)
energy = double_well_potential(grid).numpy()
fig, axes = plt.subplots(1, 2, figsize=(12, 4.8))
axes[0].contourf(grid_x.numpy(), grid_y.numpy(), energy, levels=30, cmap='cividis')
for idx in range(min(n_particles, traj.shape[1])):
path = traj[:, idx, :]
axes[0].plot(path[:, 0], path[:, 1], alpha=0.45, linewidth=1.0)
axes[0].scatter(traj[0, :, 0], traj[0, :, 1], s=18, color='white', edgecolor='black', linewidth=0.4, label='start')
axes[0].scatter(traj[-1, :, 0], traj[-1, :, 1], s=22, color='tab:red', alpha=0.8, label='end')
axes[0].set_title(f'Double-well trajectories at T={temperature:.2f}')
axes[0].set_xlabel('$x$')
axes[0].set_ylabel('$y$')
axes[0].legend(loc='upper right')
time = np.arange(traj.shape[0]) * dt
for idx in range(min(8, traj.shape[1])):
axes[1].plot(time, traj[:, idx, 0], alpha=0.75, linewidth=1.1)
axes[1].axhline(0.0, color='black', linestyle='--', linewidth=0.8, alpha=0.5)
axes[1].set_title('The $x$ coordinate shows barrier-hopping directly')
axes[1].set_xlabel('time')
axes[1].set_ylabel('$x_t$')
axes[1].grid(alpha=0.2)
plt.tight_layout()
plt.show()
def plot_temperature_sweep(temperatures=(0.03, 0.12, 0.30), n_particles: int = 24, n_steps: int = 180, dt: float = 0.03):
grid_x = torch.linspace(-2.2, 2.2, 180)
grid_y = torch.linspace(-1.8, 1.8, 180)
xx, yy = torch.meshgrid(grid_x, grid_y, indexing='xy')
grid = torch.stack([xx, yy], dim=-1)
energy = double_well_potential(grid).numpy()
fig, axes = plt.subplots(1, len(temperatures), figsize=(4.6 * len(temperatures), 4.2))
if len(temperatures) == 1:
axes = [axes]
for ax, temp in zip(axes, temperatures):
traj = simulate_double_well(temp, n_particles=n_particles, n_steps=n_steps, dt=dt, seed=0)
ax.contourf(grid_x.numpy(), grid_y.numpy(), energy, levels=30, cmap='cividis')
for idx in range(min(n_particles, traj.shape[1])):
path = traj[:, idx, :]
ax.plot(path[:, 0], path[:, 1], alpha=0.35, linewidth=0.9)
ax.scatter(traj[-1, :, 0], traj[-1, :, 1], s=18, color='tab:red', alpha=0.75)
ax.set_title(f'T = {temp:.2f}')
ax.set_xticks([])
ax.set_yticks([])
plt.suptitle('As temperature rises, trajectories explore both basins more aggressively', y=1.03)
plt.tight_layout()
plt.show()# @title Explore the double-well demo
energy_demo_temperature = 0.10 # @param {type:"slider", min:0.02, max:0.40, step:0.02}
energy_demo_particles = 24 # @param {type:"slider", min:8, max:48, step:4}
energy_demo_steps = 180 # @param {type:"slider", min:80, max:260, step:20}
energy_demo_dt = 0.03 # @param {type:"slider", min:0.01, max:0.08, step:0.01}
energy_demo_seed = 0 # @param {type:"integer"}
plot_double_well_demo(
temperature=energy_demo_temperature,
n_particles=energy_demo_particles,
n_steps=energy_demo_steps,
dt=energy_demo_dt,
seed=energy_demo_seed,
)
plot_temperature_sweep(temperatures=(0.03, 0.12, 0.30), n_particles=24, n_steps=180, dt=0.03)
Exercise: temperature and barrier crossing¶
Task. In one or two sentences, explain why higher temperature produces more transitions between the two wells.
Answer
The deterministic drift still pulls particles downhill, but the stochastic term grows like \sqrt{T}, so larger thermal kicks make it easier to cross the barrier at x=0 and visit the other basin.
Why this notebook matters for diffusion models¶
In score-based diffusion models, we deliberately define a forward noising process that gradually destroys structure, and then learn a reverse-time dynamics that reconstructs it.
For someone in computational chemistry or materials science, a good first mental model is:
the forward process pushes configurations toward an easy reference ensemble,
the reverse process uses a learned force-like field to guide them back toward realistic structures,
and conditioning corresponds to biasing that dynamics toward a desired composition, property, or environment.
This notebook gives you the simulation language that those models are built from.
Part 2: Langevin dynamics and equilibrium sampling¶
class Density(ABC):
"""
Distribution with tractable density
"""
@abstractmethod
def log_density(self, x: torch.Tensor) -> torch.Tensor:
"""
Returns the log density at x.
Args:
- x: shape (batch_size, dim)
Returns:
- log_density: shape (batch_size, 1)
"""
pass
def score(self, x: torch.Tensor) -> torch.Tensor:
"""
Returns the score dx log density(x)
Args:
- x: (batch_size, dim)
Returns:
- score: (batch_size, dim)
"""
x = x.unsqueeze(1) # (batch_size, 1, ...)
score = vmap(jacrev(self.log_density))(x) # (batch_size, 1, 1, 1, ...)
return score.squeeze((1, 2, 3)) # (batch_size, ...)
class Sampleable(ABC):
"""
Distribution which can be sampled from
"""
@abstractmethod
def sample(self, num_samples: int) -> torch.Tensor:
"""
Returns the log density at x.
Args:
- num_samples: the desired number of samples
Returns:
- samples: shape (batch_size, dim)
"""
pass# Several plotting utility functions
def hist2d_sampleable(sampleable: Sampleable, num_samples: int, ax: Optional[Axes] = None, **kwargs):
if ax is None:
ax = plt.gca()
samples = sampleable.sample(num_samples) # (ns, 2)
ax.hist2d(samples[:,0].cpu(), samples[:,1].cpu(), **kwargs)
def scatter_sampleable(sampleable: Sampleable, num_samples: int, ax: Optional[Axes] = None, **kwargs):
if ax is None:
ax = plt.gca()
samples = sampleable.sample(num_samples) # (ns, 2)
ax.scatter(samples[:,0].cpu(), samples[:,1].cpu(), **kwargs)
def imshow_density(density: Density, bins: int, scale: float, ax: Optional[Axes] = None, **kwargs):
if ax is None:
ax = plt.gca()
x = torch.linspace(-scale, scale, bins).to(device)
y = torch.linspace(-scale, scale, bins).to(device)
X, Y = torch.meshgrid(x, y)
xy = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)
density = density.log_density(xy).reshape(bins, bins).T
im = ax.imshow(density.cpu(), extent=[-scale, scale, -scale, scale], origin='lower', **kwargs)
def contour_density(density: Density, bins: int, scale: float, ax: Optional[Axes] = None, **kwargs):
if ax is None:
ax = plt.gca()
x = torch.linspace(-scale, scale, bins).to(device)
y = torch.linspace(-scale, scale, bins).to(device)
X, Y = torch.meshgrid(x, y)
xy = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)
density = density.log_density(xy).reshape(bins, bins).T
im = ax.contour(density.cpu(), extent=[-scale, scale, -scale, scale], origin='lower', **kwargs)class Gaussian(torch.nn.Module, Sampleable, Density):
"""
Two-dimensional Gaussian. Is a Density and a Sampleable. Wrapper around torch.distributions.MultivariateNormal
"""
def __init__(self, mean, cov):
"""
mean: shape (2,)
cov: shape (2,2)
"""
super().__init__()
self.register_buffer("mean", mean)
self.register_buffer("cov", cov)
@property
def distribution(self):
return D.MultivariateNormal(self.mean, self.cov, validate_args=False)
def sample(self, num_samples) -> torch.Tensor:
return self.distribution.sample((num_samples,))
def log_density(self, x: torch.Tensor):
return self.distribution.log_prob(x).view(-1, 1)
class GaussianMixture(torch.nn.Module, Sampleable, Density):
"""
Two-dimensional Gaussian mixture model, and is a Density and a Sampleable. Wrapper around torch.distributions.MixtureSameFamily.
"""
def __init__(
self,
means: torch.Tensor, # nmodes x data_dim
covs: torch.Tensor, # nmodes x data_dim x data_dim
weights: torch.Tensor, # nmodes
):
"""
means: shape (nmodes, 2)
covs: shape (nmodes, 2, 2)
weights: shape (nmodes, 1)
"""
super().__init__()
self.nmodes = means.shape[0]
self.register_buffer("means", means)
self.register_buffer("covs", covs)
self.register_buffer("weights", weights)
@property
def dim(self) -> int:
return self.means.shape[1]
@property
def distribution(self):
return D.MixtureSameFamily(
mixture_distribution=D.Categorical(probs=self.weights, validate_args=False),
component_distribution=D.MultivariateNormal(
loc=self.means,
covariance_matrix=self.covs,
validate_args=False,
),
validate_args=False,
)
def log_density(self, x: torch.Tensor) -> torch.Tensor:
return self.distribution.log_prob(x).view(-1, 1)
def sample(self, num_samples: int) -> torch.Tensor:
return self.distribution.sample(torch.Size((num_samples,)))
@classmethod
def random_2D(
cls, nmodes: int, std: float, scale: float = 10.0, seed = 0.0
) -> "GaussianMixture":
torch.manual_seed(seed)
means = (torch.rand(nmodes, 2) - 0.5) * scale
covs = torch.diag_embed(torch.ones(nmodes, 2)) * std ** 2
weights = torch.ones(nmodes)
return cls(means, covs, weights)
@classmethod
def symmetric_2D(
cls, nmodes: int, std: float, scale: float = 10.0,
) -> "GaussianMixture":
angles = torch.linspace(0, 2 * np.pi, nmodes + 1)[:nmodes]
means = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1) * scale
covs = torch.diag_embed(torch.ones(nmodes, 2) * std ** 2)
weights = torch.ones(nmodes) / nmodes
return cls(means, covs, weights)# Visualize densities
densities = {
"Gaussian": Gaussian(mean=torch.zeros(2), cov=10 * torch.eye(2)).to(device),
"Random Mixture": GaussianMixture.random_2D(nmodes=5, std=1.0, scale=20.0, seed=3.0).to(device),
"Symmetric Mixture": GaussianMixture.symmetric_2D(nmodes=5, std=1.0, scale=8.0).to(device),
}
fig, axes = plt.subplots(1,3, figsize=(18, 6))
bins = 100
scale = 15
for idx, (name, density) in enumerate(densities.items()):
ax = axes[idx]
ax.set_title(name)
imshow_density(density, bins, scale, ax, vmin=-15, cmap=plt.get_cmap('Blues'))
contour_density(density, bins, scale, ax, colors='grey', linestyles='solid', alpha=0.25, levels=20)
plt.show()
/home/uccabaa/.local/lib/python3.12/site-packages/torch/functional.py:505: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /pytorch/aten/src/ATen/native/TensorShape.cpp:4381.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]

Question 3.1: Implement overdamped Langevin dynamics¶
In this section, we simulate the overdamped Langevin dynamics
If , then
so the drift term is proportional to a force. This is why Langevin dynamics is such a natural meeting point between statistical physics and modern generative modeling.
Try first: before running the class below, identify the drift and diffusion coefficients from the Langevin equation.
The worked implementation is already present in the next code cell so the notebook still runs properly.
Worked answer
Comparing
b(x,t) = \frac{1}{2}\sigma^2 \nabla \log p(x),
\qquad
a(x,t) = \sigma.class LangevinSDE(SDE):
def __init__(self, sigma: float, density: Density):
self.sigma = sigma
self.density = density
def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Returns the drift coefficient of the ODE.
Args:
- xt: state at time t, shape (bs, dim)
- t: time, shape ()
Returns:
- drift: shape (bs, dim)
"""
return 0.5 * self.sigma ** 2 * self.density.score(xt)
def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Returns the diffusion coefficient of the ODE.
Args:
- xt: state at time t, shape (bs, dim)
- t: time, shape ()
Returns:
- diffusion: shape (bs, dim)
"""
return self.sigma * torch.ones_like(xt)Now let us graph how a whole ensemble evolves under these dynamics.
As you look at the plots, keep the following question in mind:
If the score points toward high-density regions, how does the combination of drift + noise reshape an initial cloud of samples into the target distribution?
# First, let's define two utility functions...
def every_nth_index(num_timesteps: int, n: int) -> torch.Tensor:
"""
Compute the indices to record in the trajectory given a record_every parameter
"""
if n == 1:
return torch.arange(num_timesteps)
return torch.cat(
[
torch.arange(0, num_timesteps - 1, n),
torch.tensor([num_timesteps - 1]),
]
)
def graph_dynamics(
num_samples: int,
source_distribution: Sampleable,
simulator: Simulator,
density: Density,
timesteps: torch.Tensor,
plot_every: int,
bins: int,
scale: float
):
"""
Plot the evolution of samples from source under the simulation scheme given by simulator (itself a discretization of an ODE or SDE).
Args:
- num_samples: the number of samples to simulate
- source_distribution: distribution from which we draw initial samples at t=0
- simulator: the discertized simulation scheme used to simulate the dynamics
- density: the target density
- timesteps: the timesteps used by the simulator
- plot_every: number of timesteps between consecutive plots
- bins: number of bins for imshow
- scale: scale for imshow
"""
# Simulate
x0 = source_distribution.sample(num_samples)
xts = simulator.simulate_with_trajectory(x0, timesteps)
indices_to_plot = every_nth_index(len(timesteps), plot_every)
plot_timesteps = timesteps[indices_to_plot]
plot_xts = xts[:,indices_to_plot]
# Graph
fig, axes = plt.subplots(2, len(plot_timesteps), figsize=(8*len(plot_timesteps), 16))
axes = axes.reshape((2,len(plot_timesteps)))
for t_idx in range(len(plot_timesteps)):
t = plot_timesteps[t_idx].item()
xt = plot_xts[:,t_idx]
# Scatter axes
scatter_ax = axes[0, t_idx]
imshow_density(density, bins, scale, scatter_ax, vmin=-15, alpha=0.25, cmap=plt.get_cmap('Blues'))
scatter_ax.scatter(xt[:,0].cpu(), xt[:,1].cpu(), marker='x', color='black', alpha=0.75, s=15)
scatter_ax.set_title(f'Samples at t={t:.1f}', fontsize=15)
scatter_ax.set_xticks([])
scatter_ax.set_yticks([])
# Kdeplot axes
kdeplot_ax = axes[1, t_idx]
imshow_density(density, bins, scale, kdeplot_ax, vmin=-15, alpha=0.5, cmap=plt.get_cmap('Blues'))
sns.kdeplot(x=xt[:,0].cpu(), y=xt[:,1].cpu(), alpha=0.5, ax=kdeplot_ax,color='grey')
kdeplot_ax.set_title(f'Density of Samples at t={t:.1f}', fontsize=15)
kdeplot_ax.set_xticks([])
kdeplot_ax.set_yticks([])
kdeplot_ax.set_xlabel("")
kdeplot_ax.set_ylabel("")
plt.show()# Construct the simulator
target = GaussianMixture.random_2D(nmodes=5, std=0.75, scale=15.0, seed=3.0).to(device)
sde = LangevinSDE(sigma = 0.6, density = target)
simulator = EulerMaruyamaSimulator(sde)
# Graph the results!
graph_dynamics(
num_samples = 1000,
source_distribution = Gaussian(mean=torch.zeros(2), cov=20 * torch.eye(2)).to(device),
simulator=simulator,
density=target,
timesteps=torch.linspace(0,5.0,1000).to(device),
plot_every=334,
bins=200,
scale=15
)100%|██████████| 999/999 [00:04<00:00, 237.20it/s]

Your job: Try varying the value of sigma, the number and range of the simulation steps, the source distribution, and the target density. What do you notice? Why?
Your answer:
Answer
The sample cloud tends to relax toward the distribution used to define the score. In physics language, Langevin dynamics mixes toward its stationary distribution; in generative-model language, the learned score field steers noisy samples back toward the data manifold / high-probability region.
Optional. The next two cells make an animation. They require a working ffmpeg installation. If your environment does not already have it, you may need to install it separately (for example through conda-forge).
try:
from celluloid import Camera
except ImportError:
Camera = None
print('Optional animation skipped: `celluloid` is not installed in this environment.')
from IPython.display import HTML
from matplotlib import animation as mpl_animation
def animate_dynamics(
num_samples: int,
source_distribution: Sampleable,
simulator: Simulator,
density: Density,
timesteps: torch.Tensor,
animate_every: int,
bins: int,
scale: float,
save_path: str = 'dynamics_animation.gif'
):
"""Plot the evolution of samples from source under the simulation scheme given by simulator."""
if Camera is None:
return HTML('<b>Optional animation skipped:</b> install <code>celluloid</code> to render this cell.')
x0 = source_distribution.sample(num_samples)
xts = simulator.simulate_with_trajectory(x0, timesteps)
indices_to_animate = every_nth_index(len(timesteps), animate_every)
animate_timesteps = timesteps[indices_to_animate]
animate_xts = xts[:, indices_to_animate]
fig, axes = plt.subplots(1, 2, figsize=(16, 8))
camera = Camera(fig)
for t_idx in range(len(animate_timesteps)):
xt = animate_xts[:, t_idx]
scatter_ax = axes[0]
imshow_density(density, bins, scale, scatter_ax, vmin=-15, alpha=0.25, cmap=plt.get_cmap('Blues'))
scatter_ax.scatter(xt[:, 0].cpu(), xt[:, 1].cpu(), marker='x', color='black', alpha=0.75, s=15)
scatter_ax.set_title('Samples')
scatter_ax.set_xticks([])
scatter_ax.set_yticks([])
kdeplot_ax = axes[1]
imshow_density(density, bins, scale, kdeplot_ax, vmin=-15, alpha=0.5, cmap=plt.get_cmap('Blues'))
sns.kdeplot(x=xt[:, 0].cpu(), y=xt[:, 1].cpu(), alpha=0.5, ax=kdeplot_ax, color='grey')
kdeplot_ax.set_title('Density of Samples', fontsize=15)
kdeplot_ax.set_xticks([])
kdeplot_ax.set_yticks([])
kdeplot_ax.set_xlabel('')
kdeplot_ax.set_ylabel('')
camera.snap()
animation = camera.animate()
if mpl_animation.writers.is_available('ffmpeg'):
animation.save(save_path)
result = HTML(animation.to_html5_video())
else:
result = HTML(animation.to_jshtml())
plt.close()
return result# OPTIONAL CELL
# Construct the simulator
target = GaussianMixture.random_2D(nmodes=5, std=0.75, scale=15.0, seed=3.0).to(device)
sde = LangevinSDE(sigma=0.6, density=target)
simulator = EulerMaruyamaSimulator(sde)
animate_dynamics(
num_samples=1000,
source_distribution=Gaussian(mean=torch.zeros(2), cov=20 * torch.eye(2)).to(device),
simulator=simulator,
density=target,
timesteps=torch.linspace(0, 5.0, 1000).to(device),
bins=200,
scale=15,
animate_every=100,
)100%|██████████| 999/999 [00:05<00:00, 196.65it/s]
Question 3.2: Ornstein--Uhlenbeck as a special case of Langevin dynamics¶
We now make the connection completely explicit.
Recall:
Langevin dynamics:
Ornstein--Uhlenbeck:
This exercise shows that the OU process is exactly Langevin dynamics for a Gaussian target distribution. That is the cleanest possible example of the more general principle:
score = force-like field that defines how samples flow toward equilibrium.
Your job: Show that when
the score is
Hint. The Gaussian density is
Your answer:
Answer
From the hint,
This is exactly the linear score field associated with a harmonic basin.
Your job: Conclude that when
the Langevin dynamics
is equivalent to the OU process
Your answer:
Answer
Substitute the score from the previous part into the Langevin drift:
Takeaway for diffusion models¶
In chemistry, one often starts from an energy and obtains forces by differentiation.
In score-based generative modeling, we reverse the perspective:
we learn a force-like field (the score),
use it inside an SDE or ODE,
and thereby steer noise into structured samples.
That is the conceptual bridge from Brownian motion and Langevin dynamics to diffusion models for molecules, crystals, and materials configurations.
Transition: from Langevin intuition to diffusion models¶
Everything above was about dynamics on known vector fields or known densities. Diffusion models flip the problem around:
we define a simple forward corruption process,
we learn the reverse denoising direction from data,
and we sample by integrating that learned reverse-time dynamics.
The next sections compare three closely related views of that idea.
Why the names get confusing¶
These families overlap, so it is easy for them to blur together:
NCSN / score-based models learn the score field directly and sample with annealed Langevin updates.
DDPM learns a denoiser or noise predictor on a discrete forward diffusion chain and samples stochastically.
DDIM usually uses the same trained DDPM network, but swaps in a deterministic reverse sampler.
So the key differences are not just in the network, but in the forward process, the training target, and the sampler.
Part 3: A single toy dataset for all three model families¶
To keep the geometry visible, we will use one small 2D dataset throughout: a five-mode Gaussian mixture arranged like a cross.
That gives us a nice teaching setup:
it is simple enough to plot directly,
it trains quickly on CPU or Colab,
and bad samplers fail in a very visible way.
What to notice¶
If a model trains well, samples should cluster around the five modes.
If the reverse process is poor, samples either blur into the middle or explode away from the modes.
Because the dataset is shared, differences in the plots mostly come from the model family rather than from the data.
import math
toy_centers = torch.tensor(
[
[-2.5, 0.0],
[2.5, 0.0],
[0.0, 2.5],
[0.0, -2.5],
[0.0, 0.0],
],
dtype=torch.float32,
device=device,
)
NUM_DDPM_STEPS = 50
DDPM_BETAS = torch.linspace(1e-4, 0.035, NUM_DDPM_STEPS, device=device)
DDPM_ALPHAS = 1.0 - DDPM_BETAS
DDPM_ALPHA_BAR = torch.cumprod(DDPM_ALPHAS, dim=0)
NCSN_SIGMAS = torch.tensor([1.0, 0.6, 0.35, 0.2, 0.12, 0.07], dtype=torch.float32, device=device)
def sample_cross_data(num_samples: int, std: float = 0.25, seed: int = None):
if seed is not None:
torch.manual_seed(seed)
idx = torch.randint(0, len(toy_centers), (num_samples,), device=device)
return toy_centers[idx] + std * torch.randn(num_samples, 2, device=device)
def ddpm_q_sample(x0: torch.Tensor, t_idx: torch.Tensor, noise: torch.Tensor = None):
if noise is None:
noise = torch.randn_like(x0)
alpha_bar_t = DDPM_ALPHA_BAR[t_idx].unsqueeze(1)
xt = torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1.0 - alpha_bar_t) * noise
return xt, noise
class SinusoidalTimeEmbedding(torch.nn.Module):
def __init__(self, dim: int = 32):
super().__init__()
self.dim = dim
def forward(self, t: torch.Tensor) -> torch.Tensor:
half_dim = self.dim // 2
freqs = torch.exp(torch.linspace(math.log(1.0), math.log(40.0), half_dim, device=t.device))
angles = t * freqs.view(1, -1)
return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
class TinyDiffusionMLP(torch.nn.Module):
def __init__(self, hidden: int = 96, time_dim: int = 32):
super().__init__()
self.time_embed = SinusoidalTimeEmbedding(time_dim)
self.net = torch.nn.Sequential(
torch.nn.Linear(2 + time_dim, hidden),
torch.nn.SiLU(),
torch.nn.Linear(hidden, hidden),
torch.nn.SiLU(),
torch.nn.Linear(hidden, hidden),
torch.nn.SiLU(),
torch.nn.Linear(hidden, 2),
)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return self.net(torch.cat([x, self.time_embed(t)], dim=-1))
@torch.no_grad()
def plot_point_cloud(ax, points: torch.Tensor, title: str, centers: torch.Tensor = None):
pts = points.detach().cpu()
ax.scatter(pts[:, 0], pts[:, 1], s=9, alpha=0.45)
if centers is None:
centers = toy_centers
if centers is not None:
c = centers.detach().cpu()
ax.scatter(c[:, 0], c[:, 1], s=50, marker='x', color='black', linewidth=1.2)
ax.set_title(title)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect('equal')
@torch.no_grad()
def plot_score_quiver(ax, model: torch.nn.Module, noise_value: float, title: str, grid_scale: float = 4.5):
grid = torch.linspace(-grid_scale, grid_scale, 17, device=device)
xx, yy = torch.meshgrid(grid, grid, indexing='xy')
pts = torch.stack([xx.reshape(-1), yy.reshape(-1)], dim=-1)
t = torch.full((pts.shape[0], 1), noise_value, device=device)
scores = model(pts, t).detach().cpu()
ax.quiver(
pts[:, 0].cpu().numpy(),
pts[:, 1].cpu().numpy(),
scores[:, 0].numpy(),
scores[:, 1].numpy(),
angles='xy',
scale_units='xy',
scale=10,
width=0.003,
alpha=0.75,
color='tab:blue',
)
background = sample_cross_data(700, seed=0).detach().cpu()
ax.scatter(background[:, 0], background[:, 1], s=5, alpha=0.08, color='black')
ax.set_title(title)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect('equal')
@torch.no_grad()
def plot_reverse_paths(ax, history, title: str, max_paths: int = 10):
for idx in range(min(max_paths, history[0].shape[0])):
path = torch.stack([frame[idx] for frame in history], dim=0).cpu()
ax.plot(path[:, 0], path[:, 1], linewidth=1.0, alpha=0.7)
ax.scatter(path[0, 0], path[0, 1], color='black', s=14)
ax.scatter(path[-1, 0], path[-1, 1], color='tab:red', s=16)
ax.scatter(toy_centers[:, 0].cpu(), toy_centers[:, 1].cpu(), marker='x', color='black', s=45)
ax.set_title(title)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect('equal')Forward corruption preview¶
Before training anything, let us look at what the DDPM-style forward process actually does.
For a discrete timestep , the closed-form corruption is
What to notice¶
Early timesteps keep most of the original structure.
Late timesteps look much closer to an isotropic Gaussian cloud.
This is why the reverse sampler can start from simple noise: the forward process was designed to end there.
torch.manual_seed(0)
forward_data = sample_cross_data(900)
preview_steps = [0, 10, 25, 49]
fig, axes = plt.subplots(1, len(preview_steps) + 1, figsize=(15, 3.6))
plot_point_cloud(axes[0], forward_data, 'clean data')
for ax, step_idx in zip(axes[1:], preview_steps):
t_idx = torch.full((forward_data.shape[0],), step_idx, dtype=torch.long, device=device)
xt, _ = ddpm_q_sample(forward_data, t_idx)
plot_point_cloud(ax, xt, f't = {step_idx}')
plt.suptitle('The same data distribution becomes progressively easier as noise increases', y=1.05)
plt.tight_layout()
plt.show()
Part 3A: NCSN and score-based models¶
A noise-conditional score network learns the score of progressively noisier versions of the data:
In practice we sample a clean point , add Gaussian noise,
and train the network to predict the denoising direction
How it works¶
The network is told the noise level explicitly.
At large , the score field is broad and global.
At small , the score field becomes local and sharp.
Sampling uses annealed Langevin dynamics: alternate score steps and fresh noise injections while gradually lowering .
How it differs from DDPM¶
NCSN is usually described as learning the score field directly.
DDPM is usually described as learning a noise predictor or denoiser on a discrete forward chain.
The two views are mathematically close, but the training objective and the sampler are presented differently.
def train_ncsn_model(num_steps: int = 1000, batch_size: int = 384, lr: float = 1.2e-3):
model = TinyDiffusionMLP().to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr)
losses = []
for _ in range(num_steps):
x0 = sample_cross_data(batch_size)
idx = torch.randint(0, len(NCSN_SIGMAS), (batch_size,), device=device)
sigma = NCSN_SIGMAS[idx].unsqueeze(1)
noise = torch.randn_like(x0)
xt = x0 + sigma * noise
target = -noise / sigma
t = idx.float().unsqueeze(1) / (len(NCSN_SIGMAS) - 1)
pred = model(xt, t)
loss = ((pred - target) ** 2 * sigma ** 2).mean()
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.item())
return model, losses
@torch.no_grad()
def sample_ncsn(model: torch.nn.Module, num_samples: int = 1024, base_step: float = 1.6e-4, steps_per_level: int = 60):
x = NCSN_SIGMAS[0] * torch.randn(num_samples, 2, device=device)
for idx, sigma in enumerate(NCSN_SIGMAS):
t = torch.full((num_samples, 1), idx / (len(NCSN_SIGMAS) - 1), device=device)
step_size = base_step * float((sigma / NCSN_SIGMAS[-1]) ** 2)
for _ in range(steps_per_level):
score = model(x, t)
x = x + step_size * score + math.sqrt(2.0 * step_size) * torch.randn_like(x)
return x
torch.manual_seed(7)
ncsn_model, ncsn_losses = train_ncsn_model()
ncsn_samples = sample_ncsn(ncsn_model)fig, axes = plt.subplots(2, 2, figsize=(11, 8))
axes[0, 0].plot(ncsn_losses)
axes[0, 0].set_title('NCSN training loss')
axes[0, 0].set_xlabel('training step')
axes[0, 0].set_ylabel('weighted DSM loss')
axes[0, 0].grid(alpha=0.25)
plot_score_quiver(axes[0, 1], ncsn_model, noise_value=0.0, title='score field at the lowest noise level')
plot_score_quiver(axes[1, 0], ncsn_model, noise_value=1.0, title='score field at the highest noise level')
plot_point_cloud(axes[1, 1], ncsn_samples, 'NCSN samples after annealed Langevin')
plt.tight_layout()
plt.show()
Exercise: why condition the score network on noise level?¶
Task. In one sentence, explain why the NCSN input is rather than only .
Answer
Because the optimal score field changes with the noise level: low-noise data needs a sharp local denoising field, while high-noise data needs a broader global direction field.
Part 3B: DDPM¶
DDPM starts from a discrete forward Markov chain
which implies the useful closed-form corruption rule
Instead of predicting the score directly, the standard DDPM network is trained to predict the injected noise:
with the mean-squared error objective
How it works¶
The forward process is a discrete chain with many small steps.
The network predicts the noise at a randomly chosen step.
Sampling walks backward through the chain and reintroduces a stochastic term at each reverse step.
How it differs from NCSN¶
DDPM is tied to a specific discrete noising schedule.
The network is usually phrased as a noise predictor rather than a direct score predictor.
The reverse process is a learned stochastic reverse diffusion chain rather than annealed Langevin dynamics.
def train_ddpm_model(num_steps: int = 900, batch_size: int = 384, lr: float = 1.5e-3):
model = TinyDiffusionMLP().to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr)
losses = []
for _ in range(num_steps):
x0 = sample_cross_data(batch_size)
t_idx = torch.randint(0, NUM_DDPM_STEPS, (batch_size,), device=device)
xt, noise = ddpm_q_sample(x0, t_idx)
t = t_idx.float().unsqueeze(1) / (NUM_DDPM_STEPS - 1)
pred = model(xt, t)
loss = torch.mean((pred - noise) ** 2)
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.item())
return model, losses
@torch.no_grad()
def sample_ddpm_or_ddim(
model: torch.nn.Module,
num_samples: int = 1024,
eta: float = 1.0,
x_init: torch.Tensor = None,
return_history: bool = False,
record_every: int = 5,
):
if x_init is None:
x = torch.randn(num_samples, 2, device=device)
else:
x = x_init.clone().to(device)
num_samples = x.shape[0]
history = []
if return_history:
history.append(x.detach().cpu())
for step_idx in reversed(range(NUM_DDPM_STEPS)):
t = torch.full((num_samples, 1), step_idx / (NUM_DDPM_STEPS - 1), device=device)
alpha_bar_t = DDPM_ALPHA_BAR[step_idx]
eps_pred = model(x, t)
x0_pred = (x - torch.sqrt(1.0 - alpha_bar_t) * eps_pred) / torch.sqrt(alpha_bar_t)
if step_idx == 0:
x = x0_pred
else:
alpha_bar_prev = DDPM_ALPHA_BAR[step_idx - 1]
sigma_t = eta * torch.sqrt((1.0 - alpha_bar_prev) / (1.0 - alpha_bar_t) * DDPM_BETAS[step_idx])
direction = torch.sqrt(torch.clamp(1.0 - alpha_bar_prev - sigma_t ** 2, min=1e-8)) * eps_pred
x = torch.sqrt(alpha_bar_prev) * x0_pred + direction
if eta > 0:
x = x + sigma_t * torch.randn_like(x)
if return_history and (step_idx % record_every == 0 or step_idx == 0):
history.append(x.detach().cpu())
if return_history:
return x, history
return x
torch.manual_seed(11)
ddpm_model, ddpm_losses = train_ddpm_model()
ddpm_samples = sample_ddpm_or_ddim(ddpm_model, eta=1.0)fig, axes = plt.subplots(1, 2, figsize=(10.5, 4.2))
axes[0].plot(ddpm_losses)
axes[0].set_title('DDPM training loss')
axes[0].set_xlabel('training step')
axes[0].set_ylabel('noise-prediction MSE')
axes[0].grid(alpha=0.25)
plot_point_cloud(axes[1], ddpm_samples, 'DDPM samples (stochastic reverse process)')
plt.tight_layout()
plt.show()
Exercise: what does a DDPM network actually predict?¶
Task. When people say that DDPM predicts the noise, what quantity are they referring to?
Answer
They mean the Gaussian perturbation \epsilon used in the forward corruption rule x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\,\epsilon.
Part 3C: DDIM¶
DDIM is easiest to understand as a different sampler built on top of the same DDPM-trained denoiser.
The network is still trained with the DDPM objective. What changes is the reverse update: DDIM removes the extra stochastic term and follows a deterministic path from the same initial noise.
What changes¶
Training: unchanged from DDPM.
Network: unchanged from DDPM.
Sampler: stochastic when , deterministic when .
That is why DDIM is often described as a way to trade some diversity for faster, cleaner, or more reproducible sampling.
torch.manual_seed(11)
shared_start_noise = torch.randn(256, 2, device=device)
shared_ddpm_samples, ddpm_history = sample_ddpm_or_ddim(
ddpm_model,
eta=1.0,
x_init=shared_start_noise,
return_history=True,
record_every=6,
)
shared_ddim_samples, ddim_history = sample_ddpm_or_ddim(
ddpm_model,
eta=0.0,
x_init=shared_start_noise,
return_history=True,
record_every=6,
)fig, axes = plt.subplots(2, 2, figsize=(10.5, 8))
plot_point_cloud(axes[0, 0], shared_start_noise, 'shared starting noise', centers=None)
plot_point_cloud(axes[0, 1], shared_ddpm_samples, 'DDPM endpoints from that noise')
plot_reverse_paths(axes[1, 0], ddpm_history, 'DDPM reverse paths (noise injected each step)')
plot_reverse_paths(axes[1, 1], ddim_history, 'DDIM reverse paths (deterministic once initialized)')
plt.tight_layout()
plt.show()
Exercise: why can DDPM and DDIM share the same trained model?¶
Task. In one or two sentences, explain why DDIM does not need a separately trained network.
Answer
Because DDIM changes the reverse-time integration rule, not the training target. The same denoiser or noise predictor learned by DDPM can be plugged into a deterministic reverse update.
Part 3D: What actually differs?¶
Here is the cleanest way to separate the three families.
| Model family | What the network learns | Forward process | Reverse sampler | Good mental model |
|---|---|---|---|---|
| NCSN / score-based | score | direct Gaussian perturbations at chosen noise levels | annealed Langevin dynamics | learn the vector field directly |
| DDPM | noise predictor or denoiser | discrete Markov diffusion chain | stochastic reverse diffusion | many tiny denoising steps |
| DDIM | same network as DDPM | same forward training setup as DDPM | deterministic reverse path | DDPM without the random jitter |
A practical way to remember it¶
If you want the most direct score-field story, think NCSN.
If you want the standard discrete diffusion training recipe, think DDPM.
If you want to reuse a DDPM model but sample deterministically, think DDIM.
Final comparison exercise¶
Task. Which pair differs only in the sampler, and which pair differs in both training language and sampler?
Answer
DDPM and DDIM differ only in the sampler. NCSN differs from DDPM/DDIM in both the training framing and the reverse-time sampler.
Part 4: Bridge to crystals¶
The crystal notebooks reuse the same core diffusion template:
define a forward corruption process on a crystal representation,
train a network to predict the reverse denoising direction,
and sample by walking backward from an easy noise distribution.
What changes in the crystal setting is not the broad logic of diffusion, but the representation: atomic species, fractional coordinates, and lattice geometry replace the small 2D point clouds in this notebook.
Final exercise¶
Task. Write one sentence explaining why this 2D comparison is still useful before moving to crystals.
Answer
It isolates the model-family differences between NCSN, DDPM, and DDIM, so the crystal notebook can focus on representation and conditioning instead of basic diffusion mechanics.