Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Diffusion Fundamentals: DDPM, Score-Based Models, NCSN, and DDIM

Open in Colab

Diffusion Fundamentals: DDPM, Score-Based Models, NCSN, and DDIM

This notebook introduces diffusion models from the ground up for readers who want both intuition and working implementations.

Aims

  • build physical intuition from energy landscapes, stochastic motion, and Langevin dynamics,

  • connect that intuition to the forward-noising / reverse-denoising recipe,

  • compare three concrete model families on the same toy dataset: NCSN / score-based models, DDPM, and DDIM.

Learning outcomes

By the end you should be able to:

  1. explain what the score is and why Langevin dynamics matters,

  2. describe what NCSN, DDPM, and DDIM learn and how they sample,

  3. say clearly what DDPM and DDIM share and what changes between them,

  4. recognize how the same logic carries over to crystals and materials.

If you want the applied crystal notebook next, open crystal-diffusion-from-scratch.ipynb.

Runtime expectations

  • The warm-up sections and plots run quickly on CPU.

  • The toy diffusion-model training sections are still real optimization loops, so expect roughly 10--20 minutes on CPU or faster on GPU for a full top-to-bottom first pass.

  • If you only want the conceptual comparison, you can stop after the plots and short model runs in Part 3.

Key references for this notebook

This notebook is pedagogical, but the model families and notation come from a small set of canonical papers:

The code here is intentionally small and CPU-friendly, so it illustrates the core ideas rather than reproducing the full training setups from the papers.

from abc import ABC, abstractmethod
from typing import Optional

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes._axes import Axes
import torch
import torch.distributions as D
from torch.func import vmap, jacrev
from tqdm import tqdm
import seaborn as sns

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Part 1: Energy landscapes, drift, and stochastic motion

Let us make precise the central objects of study: ordinary differential equations (ODEs) and stochastic differential equations (SDEs).

In this notebook, you can interpret the time-dependent score field

u:Rd×[0,1]Rd,(x,t)ut(x)u:\mathbb{R}^d\times [0,1]\to \mathbb{R}^d,\qquad (x,t)\mapsto u_t(x)

as an effective velocity field in configuration space. In chemistry language, this is analogous to saying:

  • the system is currently at configuration xx,

  • at time tt,

  • and the local geometry of the landscape tells us which way the configuration wants to move.

An ODE is then

dXt=ut(Xt)dt,X0=x0.dX_t = u_t(X_t)\,dt, \qquad X_0=x_0.

This is the deterministic case: once the initial condition is fixed, the trajectory is fixed. Conceptually, this is like continuous-time relaxation or steepest-descent-style motion on an energy landscape.

An SDE is

dXt=ut(Xt)dt+σtdWt,X0=x0,dX_t = u_t(X_t)\,dt + \sigma_t\, dW_t,\qquad X_0=x_0,

which adds a stochastic term driven by Brownian motion (Wt)0t1(W_t)_{0\le t\le 1}. Here:

  • ut(Xt)u_t(X_t) is the drift,

  • σt\sigma_t is the diffusion coefficient,

  • and the stochastic term plays the role of thermal agitation / random kicks.

So the mental model is:

ODE = deterministic relaxation
SDE = deterministic relaxation + thermal/random fluctuations

That same language will become very useful later, because diffusion models can be viewed as learning how to reverse a noising process in exactly this SDE/ODE framework.

Translation guide: math language \leftrightarrow chemistry / materials language

Math / generative modelingChemistry / materials intuition
state xxconfiguration, coordinate, structure descriptor
drift ut(x)u_t(x)deterministic velocity / force-induced motion
diffusion coefficient σt\sigma_tnoise strength, temperature-like stochasticity
Brownian motionfree diffusion
OU processdiffusion in a harmonic basin
logp(x)\nabla \log p(x) (score)force-like field toward high-probability / low-energy regions
Langevin dynamicsnoisy force-driven sampling
stationary distributionequilibrium distribution

You do not need to think of these as separate worlds. The same equations appear in both.

