{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ddmms/camml-tutorials/blob/main/notebooks/05-generative/simple-diffusion.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# A Minimal Diffusion Model from Scratch\n", "\n", "## Outline\n", "\n", "* Setup & Imports\n", "* Generate Simple Data (e.g. 2D Gaussian blobs or MNIST)\n", "* Forward Process (Adding Noise)\n", "* Reverse Process (Learning to Denoise)\n", "* Sampling from Noise\n", "* Visualization & Insights" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, TensorDataset\n", "from tqdm import tqdm\n", "import matplotlib.animation as animation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate Simple Data\n", "\n", "Let’s use 2D Gaussians so we can visualize easily:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Generate 2D Gaussian data (2 clusters)\n", "def generate_data(n=1000):\n", " mean1 = [2, 2]\n", " mean2 = [-2, -2]\n", " cov = [[0.1, 0], [0, 0.1]]\n", " data1 = np.random.multivariate_normal(mean1, cov, n // 2)\n", " data2 = np.random.multivariate_normal(mean2, cov, n // 2)\n", " data = np.vstack([data1, data2])\n", " np.random.shuffle(data)\n", " return torch.tensor(data, dtype=torch.float32)\n", "\n", "data = generate_data()\n", "plt.scatter(data[:, 0], data[:, 1], alpha=0.5)\n", "plt.title(\"Toy 2D Data\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Forward Process (Diffusion)\n", "\n", "We define a noise schedule and simulate noisy steps." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "T = 100 # number of diffusion steps\n", "beta = torch.linspace(1e-4, 0.02, T)\n", "alpha = 1. - beta\n", "alpha_hat = torch.cumprod(alpha, dim=0)\n", "\n", "def q_sample(x0, t):\n", " noise = torch.randn_like(x0)\n", " sqrt_alpha_hat = alpha_hat[t].sqrt().unsqueeze(1)\n", " sqrt_one_minus = (1 - alpha_hat[t]).sqrt().unsqueeze(1)\n", " return sqrt_alpha_hat * x0 + sqrt_one_minus * noise, noise" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "sc = ax.scatter(data[:, 0], data[:, 1], alpha=0.5)\n", "ax.set_xlim(-4, 4)\n", "ax.set_ylim(-4, 4)\n", "\n", "frames = []\n", "for t in range(0, T, 5):\n", " xt, _ = q_sample(data, torch.tensor([t] * data.shape[0]))\n", " frames.append(xt.numpy())\n", "\n", "def update(frame):\n", " sc.set_offsets(frame)\n", " return sc,\n", "\n", "ani = animation.FuncAnimation(fig, update, frames=frames, interval=100)\n", "plt.close()\n", "from IPython.display import HTML\n", "HTML(ani.to_jshtml())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reverse Process (Train a Denoiser)\n", "\n", "A simple MLP to predict noise:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class DenoiseMLP(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(2 + 1, 64),\n", " nn.ReLU(),\n", " nn.Linear(64, 64),\n", " nn.ReLU(),\n", " nn.Linear(64, 2),\n", " )\n", " \n", " def forward(self, x, t):\n", " t_embed = t.float().unsqueeze(1) / T\n", " x_input = torch.cat([x, t_embed], dim=1)\n", " return self.net(x_input)\n", "\n", "model = DenoiseMLP()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training Loop" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset = TensorDataset(data)\n", "loader = DataLoader(dataset, batch_size=128, shuffle=True)\n", "\n", "for epoch in range(100):\n", " for batch, in loader:\n", " t = torch.randint(0, T, (batch.shape[0],))\n", " x_t, noise = q_sample(batch, t)\n", " pred_noise = model(x_t, t)\n", " loss = F.mse_loss(pred_noise, noise)\n", " \n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " \n", " if epoch % 10 == 0:\n", " print(f\"Epoch {epoch}: loss = {loss.item():.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sample\n", "\n", "The function `p_sample` implements one reverse diffusion step:\n", "\n", "Given a noisy input $x_t$ at timestep $t$, the model predicts the noise $\\epsilon_\\theta(x_t, t)$ and estimates the mean of the original clean sample $x_0$.\n", " \n", "From the diffusion process:\n", " \n", "$$\n", "x_t = \\sqrt{\\bar{\\alpha}_t} x_0 + \\sqrt{1 - \\bar{\\alpha}_t} \\epsilon [0]\n", "$$\n", " \n", "Solving for $x_0$:\n", "\n", "$$\n", "\\hat{x}_0 = \\frac{1}{\\sqrt{\\bar{\\alpha}_t}} (x_t - \\sqrt{1 - \\bar{\\alpha}_t} \\cdot \\epsilon_\\theta(x_t, t)) [1]\n", "$$\n", "\n", "We can then define the mean for the reverse distribution $p(x_{t-1} | x_t)$ as:\n", "\n", "$$\n", "\\mu_t(x_t, \\epsilon_\\theta) = \\frac{1}{\\sqrt{\\alpha_t}} \\left(x_t - \\frac{\\beta_t}{\\sqrt{1 - \\bar{\\alpha}_t}} \\cdot \\epsilon_\\theta(x_t, t)\\right) [2]\n", "$$\n", "\n", "And sample from:\n", " \n", "$$\n", "x_{t-1} \\sim \\mathcal{N}(\\mu_t, \\sigma_t^2 I), \\quad \\text{where } \\sigma_t^2 = \\beta_t [3]\n", "$$\n", "This implements the denoising step for one timestep." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@torch.no_grad()\n", "def p_sample(model, x, t):\n", " beta_t = beta[t]\n", " sqrt_one_minus = (1 - alpha_hat[t]).sqrt()\n", " sqrt_recip_alpha = (1. / alpha[t]).sqrt()\n", "\n", " pred_noise = model(x, torch.tensor([t] * x.shape[0]))\n", " x0_pred = (x - sqrt_one_minus * pred_noise) / alpha_hat[t].sqrt() # Eqn 1 above\n", " mean = sqrt_recip_alpha * (x - beta_t / sqrt_one_minus * pred_noise) # Eqn 2 above\n", " if t > 0:\n", " z = torch.randn_like(x)\n", " else:\n", " z = 0\n", " return mean + beta_t.sqrt() * z # Eqn 3 above\n", "\n", "@torch.no_grad()\n", "def sample(model, n_samples=1000, return_trajectory=False):\n", " x = torch.randn(n_samples, 2)\n", " trajectory = [x.clone()]\n", " for t in reversed(range(T)):\n", " x = p_sample(model, x, t)\n", " if return_trajectory:\n", " trajectory.append(x.clone())\n", " if return_trajectory:\n", " return x, trajectory\n", " return x\n", "\n", "samples, trajectory = sample(model, return_trajectory=True)\n", "\n", "plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)\n", "plt.title(\"Generated Samples\")\n", "plt.axis(\"equal\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualizing the Reverse Process" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "sc = ax.scatter([], [], alpha=0.5)\n", "ax.set_xlim(-4, 4)\n", "ax.set_ylim(-4, 4)\n", "ax.set_title(\"Reverse Process\")\n", "ax.set_aspect(\"equal\")\n", "\n", "def animate(i):\n", " x = trajectory[i] # reverse order\n", " sc.set_offsets(x.numpy())\n", " ax.set_title(f\"Step {i}/{T}\")\n", " return sc,\n", "\n", "ani = animation.FuncAnimation(fig, animate, frames=T, interval=60, blit=True)\n", "plt.close()\n", "HTML(ani.to_jshtml())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(trajectory)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "alignn-2", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 2 }