class ODE(ABC):
    @abstractmethod
    def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Returns the drift coefficient of the ODE.
        Args:
            - xt: state at time t, shape (bs, dim)
            - t: time, shape ()
        Returns:
            - drift_coefficient: shape (batch_size, dim)
        """
        pass

class SDE(ABC):
    @abstractmethod
    def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Returns the drift coefficient of the ODE.
        Args:
            - xt: state at time t, shape (batch_size, dim)
            - t: time, shape ()
        Returns:
            - drift_coefficient: shape (batch_size, dim)
        """
        pass

    @abstractmethod
    def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Returns the diffusion coefficient of the ODE.
        Args:
            - xt: state at time t, shape (batch_size, dim)
            - t: time, shape ()
        Returns:
            - diffusion_coefficient: shape (batch_size, dim)
        """
        pass

Note. You can indeed view an ODE as the zero-noise limit of an SDE. We will still treat them separately here, because it keeps the simulation schemes and the physical intuition cleaner.

Part 1A: Numerical integration as repeated local updates

In practice, we almost never solve these equations analytically. Instead, we integrate them numerically, exactly as you would do for equations of motion in simulation.

If we think of XtX_t as the current configuration, then an ODE says:

update the configuration according to the local drift field.

For the ODE

dXt=ut(Xt)dt,dX_t = u_t(X_t)\,dt,

the explicit Euler step is

Xt+h=Xt+hut(Xt),X_{t+h} = X_t + h\,u_t(X_t),

where h=Δth=\Delta t is the step size.

For the SDE

dXt=ut(Xt)dt+σtdWt,dX_t = u_t(X_t)\,dt + \sigma_t\,dW_t,

the Euler--Maruyama step is

Xt+h=Xt+hut(Xt)+σthξt,X_{t+h} = X_t + h\,u_t(X_t) + \sigma_t \sqrt{h}\,\xi_t,

where ξtN(0,I)\xi_t\sim\mathcal{N}(0,I).

This is the stochastic analogue of a basic explicit integrator: deterministic motion from the drift, plus a random increment whose scale grows like h\sqrt{h}. If you come from Brownian dynamics or overdamped Langevin dynamics, this form should look very familiar.

class Simulator(ABC):
    @abstractmethod
    def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor):
        """
        Takes one simulation step
        Args:
            - xt: state at time t, shape (batch_size, dim)
            - t: time, shape ()
            - dt: time, shape ()
        Returns:
            - nxt: state at time t + dt
        """
        pass

    @torch.no_grad()
    def simulate(self, x: torch.Tensor, ts: torch.Tensor):
        """
        Simulates using the discretization gives by ts
        Args:
            - x_init: initial state at time ts[0], shape (batch_size, dim)
            - ts: timesteps, shape (nts,)
        Returns:
            - x_fina: final state at time ts[-1], shape (batch_size, dim)
        """
        for t_idx in range(len(ts) - 1):
            t = ts[t_idx]
            h = ts[t_idx + 1] - ts[t_idx]
            x = self.step(x, t, h)
        return x

    @torch.no_grad()
    def simulate_with_trajectory(self, x: torch.Tensor, ts: torch.Tensor):
        """
        Simulates using the discretization gives by ts
        Args:
            - x_init: initial state at time ts[0], shape (bs, dim)
            - ts: timesteps, shape (num_timesteps,)
        Returns:
            - xs: trajectory of xts over ts, shape (batch_size, num_timesteps, dim)
        """
        xs = [x.clone()]
        for t_idx in tqdm(range(len(ts) - 1)):
            t = ts[t_idx]
            h = ts[t_idx + 1] - ts[t_idx]
            x = self.step(x, t, h)
            xs.append(x.clone())
        return torch.stack(xs, dim=1)

Question 1.1: Implement explicit integrators for deterministic and stochastic dynamics

Try first: before running the next two code cells, write down the Euler and Euler-Maruyama updates in a scratch cell or on paper.

Connect each term to the physical picture:

  • deterministic displacement from the drift,

  • a Gaussian kick for the stochastic case,

  • and the Δt\sqrt{\Delta t} scaling because a Wiener increment has variance Δt\Delta t.

The worked solution is already present in the code cells below so the notebook still runs cleanly.

Worked answer

For an ODE, the explicit Euler update is

x_{t+h} = x_t + f(x_t,t)h.

For an SDE

dX_t = b(X_t,t)\,dt + \sigma(X_t,t)\,dW_t,
Euler-Maruyama becomes
X_{t+h} = X_t + b(X_t,t)h + \sigma(X_t,t)\sqrt{h}\,\xi,
\qquad \xi \sim \mathcal{N}(0, I).

That is exactly what the next two code cells implement.

class EulerSimulator(Simulator):
    def __init__(self, ode: ODE):
        self.ode = ode

    def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor):
        return xt + self.ode.drift_coefficient(xt,t) * h
class EulerMaruyamaSimulator(Simulator):
    def __init__(self, sde: SDE):
        self.sde = sde

    def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor):
        return xt + self.sde.drift_coefficient(xt,t) * h + self.sde.diffusion_coefficient(xt,t) * torch.sqrt(h) * torch.randn_like(xt)

Note. When the diffusion coefficient is zero, Euler--Maruyama reduces to ordinary Euler. In other words, the stochastic simulator collapses back to a purely deterministic integrator.

Part 1B: Brownian motion and Ornstein--Uhlenbeck intuition

We now build intuition for two especially important SDEs.

From a chemistry/materials viewpoint, these are good toy models for motion on simple landscapes:

  • Brownian motion = diffusion on a flat landscape with no restoring force,

  • Ornstein--Uhlenbeck (OU) = diffusion in a harmonic basin with a linear restoring force.

These two examples already contain most of the qualitative ingredients we care about later: random exploration, drift toward preferred regions, and approach to an equilibrium distribution.

Question 2.1: Brownian motion as free diffusion

Brownian motion is the case with no drift and constant diffusion:

dXt=σdWt,X0=0.dX_t = \sigma\,dW_t,\qquad X_0=0.

You can think of this as motion on a perfectly flat potential energy surface. There is no force pulling the particle anywhere in particular; it only wanders because of random kicks.

Try first: before running the Brownian-motion code, predict the trajectories qualitatively.

How should the paths look when sigma is very large? What about when sigma is close to zero?

Suggested answer

Large sigma gives rough, rapidly spreading trajectories because every timestep receives a stronger random kick. When sigma is close to zero, the paths barely move and stay near the initial condition.

Try first: pause and decide what the drift and diffusion should be for Brownian motion.

In physical language there is no restoring force, so the drift is zero, while the noise strength is spatially uniform.

The worked implementation is already in the next code cell so the notebook remains runnable.

Worked answer

Brownian motion satisfies

dX_t = \sigma\,dW_t.
So the drift coefficient is
b(x,t)=0,
and the diffusion coefficient is the constant field
\sigma(x,t)=\sigma.
In tensor code that means torch.zeros_like(xt) for the drift and self.sigma * torch.ones_like(xt) for the diffusion.
class BrownianMotion(SDE):
    def __init__(self, sigma: float):
        self.sigma = sigma

    def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Returns the drift coefficient of the ODE.
        Args:
            - xt: state at time t, shape (bs, dim)
            - t: time, shape ()
        Returns:
            - drift: shape (bs, dim)
        """
        return torch.zeros_like(xt)

    def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Returns the diffusion coefficient of the ODE.
        Args:
            - xt: state at time t, shape (bs, dim)
            - t: time, shape ()
        Returns:
            - diffusion: shape (bs, dim)
        """
        return self.sigma * torch.ones_like(xt)

Now let us plot trajectories. As you inspect the figures, try to think simultaneously in two ways:

  1. as individual paths in time, and

  2. as an evolving ensemble of configurations.

That ensemble viewpoint is the one that later becomes central for generative modeling.

def plot_trajectories_1d(x0: torch.Tensor, simulator: Simulator, timesteps: torch.Tensor, ax: Optional[Axes] = None, show_hist: bool = False, decouple_hist_axis: bool = False):
        """
        Graphs the trajectories of a one-dimensional SDE with given initial values (x0) and simulation timesteps (timesteps).
        Args:
            - x0: state at time t, shape (num_trajectories, 1)
            - simulator: Simulator object used to simulate
            - t: timesteps to simulate along, shape (num_timesteps,)
            - ax: pyplot Axes object to plot on
            - decouple_hist_axis: if True, do not share y-axis between trajectories and histogram
        """
        if ax is None:
            ax = plt.gca()
        trajectories = simulator.simulate_with_trajectory(x0, timesteps) # (num_trajectories, num_timesteps, ...)

        line_color = sns.color_palette("crest", 1)[0]
        hist_color = sns.color_palette("flare", 1)[0]
        label_size = 12
        tick_size = 10

        timesteps_cpu = timesteps.detach().cpu().numpy()
        for trajectory_idx in range(trajectories.shape[0]):
            trajectory = trajectories[trajectory_idx, :, 0].detach().cpu().numpy() # (num_timesteps,)
            sns.lineplot(
                x=timesteps_cpu,
                y=trajectory,
                ax=ax,
                color=line_color,
                alpha=0.45,
                linewidth=1.1,
                legend=False,
            )

        ax.set_xlabel(r"time ($t$)", fontsize=label_size)
        ax.set_ylabel(r"$X_t$", fontsize=label_size)
        ax.tick_params(axis='both', labelsize=tick_size)
        ax.grid(alpha=0.2, linewidth=0.6)

        if show_hist:
            terminal_points = trajectories[:, -1, 0].detach().cpu().numpy()
            data_range = float(terminal_points.max() - terminal_points.min()) if terminal_points.size else 1.0
            binwidth = max(data_range / 25.0, 0.05)

            from mpl_toolkits.axes_grid1 import make_axes_locatable
            divider = make_axes_locatable(ax)
            sharey = None if decouple_hist_axis else ax
            hist_ax = divider.append_axes("right", size="22%", pad=0.45, sharey=sharey)
            sns.histplot(
                y=terminal_points,
                ax=hist_ax,
                binwidth=binwidth,
                color=hist_color,
                alpha=0.7,
                edgecolor="white",
                linewidth=0.5,
            )
            hist_ax.set_xlabel("count", fontsize=label_size)
            hist_ax.set_ylabel("")
            hist_ax.tick_params(axis='both', labelsize=tick_size)
            if decouple_hist_axis:
                hist_ax.tick_params(axis='y', left=True, labelleft=True)
            else:
                hist_ax.tick_params(axis='y', left=False, labelleft=False)
            hist_ax.grid(axis='x', alpha=0.2, linewidth=0.6)

        fig = ax.figure
        if fig is not None:
            title = ax.get_title()
            if title:
                title_size = ax.title.get_fontsize()
                ax.set_title("")

                axes = [ax]
                if show_hist:
                    axes.append(hist_ax)

                fig.canvas.draw()
                bboxes = [a.get_position() for a in axes]

                left = min(b.x0 for b in bboxes)
                right = max(b.x1 for b in bboxes)
                top = max(b.y1 for b in bboxes)

                x_center = 0.5 * (left + right)
                y = top + 0.005

                fig.text(
                    x_center,
                    y,
                    title,
                    ha="center",
                    va="bottom",
                    fontsize=title_size,
                )

sigma = 1.0
n_traj = 500
brownian_motion = BrownianMotion(sigma)
simulator = EulerMaruyamaSimulator(sde=brownian_motion)
x0 = torch.zeros(n_traj,1).to(device) # Initial values - let's start at zero
ts = torch.linspace(0.0,5.0,500).to(device) # simulation timesteps

plt.figure(figsize=(9, 6))
ax = plt.gca()
ax.set_title(r'Trajectories of Brownian Motion with $\sigma=$' + str(sigma), fontsize=18)
ax.set_xlabel(r'time ($t$)', fontsize=18)
ax.set_ylabel(r'$x_t$', fontsize=18)
plot_trajectories_1d(x0, simulator, ts, ax, show_hist=True)
plt.show()
100%|██████████| 499/499 [00:00<00:00, 3537.59it/s]
<Figure size 900x600 with 2 Axes>

Your job: What happens when you vary the value of sigma?

Your answer:

Answer

Increasing sigma increases the roughness of each trajectory and broadens the distribution of terminal positions. In other words, stronger noise explores configuration space more aggressively.

Question 2.2: Ornstein--Uhlenbeck as diffusion in a harmonic well

The OU process is

dXt=θXtdt+σdWt,X0=x0.dX_t = -\theta X_t\,dt + \sigma\,dW_t,\qquad X_0=x_0.

This is one of the cleanest SDEs to interpret physically. The drift term θXt-\theta X_t is a linear restoring force that pulls the system back toward the origin, so you can view it as motion in a quadratic potential

U(x)x2.U(x)\propto x^2.

Brownian motion wandered on a flat landscape; the OU process wanders in a basin.

Try first: predict the qualitative behavior before you run the OU-process code.

What should happen when theta is very small? What about when theta is very large?

Suggested answer

When theta is very small, the restoring force is weak, so the process behaves more like wandering Brownian motion. When theta is large, the harmonic well is stiff, so trajectories are pulled back toward the origin much more aggressively.

Try first: write down the OU drift and diffusion before you inspect the next code cell.

Interpretation:

  • theta controls how stiff the harmonic well is,

  • sigma controls how strongly thermal noise kicks the system around.

The worked implementation is already given below so execution stays intact.

Worked answer

The Ornstein-Uhlenbeck process is

dX_t = -\theta X_t\,dt + \sigma\,dW_t.
So the drift coefficient is
b(x,t) = -\theta x,
and the diffusion coefficient is again spatially uniform:
\sigma(x,t)=\sigma.
That is why the code returns -self.theta * xt for the drift and self.sigma * torch.ones_like(xt) for the diffusion.
class OUProcess(SDE):
    def __init__(self, theta: float, sigma: float):
        self.theta = theta
        self.sigma = sigma

    def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Returns the drift coefficient of the SDE.
        Args:
            - xt: state at time t, shape (bs, dim)
            - t: time, shape ()
        Returns:
            - drift: shape (bs, dim)
        """
        return - self.theta * xt

    def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Returns the diffusion coefficient of the SDE.
        Args:
            - xt: state at time t, shape (bs, dim)
            - t: time, shape ()
        Returns:
            - diffusion: shape (bs, dim)
        """
        return self.sigma * torch.ones_like(xt)
# Try comparing multiple choices side-by-side
thetas_and_sigmas = [
    (0.25, 0.0),
    (0.25, 0.5),
    (0.25, 2.0),
]
simulation_time = 10.0

num_plots = len(thetas_and_sigmas)
fig, axes = plt.subplots(2, num_plots, figsize=(10.5 * num_plots, 15))

# Top row: dynamics
n_traj = 10
for idx, (theta, sigma) in enumerate(thetas_and_sigmas):
    ou_process = OUProcess(theta, sigma)
    simulator = EulerMaruyamaSimulator(sde=ou_process)
    x0 = torch.linspace(-10.0,10.0,n_traj).view(-1,1).to(device) # Initial values - let's start at zero
    ts = torch.linspace(0.0,simulation_time,1000).to(device) # simulation timesteps

    ax = axes[0,idx]
    ax.set_title(f'Trajectories of OU Process with $\\sigma = ${sigma}, $\\theta = ${theta}', fontsize=15)
    plot_trajectories_1d(x0, simulator, ts, ax, show_hist=False)

# Bottom row: distribution
n_traj = 500
for idx, (theta, sigma) in enumerate(thetas_and_sigmas):
    ou_process = OUProcess(theta, sigma)
    simulator = EulerMaruyamaSimulator(sde=ou_process)
    x0 = torch.linspace(-10.0,10.0,n_traj).view(-1,1).to(device) # Initial values - let's start at zero
    ts = torch.linspace(0.0,simulation_time,1000).to(device) # simulation timesteps

    ax = axes[1,idx]
    ax.set_title(f'Trajectories of OU Process with $\\sigma = ${sigma}, $\\theta = ${theta}', fontsize=15)
    ax = plot_trajectories_1d(x0, simulator, ts, ax, show_hist=True, decouple_hist_axis=True)
plt.show()
100%|██████████| 999/999 [00:00<00:00, 6902.13it/s]
100%|██████████| 999/999 [00:00<00:00, 6641.86it/s]
100%|██████████| 999/999 [00:00<00:00, 6528.63it/s]
100%|██████████| 999/999 [00:00<00:00, 6457.51it/s]
100%|██████████| 999/999 [00:00<00:00, 5602.67it/s]
100%|██████████| 999/999 [00:00<00:00, 5338.44it/s]
<Figure size 3150x1500 with 9 Axes>

Your job: What do you notice about the long-time behavior? Are the trajectories converging to a single point, or to a distribution?

Give two qualitative sentences of the form:

“When (θ\theta or σ\sigma) goes (up or down), we see ...”

Hint. Keep an eye on the ratio

Dσ22θ.D \triangleq \frac{\sigma^2}{2\theta}.

From a statistical mechanics perspective, this ratio plays the role of the equilibrium variance of the stationary Gaussian.

Your answer:

Answer

When \theta goes up (resp. down), the restoring force becomes stronger (resp. weaker), so the stationary distribution narrows (resp. broadens). When \sigma goes up (resp. down), the thermal kicks become stronger (resp. weaker), so the stationary distribution broadens (resp. narrows).

# Let's compare various OU processes!
sigmas = [1.0, 2.0, 10.0]
ds = [0.25, 1.0, 4.0] # sigma**2 / 2t
simulation_time = 15.0
n_traj = 500

fig, axes = plt.subplots(len(ds), len(sigmas), figsize=(8 * len(sigmas), 8 * len(ds)))
axes = axes.reshape((len(ds), len(sigmas)))
for d_idx, d in enumerate(ds):
    for s_idx, sigma in enumerate(sigmas):
        theta = sigma**2 / 2 / d
        ou_process = OUProcess(theta, sigma)
        simulator = EulerMaruyamaSimulator(sde=ou_process)
        x0 = torch.linspace(-20.0,20.0,n_traj).view(-1,1).to(device)
        time_scale = sigma**2
        ts = torch.linspace(0.0,simulation_time / time_scale,1000).to(device) # simulation timesteps
        ax = axes[d_idx, s_idx]
        ax.set_title(f'OU Trajectories with Sigma={sigma}, Theta={theta}, D={d}')
        plot_trajectories_1d(x0=x0, simulator=simulator, timesteps=ts, ax=ax, show_hist=True, decouple_hist_axis=True)
        ax.set_xlabel(r'$t$')
        ax.set_ylabel(r'X_t')
plt.show()
100%|██████████| 999/999 [00:00<00:00, 5707.40it/s]
100%|██████████| 999/999 [00:00<00:00, 5655.08it/s]
100%|██████████| 999/999 [00:00<00:00, 5306.04it/s]
100%|██████████| 999/999 [00:00<00:00, 6157.78it/s]
100%|██████████| 999/999 [00:00<00:00, 6192.30it/s]
100%|██████████| 999/999 [00:00<00:00, 5488.86it/s]
100%|██████████| 999/999 [00:00<00:00, 6189.90it/s]
100%|██████████| 999/999 [00:00<00:00, 5709.22it/s]
100%|██████████| 999/999 [00:00<00:00, 4310.93it/s]
<Figure size 2400x2400 with 18 Axes>

Your job: What conclusion can we draw from the figure above? One qualitative sentence is fine. We will revisit this in Section 3.2.

Your answer:

Answer
  1. For fixed D=\frac{\sigma^2}{2\theta}, changing \sigma mainly changes the mixing speed toward equilibrium.

  2. For fixed \sigma, increasing D broadens the equilibrium distribution, i.e. the basin is explored more widely at stationarity.

Temperature intuition: particles in a double-well landscape

Before we jump to learned diffusion models, it helps to watch particles move on a simple energy surface.

We use an overdamped Langevin update in a double-well potential

U(x,y)=14(x21.5)2+0.7y2,U(x, y) = \tfrac{1}{4}(x^2 - 1.5)^2 + 0.7 y^2,

so there are two preferred basins, one on the left and one on the right.

The update is

Xk+1=XkU(Xk)Δt+2TΔtξk,X_{k+1} = X_k - \nabla U(X_k)\,\Delta t + \sqrt{2T\,\Delta t}\,\xi_k,

where TT plays the role of a temperature-like noise strength.

What to notice

  • Low temperature means the particle mostly rattles inside one basin.

  • Higher temperature means barrier-crossing becomes more common.

  • This is the same drift-plus-noise picture we will later reuse in reverse-time diffusion sampling.

import math

def double_well_potential(xy: torch.Tensor) -> torch.Tensor:
    x = xy[..., 0]
    y = xy[..., 1]
    return 0.25 * (x ** 2 - 1.5) ** 2 + 0.7 * y ** 2


def double_well_grad(xy: torch.Tensor) -> torch.Tensor:
    x = xy[..., 0]
    y = xy[..., 1]
    dUx = x * (x ** 2 - 1.5)
    dUy = 1.4 * y
    return torch.stack([dUx, dUy], dim=-1)


@torch.no_grad()
def simulate_double_well(
    temperature: float = 0.10,
    n_particles: int = 24,
    n_steps: int = 180,
    dt: float = 0.03,
    seed: int = 0,
):
    torch.manual_seed(seed)
    start = torch.randn(n_particles, 2, device=device) * 0.18
    start[:, 0] -= 1.1
    x = start.clone()
    traj = [x.detach().cpu()]
    for _ in range(n_steps):
        drift = -double_well_grad(x)
        noise = math.sqrt(2.0 * temperature * dt) * torch.randn_like(x)
        x = x + dt * drift + noise
        traj.append(x.detach().cpu())
    return torch.stack(traj, dim=0)


def plot_double_well_demo(
    temperature: float = 0.10,
    n_particles: int = 24,
    n_steps: int = 180,
    dt: float = 0.03,
    seed: int = 0,
):
    traj = simulate_double_well(
        temperature=temperature,
        n_particles=n_particles,
        n_steps=n_steps,
        dt=dt,
        seed=seed,
    )
    grid_x = torch.linspace(-2.2, 2.2, 180)
    grid_y = torch.linspace(-1.8, 1.8, 180)
    xx, yy = torch.meshgrid(grid_x, grid_y, indexing='xy')
    grid = torch.stack([xx, yy], dim=-1)
    energy = double_well_potential(grid).numpy()

    fig, axes = plt.subplots(1, 2, figsize=(12, 4.8))

    axes[0].contourf(grid_x.numpy(), grid_y.numpy(), energy, levels=30, cmap='cividis')
    for idx in range(min(n_particles, traj.shape[1])):
        path = traj[:, idx, :]
        axes[0].plot(path[:, 0], path[:, 1], alpha=0.45, linewidth=1.0)
    axes[0].scatter(traj[0, :, 0], traj[0, :, 1], s=18, color='white', edgecolor='black', linewidth=0.4, label='start')
    axes[0].scatter(traj[-1, :, 0], traj[-1, :, 1], s=22, color='tab:red', alpha=0.8, label='end')
    axes[0].set_title(f'Double-well trajectories at T={temperature:.2f}')
    axes[0].set_xlabel('$x$')
    axes[0].set_ylabel('$y$')
    axes[0].legend(loc='upper right')

    time = np.arange(traj.shape[0]) * dt
    for idx in range(min(8, traj.shape[1])):
        axes[1].plot(time, traj[:, idx, 0], alpha=0.75, linewidth=1.1)
    axes[1].axhline(0.0, color='black', linestyle='--', linewidth=0.8, alpha=0.5)
    axes[1].set_title('The $x$ coordinate shows barrier-hopping directly')
    axes[1].set_xlabel('time')
    axes[1].set_ylabel('$x_t$')
    axes[1].grid(alpha=0.2)
    plt.tight_layout()
    plt.show()


def plot_temperature_sweep(temperatures=(0.03, 0.12, 0.30), n_particles: int = 24, n_steps: int = 180, dt: float = 0.03):
    grid_x = torch.linspace(-2.2, 2.2, 180)
    grid_y = torch.linspace(-1.8, 1.8, 180)
    xx, yy = torch.meshgrid(grid_x, grid_y, indexing='xy')
    grid = torch.stack([xx, yy], dim=-1)
    energy = double_well_potential(grid).numpy()

    fig, axes = plt.subplots(1, len(temperatures), figsize=(4.6 * len(temperatures), 4.2))
    if len(temperatures) == 1:
        axes = [axes]
    for ax, temp in zip(axes, temperatures):
        traj = simulate_double_well(temp, n_particles=n_particles, n_steps=n_steps, dt=dt, seed=0)
        ax.contourf(grid_x.numpy(), grid_y.numpy(), energy, levels=30, cmap='cividis')
        for idx in range(min(n_particles, traj.shape[1])):
            path = traj[:, idx, :]
            ax.plot(path[:, 0], path[:, 1], alpha=0.35, linewidth=0.9)
        ax.scatter(traj[-1, :, 0], traj[-1, :, 1], s=18, color='tab:red', alpha=0.75)
        ax.set_title(f'T = {temp:.2f}')
        ax.set_xticks([])
        ax.set_yticks([])
    plt.suptitle('As temperature rises, trajectories explore both basins more aggressively', y=1.03)
    plt.tight_layout()
    plt.show()
# @title Explore the double-well demo
energy_demo_temperature = 0.10  # @param {type:"slider", min:0.02, max:0.40, step:0.02}
energy_demo_particles = 24  # @param {type:"slider", min:8, max:48, step:4}
energy_demo_steps = 180  # @param {type:"slider", min:80, max:260, step:20}
energy_demo_dt = 0.03  # @param {type:"slider", min:0.01, max:0.08, step:0.01}
energy_demo_seed = 0  # @param {type:"integer"}
plot_double_well_demo(
    temperature=energy_demo_temperature,
    n_particles=energy_demo_particles,
    n_steps=energy_demo_steps,
    dt=energy_demo_dt,
    seed=energy_demo_seed,
)
<Figure size 1200x480 with 2 Axes>
plot_temperature_sweep(temperatures=(0.03, 0.12, 0.30), n_particles=24, n_steps=180, dt=0.03)
<Figure size 1380x420 with 3 Axes>

Exercise: temperature and barrier crossing

Task. In one or two sentences, explain why higher temperature produces more transitions between the two wells.

Answer

The deterministic drift still pulls particles downhill, but the stochastic term grows like \sqrt{T}, so larger thermal kicks make it easier to cross the barrier at x=0 and visit the other basin.

Why this notebook matters for diffusion models

In score-based diffusion models, we deliberately define a forward noising process that gradually destroys structure, and then learn a reverse-time dynamics that reconstructs it.

For someone in computational chemistry or materials science, a good first mental model is:

  • the forward process pushes configurations toward an easy reference ensemble,

  • the reverse process uses a learned force-like field to guide them back toward realistic structures,

  • and conditioning corresponds to biasing that dynamics toward a desired composition, property, or environment.

This notebook gives you the simulation language that those models are built from.

Part 2: Langevin dynamics and equilibrium sampling

class Density(ABC):
    """
    Distribution with tractable density
    """
    @abstractmethod
    def log_density(self, x: torch.Tensor) -> torch.Tensor:
        """
        Returns the log density at x.
        Args:
            - x: shape (batch_size, dim)
        Returns:
            - log_density: shape (batch_size, 1)
        """
        pass

    def score(self, x: torch.Tensor) -> torch.Tensor:
        """
        Returns the score dx log density(x)
        Args:
            - x: (batch_size, dim)
        Returns:
            - score: (batch_size, dim)
        """
        x = x.unsqueeze(1)  # (batch_size, 1, ...)
        score = vmap(jacrev(self.log_density))(x)  # (batch_size, 1, 1, 1, ...)
        return score.squeeze((1, 2, 3))  # (batch_size, ...)

class Sampleable(ABC):
    """
    Distribution which can be sampled from
    """
    @abstractmethod
    def sample(self, num_samples: int) -> torch.Tensor:
        """
        Returns the log density at x.
        Args:
            - num_samples: the desired number of samples
        Returns:
            - samples: shape (batch_size, dim)
        """
        pass
# Several plotting utility functions
def hist2d_sampleable(sampleable: Sampleable, num_samples: int, ax: Optional[Axes] = None, **kwargs):
    if ax is None:
        ax = plt.gca()
    samples = sampleable.sample(num_samples) # (ns, 2)
    ax.hist2d(samples[:,0].cpu(), samples[:,1].cpu(), **kwargs)

def scatter_sampleable(sampleable: Sampleable, num_samples: int, ax: Optional[Axes] = None, **kwargs):
    if ax is None:
        ax = plt.gca()
    samples = sampleable.sample(num_samples) # (ns, 2)
    ax.scatter(samples[:,0].cpu(), samples[:,1].cpu(), **kwargs)

def imshow_density(density: Density, bins: int, scale: float, ax: Optional[Axes] = None, **kwargs):
    if ax is None:
        ax = plt.gca()
    x = torch.linspace(-scale, scale, bins).to(device)
    y = torch.linspace(-scale, scale, bins).to(device)
    X, Y = torch.meshgrid(x, y)
    xy = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)
    density = density.log_density(xy).reshape(bins, bins).T
    im = ax.imshow(density.cpu(), extent=[-scale, scale, -scale, scale], origin='lower', **kwargs)

def contour_density(density: Density, bins: int, scale: float, ax: Optional[Axes] = None, **kwargs):
    if ax is None:
        ax = plt.gca()
    x = torch.linspace(-scale, scale, bins).to(device)
    y = torch.linspace(-scale, scale, bins).to(device)
    X, Y = torch.meshgrid(x, y)
    xy = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)
    density = density.log_density(xy).reshape(bins, bins).T
    im = ax.contour(density.cpu(), extent=[-scale, scale, -scale, scale], origin='lower', **kwargs)
class Gaussian(torch.nn.Module, Sampleable, Density):
    """
    Two-dimensional Gaussian. Is a Density and a Sampleable. Wrapper around torch.distributions.MultivariateNormal
    """
    def __init__(self, mean, cov):
        """
        mean: shape (2,)
        cov: shape (2,2)
        """
        super().__init__()
        self.register_buffer("mean", mean)
        self.register_buffer("cov", cov)

    @property
    def distribution(self):
        return D.MultivariateNormal(self.mean, self.cov, validate_args=False)

    def sample(self, num_samples) -> torch.Tensor:
        return self.distribution.sample((num_samples,))

    def log_density(self, x: torch.Tensor):
        return self.distribution.log_prob(x).view(-1, 1)

class GaussianMixture(torch.nn.Module, Sampleable, Density):
    """
    Two-dimensional Gaussian mixture model, and is a Density and a Sampleable. Wrapper around torch.distributions.MixtureSameFamily.
    """
    def __init__(
        self,
        means: torch.Tensor,  # nmodes x data_dim
        covs: torch.Tensor,  # nmodes x data_dim x data_dim
        weights: torch.Tensor,  # nmodes
    ):
        """
        means: shape (nmodes, 2)
        covs: shape (nmodes, 2, 2)
        weights: shape (nmodes, 1)
        """
        super().__init__()
        self.nmodes = means.shape[0]
        self.register_buffer("means", means)
        self.register_buffer("covs", covs)
        self.register_buffer("weights", weights)

    @property
    def dim(self) -> int:
        return self.means.shape[1]

    @property
    def distribution(self):
        return D.MixtureSameFamily(
                mixture_distribution=D.Categorical(probs=self.weights, validate_args=False),
                component_distribution=D.MultivariateNormal(
                    loc=self.means,
                    covariance_matrix=self.covs,
                    validate_args=False,
                ),
                validate_args=False,
            )

    def log_density(self, x: torch.Tensor) -> torch.Tensor:
        return self.distribution.log_prob(x).view(-1, 1)

    def sample(self, num_samples: int) -> torch.Tensor:
        return self.distribution.sample(torch.Size((num_samples,)))

    @classmethod
    def random_2D(
        cls, nmodes: int, std: float, scale: float = 10.0, seed = 0.0
    ) -> "GaussianMixture":
        torch.manual_seed(seed)
        means = (torch.rand(nmodes, 2) - 0.5) * scale
        covs = torch.diag_embed(torch.ones(nmodes, 2)) * std ** 2
        weights = torch.ones(nmodes)
        return cls(means, covs, weights)

    @classmethod
    def symmetric_2D(
        cls, nmodes: int, std: float, scale: float = 10.0,
    ) -> "GaussianMixture":
        angles = torch.linspace(0, 2 * np.pi, nmodes + 1)[:nmodes]
        means = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1) * scale
        covs = torch.diag_embed(torch.ones(nmodes, 2) * std ** 2)
        weights = torch.ones(nmodes) / nmodes
        return cls(means, covs, weights)
# Visualize densities
densities = {
    "Gaussian": Gaussian(mean=torch.zeros(2), cov=10 * torch.eye(2)).to(device),
    "Random Mixture": GaussianMixture.random_2D(nmodes=5, std=1.0, scale=20.0, seed=3.0).to(device),
    "Symmetric Mixture": GaussianMixture.symmetric_2D(nmodes=5, std=1.0, scale=8.0).to(device),
}

fig, axes = plt.subplots(1,3, figsize=(18, 6))
bins = 100
scale = 15
for idx, (name, density) in enumerate(densities.items()):
    ax = axes[idx]
    ax.set_title(name)
    imshow_density(density, bins, scale, ax, vmin=-15, cmap=plt.get_cmap('Blues'))
    contour_density(density, bins, scale, ax, colors='grey', linestyles='solid', alpha=0.25, levels=20)
plt.show()
/home/uccabaa/.local/lib/python3.12/site-packages/torch/functional.py:505: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /pytorch/aten/src/ATen/native/TensorShape.cpp:4381.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
<Figure size 1800x600 with 3 Axes>

Question 3.1: Implement overdamped Langevin dynamics

In this section, we simulate the overdamped Langevin dynamics

dXt=12σ2logp(Xt)dt+σdWt.dX_t = \frac{1}{2}\sigma^2 \nabla \log p(X_t)\,dt + \sigma\,dW_t.

If p(x)eβU(x)p(x)\propto e^{-\beta U(x)}, then

logp(x)=βU(x)=βF(x),\nabla \log p(x) = -\beta \nabla U(x) = \beta F(x),

so the drift term is proportional to a force. This is why Langevin dynamics is such a natural meeting point between statistical physics and modern generative modeling.

Try first: before running the class below, identify the drift and diffusion coefficients from the Langevin equation.

The worked implementation is already present in the next code cell so the notebook still runs properly.

Worked answer

Comparing

dX_t = b(X_t,t)\,dt + a(X_t,t)\,dW_t
with the Langevin form gives
b(x,t) = \frac{1}{2}\sigma^2 \nabla \log p(x),
\qquad
a(x,t) = \sigma.
So the drift uses the score of the density and the diffusion is a constant isotropic noise level.
class LangevinSDE(SDE):
    def __init__(self, sigma: float, density: Density):
        self.sigma = sigma
        self.density = density

    def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Returns the drift coefficient of the ODE.
        Args:
            - xt: state at time t, shape (bs, dim)
            - t: time, shape ()
        Returns:
            - drift: shape (bs, dim)
        """
        return 0.5 * self.sigma ** 2 * self.density.score(xt)

    def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Returns the diffusion coefficient of the ODE.
        Args:
            - xt: state at time t, shape (bs, dim)
            - t: time, shape ()
        Returns:
            - diffusion: shape (bs, dim)
        """
        return self.sigma * torch.ones_like(xt)

Now let us graph how a whole ensemble evolves under these dynamics.

As you look at the plots, keep the following question in mind:

If the score points toward high-density regions, how does the combination of drift + noise reshape an initial cloud of samples into the target distribution?

# First, let's define two utility functions...
def every_nth_index(num_timesteps: int, n: int) -> torch.Tensor:
    """
    Compute the indices to record in the trajectory given a record_every parameter
    """
    if n == 1:
        return torch.arange(num_timesteps)
    return torch.cat(
        [
            torch.arange(0, num_timesteps - 1, n),
            torch.tensor([num_timesteps - 1]),
        ]
    )

def graph_dynamics(
    num_samples: int,
    source_distribution: Sampleable,
    simulator: Simulator,
    density: Density,
    timesteps: torch.Tensor,
    plot_every: int,
    bins: int,
    scale: float
):
    """
    Plot the evolution of samples from source under the simulation scheme given by simulator (itself a discretization of an ODE or SDE).
    Args:
        - num_samples: the number of samples to simulate
        - source_distribution: distribution from which we draw initial samples at t=0
        - simulator: the discertized simulation scheme used to simulate the dynamics
        - density: the target density
        - timesteps: the timesteps used by the simulator
        - plot_every: number of timesteps between consecutive plots
        - bins: number of bins for imshow
        - scale: scale for imshow
    """
    # Simulate
    x0 = source_distribution.sample(num_samples)
    xts = simulator.simulate_with_trajectory(x0, timesteps)
    indices_to_plot = every_nth_index(len(timesteps), plot_every)
    plot_timesteps = timesteps[indices_to_plot]
    plot_xts = xts[:,indices_to_plot]

    # Graph
    fig, axes = plt.subplots(2, len(plot_timesteps), figsize=(8*len(plot_timesteps), 16))
    axes = axes.reshape((2,len(plot_timesteps)))
    for t_idx in range(len(plot_timesteps)):
        t = plot_timesteps[t_idx].item()
        xt = plot_xts[:,t_idx]
        # Scatter axes
        scatter_ax = axes[0, t_idx]
        imshow_density(density, bins, scale, scatter_ax, vmin=-15, alpha=0.25, cmap=plt.get_cmap('Blues'))
        scatter_ax.scatter(xt[:,0].cpu(), xt[:,1].cpu(), marker='x', color='black', alpha=0.75, s=15)
        scatter_ax.set_title(f'Samples at t={t:.1f}', fontsize=15)
        scatter_ax.set_xticks([])
        scatter_ax.set_yticks([])

        # Kdeplot axes
        kdeplot_ax = axes[1, t_idx]
        imshow_density(density, bins, scale, kdeplot_ax, vmin=-15, alpha=0.5, cmap=plt.get_cmap('Blues'))
        sns.kdeplot(x=xt[:,0].cpu(), y=xt[:,1].cpu(), alpha=0.5, ax=kdeplot_ax,color='grey')
        kdeplot_ax.set_title(f'Density of Samples at t={t:.1f}', fontsize=15)
        kdeplot_ax.set_xticks([])
        kdeplot_ax.set_yticks([])
        kdeplot_ax.set_xlabel("")
        kdeplot_ax.set_ylabel("")

    plt.show()
# Construct the simulator
target = GaussianMixture.random_2D(nmodes=5, std=0.75, scale=15.0, seed=3.0).to(device)
sde = LangevinSDE(sigma = 0.6, density = target)
simulator = EulerMaruyamaSimulator(sde)

# Graph the results!
graph_dynamics(
    num_samples = 1000,
    source_distribution = Gaussian(mean=torch.zeros(2), cov=20 * torch.eye(2)).to(device),
    simulator=simulator,
    density=target,
    timesteps=torch.linspace(0,5.0,1000).to(device),
    plot_every=334,
    bins=200,
    scale=15
)
100%|██████████| 999/999 [00:04<00:00, 237.20it/s]
<Figure size 3200x1600 with 8 Axes>

Your job: Try varying the value of sigma, the number and range of the simulation steps, the source distribution, and the target density. What do you notice? Why?

Your answer:

Answer

The sample cloud tends to relax toward the distribution used to define the score. In physics language, Langevin dynamics mixes toward its stationary distribution; in generative-model language, the learned score field steers noisy samples back toward the data manifold / high-probability region.

Optional. The next two cells make an animation. They require a working ffmpeg installation. If your environment does not already have it, you may need to install it separately (for example through conda-forge).

try:
    from celluloid import Camera
except ImportError:
    Camera = None
    print('Optional animation skipped: `celluloid` is not installed in this environment.')

from IPython.display import HTML
from matplotlib import animation as mpl_animation


def animate_dynamics(
    num_samples: int,
    source_distribution: Sampleable,
    simulator: Simulator,
    density: Density,
    timesteps: torch.Tensor,
    animate_every: int,
    bins: int,
    scale: float,
    save_path: str = 'dynamics_animation.gif'
):
    """Plot the evolution of samples from source under the simulation scheme given by simulator."""
    if Camera is None:
        return HTML('<b>Optional animation skipped:</b> install <code>celluloid</code> to render this cell.')

    x0 = source_distribution.sample(num_samples)
    xts = simulator.simulate_with_trajectory(x0, timesteps)
    indices_to_animate = every_nth_index(len(timesteps), animate_every)
    animate_timesteps = timesteps[indices_to_animate]
    animate_xts = xts[:, indices_to_animate]

    fig, axes = plt.subplots(1, 2, figsize=(16, 8))
    camera = Camera(fig)
    for t_idx in range(len(animate_timesteps)):
        xt = animate_xts[:, t_idx]
        scatter_ax = axes[0]
        imshow_density(density, bins, scale, scatter_ax, vmin=-15, alpha=0.25, cmap=plt.get_cmap('Blues'))
        scatter_ax.scatter(xt[:, 0].cpu(), xt[:, 1].cpu(), marker='x', color='black', alpha=0.75, s=15)
        scatter_ax.set_title('Samples')
        scatter_ax.set_xticks([])
        scatter_ax.set_yticks([])

        kdeplot_ax = axes[1]
        imshow_density(density, bins, scale, kdeplot_ax, vmin=-15, alpha=0.5, cmap=plt.get_cmap('Blues'))
        sns.kdeplot(x=xt[:, 0].cpu(), y=xt[:, 1].cpu(), alpha=0.5, ax=kdeplot_ax, color='grey')
        kdeplot_ax.set_title('Density of Samples', fontsize=15)
        kdeplot_ax.set_xticks([])
        kdeplot_ax.set_yticks([])
        kdeplot_ax.set_xlabel('')
        kdeplot_ax.set_ylabel('')
        camera.snap()

    animation = camera.animate()
    if mpl_animation.writers.is_available('ffmpeg'):
        animation.save(save_path)
        result = HTML(animation.to_html5_video())
    else:
        result = HTML(animation.to_jshtml())
    plt.close()
    return result
# OPTIONAL CELL
# Construct the simulator
target = GaussianMixture.random_2D(nmodes=5, std=0.75, scale=15.0, seed=3.0).to(device)
sde = LangevinSDE(sigma=0.6, density=target)
simulator = EulerMaruyamaSimulator(sde)

animate_dynamics(
    num_samples=1000,
    source_distribution=Gaussian(mean=torch.zeros(2), cov=20 * torch.eye(2)).to(device),
    simulator=simulator,
    density=target,
    timesteps=torch.linspace(0, 5.0, 1000).to(device),
    bins=200,
    scale=15,
    animate_every=100,
)
100%|██████████| 999/999 [00:05<00:00, 196.65it/s]
Loading...

Question 3.2: Ornstein--Uhlenbeck as a special case of Langevin dynamics

We now make the connection completely explicit.

Recall:

  • Langevin dynamics:

    dXt=12σ2logp(Xt)dt+σdWt,X0=x0,dX_t = \frac{1}{2}\sigma^2\nabla \log p(X_t)\,dt + \sigma\,dW_t,\qquad X_0=x_0,
  • Ornstein--Uhlenbeck:

    dXt=θXtdt+σdWt,X0=x0.dX_t = -\theta X_t\,dt + \sigma\,dW_t,\qquad X_0=x_0.

This exercise shows that the OU process is exactly Langevin dynamics for a Gaussian target distribution. That is the cleanest possible example of the more general principle:

score = force-like field that defines how samples flow toward equilibrium.

Your job: Show that when

p(x)=N ⁣(0,σ22θ),p(x)=\mathcal{N}\!\left(0,\frac{\sigma^2}{2\theta}\right),

the score is

logp(x)=2θσ2x.\nabla \log p(x) = -\frac{2\theta}{\sigma^2}x.

Hint. The Gaussian density is

p(x)=θσπexp ⁣(x2θσ2).p(x)=\frac{\sqrt{\theta}}{\sigma\sqrt{\pi}}\exp\!\left(-\frac{x^2\theta}{\sigma^2}\right).

Your answer:

Answer

From the hint,

\log p(x)= -\frac{\theta}{\sigma^2}x^2 + C.
Therefore
\nabla \log p(x)=\frac{d}{dx}\log p(x)= -\frac{2\theta}{\sigma^2}x.

This is exactly the linear score field associated with a harmonic basin.

Your job: Conclude that when

p(x)=N ⁣(0,σ22θ),p(x)=\mathcal{N}\!\left(0,\frac{\sigma^2}{2\theta}\right),

the Langevin dynamics

dXt=12σ2logp(Xt)dt+σdWtdX_t = \frac{1}{2}\sigma^2\nabla \log p(X_t)\,dt + \sigma dW_t

is equivalent to the OU process

dXt=θXtdt+σdWt,X0=0.dX_t = -\theta X_t\,dt + \sigma\,dW_t,\qquad X_0=0.

Your answer:

Answer

Substitute the score from the previous part into the Langevin drift:

\frac{1}{2}\sigma^2\left(-\frac{2\theta}{\sigma^2}x\right) = -\theta x.
So the Langevin SDE becomes exactly the OU process.

Takeaway for diffusion models

In chemistry, one often starts from an energy and obtains forces by differentiation.

In score-based generative modeling, we reverse the perspective:

  • we learn a force-like field (the score),

  • use it inside an SDE or ODE,

  • and thereby steer noise into structured samples.

That is the conceptual bridge from Brownian motion and Langevin dynamics to diffusion models for molecules, crystals, and materials configurations.

Transition: from Langevin intuition to diffusion models

Everything above was about dynamics on known vector fields or known densities. Diffusion models flip the problem around:

  • we define a simple forward corruption process,

  • we learn the reverse denoising direction from data,

  • and we sample by integrating that learned reverse-time dynamics.

The next sections compare three closely related views of that idea.

Why the names get confusing

These families overlap, so it is easy for them to blur together:

  • NCSN / score-based models learn the score field xlogpt(x)\nabla_x \log p_t(x) directly and sample with annealed Langevin updates.

  • DDPM learns a denoiser or noise predictor on a discrete forward diffusion chain and samples stochastically.

  • DDIM usually uses the same trained DDPM network, but swaps in a deterministic reverse sampler.

So the key differences are not just in the network, but in the forward process, the training target, and the sampler.

Part 3: A single toy dataset for all three model families

To keep the geometry visible, we will use one small 2D dataset throughout: a five-mode Gaussian mixture arranged like a cross.

That gives us a nice teaching setup:

  • it is simple enough to plot directly,

  • it trains quickly on CPU or Colab,

  • and bad samplers fail in a very visible way.

What to notice

  • If a model trains well, samples should cluster around the five modes.

  • If the reverse process is poor, samples either blur into the middle or explode away from the modes.

  • Because the dataset is shared, differences in the plots mostly come from the model family rather than from the data.

import math


toy_centers = torch.tensor(
    [
        [-2.5, 0.0],
        [2.5, 0.0],
        [0.0, 2.5],
        [0.0, -2.5],
        [0.0, 0.0],
    ],
    dtype=torch.float32,
    device=device,
)

NUM_DDPM_STEPS = 50
DDPM_BETAS = torch.linspace(1e-4, 0.035, NUM_DDPM_STEPS, device=device)
DDPM_ALPHAS = 1.0 - DDPM_BETAS
DDPM_ALPHA_BAR = torch.cumprod(DDPM_ALPHAS, dim=0)
NCSN_SIGMAS = torch.tensor([1.0, 0.6, 0.35, 0.2, 0.12, 0.07], dtype=torch.float32, device=device)


def sample_cross_data(num_samples: int, std: float = 0.25, seed: int = None):
    if seed is not None:
        torch.manual_seed(seed)
    idx = torch.randint(0, len(toy_centers), (num_samples,), device=device)
    return toy_centers[idx] + std * torch.randn(num_samples, 2, device=device)


def ddpm_q_sample(x0: torch.Tensor, t_idx: torch.Tensor, noise: torch.Tensor = None):
    if noise is None:
        noise = torch.randn_like(x0)
    alpha_bar_t = DDPM_ALPHA_BAR[t_idx].unsqueeze(1)
    xt = torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1.0 - alpha_bar_t) * noise
    return xt, noise


class SinusoidalTimeEmbedding(torch.nn.Module):
    def __init__(self, dim: int = 32):
        super().__init__()
        self.dim = dim

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        half_dim = self.dim // 2
        freqs = torch.exp(torch.linspace(math.log(1.0), math.log(40.0), half_dim, device=t.device))
        angles = t * freqs.view(1, -1)
        return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)


class TinyDiffusionMLP(torch.nn.Module):
    def __init__(self, hidden: int = 96, time_dim: int = 32):
        super().__init__()
        self.time_embed = SinusoidalTimeEmbedding(time_dim)
        self.net = torch.nn.Sequential(
            torch.nn.Linear(2 + time_dim, hidden),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden, hidden),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden, hidden),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden, 2),
        )

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        return self.net(torch.cat([x, self.time_embed(t)], dim=-1))


@torch.no_grad()
def plot_point_cloud(ax, points: torch.Tensor, title: str, centers: torch.Tensor = None):
    pts = points.detach().cpu()
    ax.scatter(pts[:, 0], pts[:, 1], s=9, alpha=0.45)
    if centers is None:
        centers = toy_centers
    if centers is not None:
        c = centers.detach().cpu()
        ax.scatter(c[:, 0], c[:, 1], s=50, marker='x', color='black', linewidth=1.2)
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')


@torch.no_grad()
def plot_score_quiver(ax, model: torch.nn.Module, noise_value: float, title: str, grid_scale: float = 4.5):
    grid = torch.linspace(-grid_scale, grid_scale, 17, device=device)
    xx, yy = torch.meshgrid(grid, grid, indexing='xy')
    pts = torch.stack([xx.reshape(-1), yy.reshape(-1)], dim=-1)
    t = torch.full((pts.shape[0], 1), noise_value, device=device)
    scores = model(pts, t).detach().cpu()
    ax.quiver(
        pts[:, 0].cpu().numpy(),
        pts[:, 1].cpu().numpy(),
        scores[:, 0].numpy(),
        scores[:, 1].numpy(),
        angles='xy',
        scale_units='xy',
        scale=10,
        width=0.003,
        alpha=0.75,
        color='tab:blue',
    )
    background = sample_cross_data(700, seed=0).detach().cpu()
    ax.scatter(background[:, 0], background[:, 1], s=5, alpha=0.08, color='black')
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')


@torch.no_grad()
def plot_reverse_paths(ax, history, title: str, max_paths: int = 10):
    for idx in range(min(max_paths, history[0].shape[0])):
        path = torch.stack([frame[idx] for frame in history], dim=0).cpu()
        ax.plot(path[:, 0], path[:, 1], linewidth=1.0, alpha=0.7)
        ax.scatter(path[0, 0], path[0, 1], color='black', s=14)
        ax.scatter(path[-1, 0], path[-1, 1], color='tab:red', s=16)
    ax.scatter(toy_centers[:, 0].cpu(), toy_centers[:, 1].cpu(), marker='x', color='black', s=45)
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')

Forward corruption preview

Before training anything, let us look at what the DDPM-style forward process actually does.

For a discrete timestep tt, the closed-form corruption is

xt=αˉtx0+1αˉtϵ,ϵN(0,I).x_t = \sqrt{\bar{\alpha}_t}\,x_0 + \sqrt{1-\bar{\alpha}_t}\,\epsilon, \qquad \epsilon \sim \mathcal{N}(0, I).

What to notice

  • Early timesteps keep most of the original structure.

  • Late timesteps look much closer to an isotropic Gaussian cloud.

  • This is why the reverse sampler can start from simple noise: the forward process was designed to end there.

torch.manual_seed(0)
forward_data = sample_cross_data(900)
preview_steps = [0, 10, 25, 49]

fig, axes = plt.subplots(1, len(preview_steps) + 1, figsize=(15, 3.6))
plot_point_cloud(axes[0], forward_data, 'clean data')
for ax, step_idx in zip(axes[1:], preview_steps):
    t_idx = torch.full((forward_data.shape[0],), step_idx, dtype=torch.long, device=device)
    xt, _ = ddpm_q_sample(forward_data, t_idx)
    plot_point_cloud(ax, xt, f't = {step_idx}')
plt.suptitle('The same data distribution becomes progressively easier as noise increases', y=1.05)
plt.tight_layout()
plt.show()
<Figure size 1500x360 with 5 Axes>

Part 3A: NCSN and score-based models

A noise-conditional score network learns the score of progressively noisier versions of the data:

sθ(x,σ)xlogpσ(x).s_\theta(x, \sigma) \approx \nabla_x \log p_\sigma(x).

In practice we sample a clean point x0x_0, add Gaussian noise,

x(σ)=x0+σϵ,ϵN(0,I),x^{(\sigma)} = x_0 + \sigma \epsilon, \qquad \epsilon \sim \mathcal{N}(0, I),

and train the network to predict the denoising direction

xlogp(x(σ)x0)=x0x(σ)σ2=ϵσ.\nabla_x \log p(x^{(\sigma)} \mid x_0) = \frac{x_0 - x^{(\sigma)}}{\sigma^2} = -\frac{\epsilon}{\sigma}.

How it works

  • The network is told the noise level explicitly.

  • At large σ\sigma, the score field is broad and global.

  • At small σ\sigma, the score field becomes local and sharp.

  • Sampling uses annealed Langevin dynamics: alternate score steps and fresh noise injections while gradually lowering σ\sigma.

How it differs from DDPM

  • NCSN is usually described as learning the score field directly.

  • DDPM is usually described as learning a noise predictor or denoiser on a discrete forward chain.

  • The two views are mathematically close, but the training objective and the sampler are presented differently.

def train_ncsn_model(num_steps: int = 1000, batch_size: int = 384, lr: float = 1.2e-3):
    model = TinyDiffusionMLP().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    losses = []
    for _ in range(num_steps):
        x0 = sample_cross_data(batch_size)
        idx = torch.randint(0, len(NCSN_SIGMAS), (batch_size,), device=device)
        sigma = NCSN_SIGMAS[idx].unsqueeze(1)
        noise = torch.randn_like(x0)
        xt = x0 + sigma * noise
        target = -noise / sigma
        t = idx.float().unsqueeze(1) / (len(NCSN_SIGMAS) - 1)
        pred = model(xt, t)
        loss = ((pred - target) ** 2 * sigma ** 2).mean()
        opt.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.item())
    return model, losses


@torch.no_grad()
def sample_ncsn(model: torch.nn.Module, num_samples: int = 1024, base_step: float = 1.6e-4, steps_per_level: int = 60):
    x = NCSN_SIGMAS[0] * torch.randn(num_samples, 2, device=device)
    for idx, sigma in enumerate(NCSN_SIGMAS):
        t = torch.full((num_samples, 1), idx / (len(NCSN_SIGMAS) - 1), device=device)
        step_size = base_step * float((sigma / NCSN_SIGMAS[-1]) ** 2)
        for _ in range(steps_per_level):
            score = model(x, t)
            x = x + step_size * score + math.sqrt(2.0 * step_size) * torch.randn_like(x)
    return x


torch.manual_seed(7)
ncsn_model, ncsn_losses = train_ncsn_model()
ncsn_samples = sample_ncsn(ncsn_model)
fig, axes = plt.subplots(2, 2, figsize=(11, 8))
axes[0, 0].plot(ncsn_losses)
axes[0, 0].set_title('NCSN training loss')
axes[0, 0].set_xlabel('training step')
axes[0, 0].set_ylabel('weighted DSM loss')
axes[0, 0].grid(alpha=0.25)

plot_score_quiver(axes[0, 1], ncsn_model, noise_value=0.0, title='score field at the lowest noise level')
plot_score_quiver(axes[1, 0], ncsn_model, noise_value=1.0, title='score field at the highest noise level')
plot_point_cloud(axes[1, 1], ncsn_samples, 'NCSN samples after annealed Langevin')
plt.tight_layout()
plt.show()
<Figure size 1100x800 with 4 Axes>

Exercise: why condition the score network on noise level?

Task. In one sentence, explain why the NCSN input is (x,σ)(x, \sigma) rather than only xx.

Answer

Because the optimal score field changes with the noise level: low-noise data needs a sharp local denoising field, while high-noise data needs a broader global direction field.

Part 3B: DDPM

DDPM starts from a discrete forward Markov chain

q(xtxt1)=N(αtxt1,βtI),αt=1βt,q(x_t \mid x_{t-1}) = \mathcal{N}(\sqrt{\alpha_t}\,x_{t-1},\, \beta_t I), \qquad \alpha_t = 1 - \beta_t,

which implies the useful closed-form corruption rule

xt=αˉtx0+1αˉtϵ.x_t = \sqrt{\bar{\alpha}_t}\,x_0 + \sqrt{1 - \bar{\alpha}_t}\,\epsilon.

Instead of predicting the score directly, the standard DDPM network is trained to predict the injected noise:

ϵθ(xt,t)ϵ,\epsilon_\theta(x_t, t) \approx \epsilon,

with the mean-squared error objective

LDDPM=E[ϵϵθ(xt,t)2].\mathcal{L}_{\mathrm{DDPM}} = \mathbb{E}\bigl[\|\epsilon - \epsilon_\theta(x_t, t)\|^2\bigr].

How it works

  • The forward process is a discrete chain with many small steps.

  • The network predicts the noise at a randomly chosen step.

  • Sampling walks backward through the chain and reintroduces a stochastic term at each reverse step.

How it differs from NCSN

  • DDPM is tied to a specific discrete noising schedule.

  • The network is usually phrased as a noise predictor rather than a direct score predictor.

  • The reverse process is a learned stochastic reverse diffusion chain rather than annealed Langevin dynamics.

def train_ddpm_model(num_steps: int = 900, batch_size: int = 384, lr: float = 1.5e-3):
    model = TinyDiffusionMLP().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    losses = []
    for _ in range(num_steps):
        x0 = sample_cross_data(batch_size)
        t_idx = torch.randint(0, NUM_DDPM_STEPS, (batch_size,), device=device)
        xt, noise = ddpm_q_sample(x0, t_idx)
        t = t_idx.float().unsqueeze(1) / (NUM_DDPM_STEPS - 1)
        pred = model(xt, t)
        loss = torch.mean((pred - noise) ** 2)
        opt.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.item())
    return model, losses


@torch.no_grad()
def sample_ddpm_or_ddim(
    model: torch.nn.Module,
    num_samples: int = 1024,
    eta: float = 1.0,
    x_init: torch.Tensor = None,
    return_history: bool = False,
    record_every: int = 5,
):
    if x_init is None:
        x = torch.randn(num_samples, 2, device=device)
    else:
        x = x_init.clone().to(device)
        num_samples = x.shape[0]

    history = []
    if return_history:
        history.append(x.detach().cpu())

    for step_idx in reversed(range(NUM_DDPM_STEPS)):
        t = torch.full((num_samples, 1), step_idx / (NUM_DDPM_STEPS - 1), device=device)
        alpha_bar_t = DDPM_ALPHA_BAR[step_idx]
        eps_pred = model(x, t)
        x0_pred = (x - torch.sqrt(1.0 - alpha_bar_t) * eps_pred) / torch.sqrt(alpha_bar_t)

        if step_idx == 0:
            x = x0_pred
        else:
            alpha_bar_prev = DDPM_ALPHA_BAR[step_idx - 1]
            sigma_t = eta * torch.sqrt((1.0 - alpha_bar_prev) / (1.0 - alpha_bar_t) * DDPM_BETAS[step_idx])
            direction = torch.sqrt(torch.clamp(1.0 - alpha_bar_prev - sigma_t ** 2, min=1e-8)) * eps_pred
            x = torch.sqrt(alpha_bar_prev) * x0_pred + direction
            if eta > 0:
                x = x + sigma_t * torch.randn_like(x)

        if return_history and (step_idx % record_every == 0 or step_idx == 0):
            history.append(x.detach().cpu())

    if return_history:
        return x, history
    return x


torch.manual_seed(11)
ddpm_model, ddpm_losses = train_ddpm_model()
ddpm_samples = sample_ddpm_or_ddim(ddpm_model, eta=1.0)
fig, axes = plt.subplots(1, 2, figsize=(10.5, 4.2))
axes[0].plot(ddpm_losses)
axes[0].set_title('DDPM training loss')
axes[0].set_xlabel('training step')
axes[0].set_ylabel('noise-prediction MSE')
axes[0].grid(alpha=0.25)
plot_point_cloud(axes[1], ddpm_samples, 'DDPM samples (stochastic reverse process)')
plt.tight_layout()
plt.show()
<Figure size 1050x420 with 2 Axes>

Exercise: what does a DDPM network actually predict?

Task. When people say that DDPM predicts the noise, what quantity are they referring to?

Answer

They mean the Gaussian perturbation \epsilon used in the forward corruption rule x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\,\epsilon.

Part 3C: DDIM

DDIM is easiest to understand as a different sampler built on top of the same DDPM-trained denoiser.

The network is still trained with the DDPM objective. What changes is the reverse update: DDIM removes the extra stochastic term and follows a deterministic path from the same initial noise.

What changes

  • Training: unchanged from DDPM.

  • Network: unchanged from DDPM.

  • Sampler: stochastic when η>0\eta > 0, deterministic when η=0\eta = 0.

That is why DDIM is often described as a way to trade some diversity for faster, cleaner, or more reproducible sampling.

torch.manual_seed(11)
shared_start_noise = torch.randn(256, 2, device=device)
shared_ddpm_samples, ddpm_history = sample_ddpm_or_ddim(
    ddpm_model,
    eta=1.0,
    x_init=shared_start_noise,
    return_history=True,
    record_every=6,
)
shared_ddim_samples, ddim_history = sample_ddpm_or_ddim(
    ddpm_model,
    eta=0.0,
    x_init=shared_start_noise,
    return_history=True,
    record_every=6,
)
fig, axes = plt.subplots(2, 2, figsize=(10.5, 8))
plot_point_cloud(axes[0, 0], shared_start_noise, 'shared starting noise', centers=None)
plot_point_cloud(axes[0, 1], shared_ddpm_samples, 'DDPM endpoints from that noise')
plot_reverse_paths(axes[1, 0], ddpm_history, 'DDPM reverse paths (noise injected each step)')
plot_reverse_paths(axes[1, 1], ddim_history, 'DDIM reverse paths (deterministic once initialized)')
plt.tight_layout()
plt.show()
<Figure size 1050x800 with 4 Axes>

Exercise: why can DDPM and DDIM share the same trained model?

Task. In one or two sentences, explain why DDIM does not need a separately trained network.

Answer

Because DDIM changes the reverse-time integration rule, not the training target. The same denoiser or noise predictor learned by DDPM can be plugged into a deterministic reverse update.

Part 3D: What actually differs?

Here is the cleanest way to separate the three families.

Model familyWhat the network learnsForward processReverse samplerGood mental model
NCSN / score-basedscore sθ(x,σ)s_\theta(x, \sigma)direct Gaussian perturbations at chosen noise levelsannealed Langevin dynamicslearn the vector field directly
DDPMnoise predictor ϵθ(xt,t)\epsilon_\theta(x_t, t) or denoiserdiscrete Markov diffusion chainstochastic reverse diffusionmany tiny denoising steps
DDIMsame network as DDPMsame forward training setup as DDPMdeterministic reverse pathDDPM without the random jitter

A practical way to remember it

  • If you want the most direct score-field story, think NCSN.

  • If you want the standard discrete diffusion training recipe, think DDPM.

  • If you want to reuse a DDPM model but sample deterministically, think DDIM.

Final comparison exercise

Task. Which pair differs only in the sampler, and which pair differs in both training language and sampler?

Answer

DDPM and DDIM differ only in the sampler. NCSN differs from DDPM/DDIM in both the training framing and the reverse-time sampler.

Part 4: Bridge to crystals

The crystal notebooks reuse the same core diffusion template:

  • define a forward corruption process on a crystal representation,

  • train a network to predict the reverse denoising direction,

  • and sample by walking backward from an easy noise distribution.

What changes in the crystal setting is not the broad logic of diffusion, but the representation: atomic species, fractional coordinates, and lattice geometry replace the small 2D point clouds in this notebook.

Final exercise

Task. Write one sentence explaining why this 2D comparison is still useful before moving to crystals.

Answer

It isolates the model-family differences between NCSN, DDPM, and DDIM, so the crystal notebook can focus on representation and conditioning instead of basic diffusion mechanics.