Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Crystal Diffusion From Scratch: MatterGen-light, taught step by step

Open in Colab

Source repo for Colab bootstrap and helper downloads: https://gitlab.com/cam-ml/tutorials/-/tree/main/notebooks/05-generative

Crystal Diffusion From Scratch: MatterGen-light, taught step by step

This notebook is the hands-on capstone for the generative workshop. It follows the same style as the introduction notebook: inspect the data first, predict what a model should do, run a compact implementation, and then critique the outputs scientifically.

Aims

You will build a compact, MatterGen-like workflow that:

  1. curates a crystal dataset directly from Materials Project,

  2. turns crystals into a ChemGraph-like batch format,

  3. corrupts coordinates, lattice, and atom types with separate forward processes,

  4. trains an unconditional base model,

  5. adds density + band-gap conditioning with an adapter-style fine-tune,

  6. samples crystals and inspects diffusion trajectories.

Learning outcomes

By the end you should be able to:

  1. explain what parts of a crystal need separate corruption processes,

  2. read and manipulate a flattened crystal-graph batch representation,

  3. train a small unconditional crystal denoiser,

  4. add conditional signals without rebuilding the whole model,

  5. interpret unconditional samples, conditioned samples, and diffusion-path plots as complementary diagnostics.

How to use this notebook

  1. Pause before each heavy section and predict what the next diagnostic should reveal.

  2. Use quick mode for the first full pass, then rerun only the sections you want to explore.

  3. Keep one running note with three columns: representation, corruption, diagnostic.

  4. Treat every generated structure as a hypothesis that still needs screening, not as a guaranteed discovery.

Runtime expectations

  • quick mode is the best first pass: it uses a small binary-crystal dataset, a smaller denoiser, and shorter training loops so the whole workflow stays notebook-scale.

  • full mode keeps the same teaching path but broadens the chemistry and training budget, so it is better once the notebook is already working on your machine.

  • On CPU, quick mode is usually in the 10--20 minute range once dependencies are installed and the dataset is cached; full mode can easily take roughly twice that.

  • A live Materials Project query adds a short data-fetch step; the bundled fallback stays local.

Key sources and papers

This notebook is a pedagogical MatterGen-light build rather than an official reproduction. The main sources behind the data/API path and the crystal-diffusion design choices are:

The notebook keeps the same broad decomposition as production crystal diffusion systems, with separate treatment of atom identity, periodic coordinates, and lattice geometry, but scales the model and dataset down so the full workflow stays readable and runnable in a teaching setting.

Task for you

  • Write down which crystal quantities are continuous and which are discrete before you reach the corruption section.

0) Install the small dependency set

This cell installs the notebook dependencies. After it runs once, you can usually ignore it.

It also installs py3Dmol, which is only used if you switch the structure viewer from the default static previews to the optional interactive mode later in the notebook.

# @title
# Colab install cell.
# PyTorch usually already exists in Colab, so we only install the extras we use.
!pip install uv
!uv pip -q install mp-api pymatgen ase pandas matplotlib tqdm pillow monty py3Dmol ipywidgets

1) Setup and connect to Materials Project

This section does three things:

  • imports the packages we need,

  • chooses CPU or GPU,

  • asks for your Materials Project API key.

Why start here? Because this notebook queries Materials Project directly rather than reading a bundled archive. The dataset step is part of the workflow.

Design choice: that keeps the notebook closer to a real research workflow, where the data query is part of the experiment.

If you do not provide an API key, the notebook falls back to a bundled small Materials Project mini-corpus stored in the repo. That means the notebook still teaches from real MP structures rather than a synthetic toy set.

If mp_api imports cleanly, the notebook uses it directly. If that import fails in a lightweight environment, the notebook falls back to direct REST requests against the same Materials Project API.

Try later: after the notebook works once, rerun only the dataset section with a different chemistry or structure family.

Fastest key paths:

  • paste the key into the short form cell just below for this notebook session,

  • or set MP_API_KEY / PMG_MAPI_KEY in your environment,

  • or store a Colab secret named MP_API_KEY.

The code below is intentionally short and practical. Read it as:

  • imports,

  • seed everything for reproducibility,

  • choose device,

  • fetch API key from the quick form, your environment, or a Colab secret.

# @title Optional: add your Materials Project API key for this session
import os

MP_API_KEY_INPUT = ""  # @param {type:"string"}


def _set_session_mp_key(raw_key: str) -> None:
    global _SESSION_MP_API_KEY
    key = raw_key.strip()
    previous_key = globals().get("_SESSION_MP_API_KEY", "")

    if key:
        os.environ["MP_API_KEY"] = key
        os.environ["PMG_MAPI_KEY"] = key
        _SESSION_MP_API_KEY = key
        return

    if previous_key:
        if os.environ.get("MP_API_KEY") == previous_key:
            os.environ.pop("MP_API_KEY", None)
        if os.environ.get("PMG_MAPI_KEY") == previous_key:
            os.environ.pop("PMG_MAPI_KEY", None)
    _SESSION_MP_API_KEY = ""


def _describe_session_mp_key() -> str:
    key = globals().get("_SESSION_MP_API_KEY", "")
    if key:
        return "Stored a Materials Project API key for this notebook session: " + "*" * 8 + key[-4:]
    return (
        "Leave this blank to use MP_API_KEY / PMG_MAPI_KEY from your environment "
        "or a Colab secret named MP_API_KEY."
    )


existing_key = (
    globals().get("_SESSION_MP_API_KEY")
    or os.environ.get("MP_API_KEY")
    or os.environ.get("PMG_MAPI_KEY")
    or MP_API_KEY_INPUT.strip()
)

try:
    import google.colab  # type: ignore
    _MP_KEY_IN_COLAB = True
except Exception:
    _MP_KEY_IN_COLAB = False

if _MP_KEY_IN_COLAB:
    _set_session_mp_key(MP_API_KEY_INPUT)
    print(_describe_session_mp_key())
else:
    try:
        import ipywidgets as widgets
        from IPython.display import display
    except Exception:
        widgets = None

    if widgets is None:
        _set_session_mp_key(existing_key)
        print(_describe_session_mp_key())
        print("Install `ipywidgets` to get an in-notebook password field in Jupyter.")
    else:
        try:
            _format_widget_pre = format_widget_pre
        except NameError:
            import html as _html

            def _format_widget_pre(text: str) -> str:
                return f"<pre style='white-space:pre-wrap; margin:0'>{_html.escape(text)}</pre>"

        try:
            _bind_widget_state = bind_widget_state
        except NameError:
            def _bind_widget_state(controls, apply_fn):
                state_holder = {"has_rendered": False, "last": None}

                def refresh(change=None):
                    state = {name: control.value for name, control in controls.items()}
                    state_key = tuple((name, repr(value)) for name, value in state.items())
                    if state_holder["has_rendered"] and state_holder["last"] == state_key:
                        return
                    state_holder["has_rendered"] = True
                    state_holder["last"] = state_key
                    apply_fn(**state)

                refresh()
                for control in controls.values():
                    control.observe(refresh, names="value")
                return refresh

        mp_key_widget = widgets.Password(
            value=existing_key,
            description="MP key:",
            placeholder="Paste your Materials Project API key",
            layout=widgets.Layout(width="520px"),
            style={"description_width": "70px"},
        )
        mp_key_help = widgets.HTML(
            "<small>Jupyter users can paste a key here for this notebook session only.</small>"
        )
        mp_key_status = widgets.HTML()
        display(widgets.VBox([mp_key_widget, mp_key_help, mp_key_status]))

        def _refresh_mp_key(raw_key):
            _set_session_mp_key(raw_key)
            mp_key_status.value = _format_widget_pre(_describe_session_mp_key())

        _bind_widget_state({"raw_key": mp_key_widget}, _refresh_mp_key)
# @title
import os
import io
import html
import math
import copy
import json
import hashlib
import time
import random
import re
import itertools
import warnings
from getpass import getpass
from pathlib import Path
from dataclasses import dataclass
import sys
import types

# Some lightweight environments ship without the optional bz2 extension.
# A few imports in the crystal stack transitively touch `bz2`, so we provide
# a small compatibility shim when the compiled module is unavailable.
try:
    import bz2  # type: ignore
except Exception:
    def _bz2_unavailable(*args, **kwargs):
        raise RuntimeError("bz2 compression is unavailable in this Python build")

    bz2 = types.ModuleType("bz2")
    bz2.BZ2File = _bz2_unavailable
    bz2.compress = _bz2_unavailable
    bz2.decompress = _bz2_unavailable
    bz2.open = _bz2_unavailable
    sys.modules["bz2"] = bz2

try:
    import lzma  # type: ignore
except Exception:
    lzma = types.ModuleType("lzma")
    lzma.LZMAFile = _bz2_unavailable
    lzma.compress = _bz2_unavailable
    lzma.decompress = _bz2_unavailable
    sys.modules["lzma"] = lzma


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import requests
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from pymatgen.core import Structure, Lattice, Element, Composition
from pymatgen.io.cif import CifWriter
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
try:
    from mp_api.client import MPRester
except Exception:
    MPRester = None
from monty.serialization import dumpfn, loadfn

from ase.visualize.plot import plot_atoms

try:
    import py3Dmol  # type: ignore
except Exception:
    py3Dmol = None

try:
    import ipywidgets as widgets
except Exception:
    widgets = None

try:
    import google.colab  # type: ignore
    IN_COLAB = True
except Exception:
    IN_COLAB = False

from IPython.display import display, HTML, Markdown, Image as IPythonImage

warnings.filterwarnings("ignore")

SEED = 7
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

def get_mp_api_key() -> str:
    key = os.environ.get("PMG_MAPI_KEY") or os.environ.get("MP_API_KEY")
    if key:
        os.environ["PMG_MAPI_KEY"] = key
        return key

    # Optional Colab secret named MP_API_KEY.
    try:
        from google.colab import userdata  # type: ignore
        key = userdata.get("MP_API_KEY")
    except Exception:
        key = None

    if not key:
        print(
            "No Materials Project API key found; using the bundled small Materials Project fallback corpus. "
            "Fill the form cell above or set MP_API_KEY to switch to live Materials Project data."
        )
        return ""

    os.environ["PMG_MAPI_KEY"] = key
    return key

MP_API_KEY = get_mp_api_key()
if MP_API_KEY:
    print("Using Materials Project key ending with:", "*" * 8 + MP_API_KEY[-4:])
else:
    print("No Materials Project API key detected; the bundled Materials Project fallback corpus will be used.")


def format_widget_pre(text: str) -> str:
    return f"<pre style='white-space:pre-wrap; margin:0'>{html.escape(text)}</pre>"


def bind_widget_state(controls, apply_fn):
    state_holder = {"has_rendered": False, "last": None}

    def refresh(change=None):
        state = {name: control.value for name, control in controls.items()}
        state_key = tuple((name, repr(value)) for name, value in state.items())
        if state_holder["has_rendered"] and state_holder["last"] == state_key:
            return
        state_holder["has_rendered"] = True
        state_holder["last"] = state_key
        apply_fn(**state)

    refresh()
    for control in controls.values():
        control.observe(refresh, names="value")
    return refresh

Choose quick or full notebook mode

Use quick for a first pass through the entire notebook. It focuses on a small binary-crystal corpus so the chemistry, plots, and generated samples stay easy to read. Use full once you want a stronger demo with broader chemistry and a larger denoiser.

The mode controls the default dataset size, chemistry scope, batch size, network width, and training budgets. You can still override any of those values later if you want to explore manually.

# @title Choose quick or full notebook mode
DEMO_MODE = "quick"  # @param ["quick", "full"]

DEMO_PRESETS = {
    "quick": {
        "chemistry_scope": "2_elements",
        "max_structures": 1000,
        "max_atoms": 5,
        "batch_size": 16,
        "hidden_dim": 64,
        "num_blocks": 4,
        "base_epochs": 75,
        "base_patience": 25,
        "adapter_epochs": 75,
        "adapter_patience": 25,
    },
    "full": {
        "chemistry_scope": "2_to_4_elements",
        "max_structures": 4000,
        "max_atoms": 20,
        "batch_size": 64,
        "hidden_dim": 128,
        "num_blocks": 6,
        "base_epochs": 75,
        "base_patience": 25,
        "adapter_epochs": 75,
        "adapter_patience": 25,
    },
}


def describe_demo_mode() -> str:
    lines = [f"Notebook mode: {DEMO_MODE}"]
    for key, value in DEMO_CONFIG.items():
        lines.append(f"  {key}: {value}")
    lines.append(f"  force_mp_refresh: {FORCE_MP_REFRESH}")
    return "\n".join(lines)


def apply_demo_mode(mode: str, announce: bool = True) -> str:
    global DEMO_MODE
    global DEMO_CONFIG, CHEMISTRY_SCOPE, MAX_STRUCTURES, MAX_ATOMS, FORCE_MP_REFRESH
    global BATCH_SIZE, HIDDEN_DIM, NUM_BLOCKS, BASE_MAX_EPOCHS, BASE_PATIENCE
    global ADAPTER_MAX_EPOCHS, ADAPTER_PATIENCE

    if mode not in DEMO_PRESETS:
        raise ValueError(f"Unknown demo mode: {mode}")

    DEMO_MODE = mode
    DEMO_CONFIG = DEMO_PRESETS[DEMO_MODE].copy()
    CHEMISTRY_SCOPE = DEMO_CONFIG["chemistry_scope"]
    MAX_STRUCTURES = DEMO_CONFIG["max_structures"]
    MAX_ATOMS = DEMO_CONFIG["max_atoms"]
    FORCE_MP_REFRESH = False
    BATCH_SIZE = DEMO_CONFIG["batch_size"]
    HIDDEN_DIM = DEMO_CONFIG["hidden_dim"]
    NUM_BLOCKS = DEMO_CONFIG["num_blocks"]
    BASE_MAX_EPOCHS = DEMO_CONFIG["base_epochs"]
    BASE_PATIENCE = DEMO_CONFIG["base_patience"]
    ADAPTER_MAX_EPOCHS = DEMO_CONFIG["adapter_epochs"]
    ADAPTER_PATIENCE = DEMO_CONFIG["adapter_patience"]

    summary = describe_demo_mode()
    if announce:
        print(summary)
    return summary


if IN_COLAB or widgets is None:
    print(apply_demo_mode(DEMO_MODE, announce=False))
    if not IN_COLAB and widgets is None:
        print("Install `ipywidgets` to use a live mode selector in Jupyter.")
else:
    demo_mode_widget = widgets.Dropdown(
        options=[("quick", "quick"), ("full", "full")],
        value=DEMO_MODE,
        description="Mode:",
        style={"description_width": "70px"},
        layout=widgets.Layout(width="260px"),
    )
    demo_mode_help = widgets.HTML(
        "<small>Switch modes in Jupyter without editing the code cell manually.</small>"
    )
    demo_mode_status = widgets.HTML()
    display(widgets.VBox([demo_mode_widget, demo_mode_help, demo_mode_status]))

    def _refresh_demo_mode(mode):
        demo_mode_status.value = format_widget_pre(apply_demo_mode(mode, announce=False))

    bind_widget_state({"mode": demo_mode_widget}, _refresh_demo_mode)

2) Build a small crystal corpus from Materials Project

Dataset creation stays deliberately simple.

The idea

We fetch a broad, stable, non-elemental crystal set from Materials Project and keep only a small number of structures so the rest of the notebook stays fast enough for Colab.

If a live API key is available, this section queries Materials Project directly and caches the curated result. It uses mp_api when available and otherwise falls back to direct REST requests to the same endpoint. If no key is available, it loads a bundled small real-MP fallback dataset from the repository and adapts it to the current filters.

You only choose:

  1. how many distinct elements typical crystals should have,

  2. how many structures you want in the teaching corpus,

  3. the maximum number of atoms per crystal,

  4. whether to use cached results if they already exist.

That is enough to get a realistic miniature dataset without turning the notebook into a database-engineering exercise.

The hidden implementation cell below contains the full Materials Project helper code.

You do not need to read it on your first pass.

For now, you can just: choose the chemistry scope, choose the dataset size, and fetch the corpus.

#@title Materials Project dataset helpers (implementation details) { display-mode: "form" }
from pathlib import Path
import os
import subprocess
import sys

DAY5_SOURCE_REPO_URL = "https://gitlab.com/cam-ml/tutorials.git"
DAY5_SOURCE_REPO_BRANCH = "main"
DAY5_COLAB_CLONE_CANDIDATES = [
    Path("/content/tutorials"),
    Path("/content/cam_ml_tutorials"),
    Path("/content/camml-tutorials"),
]


def _running_in_colab():
    try:
        import google.colab  # type: ignore
        return True
    except Exception:
        return False


def _unique_paths(paths):
    unique = []
    seen = set()
    for path in paths:
        path = Path(path)
        key = str(path)
        if key not in seen:
            seen.add(key)
            unique.append(path)
    return unique


def _iter_day5_search_roots():
    cwd = Path.cwd().resolve()
    roots = [cwd, *cwd.parents]
    for clone_dir in DAY5_COLAB_CLONE_CANDIDATES:
        roots.extend([clone_dir, clone_dir / "notebooks" / "05-generative"])
    return _unique_paths(roots)


def _register_day5_notebook_root(notebook_root: Path):
    notebook_root = notebook_root.resolve()
    if str(notebook_root) not in sys.path:
        sys.path.insert(0, str(notebook_root))
    try:
        os.chdir(notebook_root)
    except OSError:
        pass
    return notebook_root


def ensure_day5_helpers_on_path():
    for candidate in _iter_day5_search_roots():
        for notebook_root in (candidate, candidate / "notebooks" / "05-generative"):
            helper_dir = notebook_root / "gen_helpers"
            if helper_dir.exists():
                return _register_day5_notebook_root(notebook_root)

    if _running_in_colab():
        for clone_dir in DAY5_COLAB_CLONE_CANDIDATES:
            notebook_root = clone_dir / "notebooks" / "05-generative"
            if notebook_root.exists():
                return _register_day5_notebook_root(notebook_root)

        for clone_dir in DAY5_COLAB_CLONE_CANDIDATES:
            if clone_dir.exists():
                continue
            clone_dir.parent.mkdir(parents=True, exist_ok=True)
            print(
                "Cloning the Day 5 tutorial repo from "
                f"{DAY5_SOURCE_REPO_URL} into {clone_dir} so notebook helper modules are available..."
            )
            subprocess.run(
                [
                    "git",
                    "clone",
                    "--depth",
                    "1",
                    "--branch",
                    DAY5_SOURCE_REPO_BRANCH,
                    DAY5_SOURCE_REPO_URL,
                    str(clone_dir),
                ],
                check=True,
            )
            notebook_root = clone_dir / "notebooks" / "05-generative"
            if notebook_root.exists():
                return _register_day5_notebook_root(notebook_root)

        raise FileNotFoundError(
            "Could not find or clone notebooks/05-generative inside /content for this Colab session."
        )

    raise FileNotFoundError(
        "Could not locate notebooks/05-generative/gen_helpers. If you are in Colab, rerun this cell so the repo can be cloned automatically."
    )

GEN_HELPERS_ROOT = ensure_day5_helpers_on_path()

from gen_helpers.crystal_dataset_helpers import *

Choose a small teaching dataset

Keep this part light. The goal is not to perfectly curate materials data; it is to get a clean, varied mini-corpus that makes the later diffusion model easy to study.

We use a single broad general query, and the main lever is how many distinct elements a crystal is allowed to have.

#@title Choose the teaching dataset
# Most users only need the quick/full mode cell above.
# Edit the four variables below only if you want to override that preset manually.

CHEMISTRY_SCOPE = globals().get("CHEMISTRY_SCOPE", "2_to_4_elements")
MAX_STRUCTURES = int(globals().get("MAX_STRUCTURES", 96))
MAX_ATOMS = int(globals().get("MAX_ATOMS", 16))
FORCE_MP_REFRESH = bool(globals().get("FORCE_MP_REFRESH", False))

DATASET_CFG = make_general_dataset_config(
    chemistry_scope=CHEMISTRY_SCOPE,
    max_structures=MAX_STRUCTURES,
    max_atoms=MAX_ATOMS,
    force_refresh=FORCE_MP_REFRESH,
)

MP_CACHE_PATH = dataset_cache_path_from_config(DATASET_CFG)

print("Preset mode:", globals().get("DEMO_MODE", "manual"))
print("Dataset config:")
for k, v in DATASET_CFG.items():
    if "query_" in k:
        continue
    print(f"  {k}: {v}")
print("cache path:", MP_CACHE_PATH)

A good way to explore is:

  • start with DEMO_MODE = "quick" and the default 2-to-4-element corpus,

  • switch to DEMO_MODE = "full" once the notebook is working and you want a stronger denoiser,

  • if you specifically want chemically simpler live-MP runs, then try binaries by setting CHEMISTRY_SCOPE = "2_elements",

  • then compare how the dataset histograms and generated structures change as you broaden the chemistry again.

After each change, look at the dataset histograms and ask: did the property range broaden or narrow?

Fetch, cache, and summarize

This cell does four practical things:

  1. builds one broad MP query,

  2. downloads a manageable pool of summary docs,

  3. curates them into one small corpus,

  4. caches the result locally so reruns are fast.

# @title
if "Path" not in globals():
    from pathlib import Path

if "DEFAULT_MP_CACHE_DIR" not in globals():
    DEFAULT_MP_CACHE_DIR = Path("mp_curated_cache")
    DEFAULT_MP_CACHE_DIR.mkdir(parents=True, exist_ok=True)

if "DATASET_CFG" not in globals():
    if "make_general_dataset_config" not in globals():
        raise NameError(
            "DATASET_CFG is not defined because the Materials Project helper cell has not been run yet. "
            "Please run the hidden 'Materials Project dataset helpers' cell first."
        )
    DATASET_CFG = make_general_dataset_config()

if "MP_CACHE_PATH" not in globals() or MP_CACHE_PATH is None:
    if "dataset_cache_path_from_config" in globals():
        MP_CACHE_PATH = dataset_cache_path_from_config(DATASET_CFG)
    else:
        MP_CACHE_PATH = DEFAULT_MP_CACHE_DIR / "mp_general_dataset_fallback.json.gz"

FORCE_MP_REFRESH = bool(globals().get("FORCE_MP_REFRESH", DATASET_CFG.get("force_refresh", False)))
MAX_ATOMS = int(DATASET_CFG.get("max_atoms", 16))
MAX_STRUCTURES = int(DATASET_CFG.get("max_structures", 96))

print("Using dataset cache:", MP_CACHE_PATH)
print("Bundled fallback dataset:", BUNDLED_MP_FALLBACK_PATH)

records = None
skipped = {}
mp_query_used = {}
cache_payload = None
cache_source = None
used_cache = False
mp_probe_result = None
mp_probe_error = None

if MP_API_KEY:
    try:
        mp_probe_result = probe_materials_project_connection(MP_API_KEY)
        print(f"Live Materials Project probe succeeded ({mp_probe_result} document(s) returned).")
    except Exception as exc:
        mp_probe_error = exc
        print("Live Materials Project probe failed; the notebook will use cached or bundled data if needed.")
        print(f"Reason: {exc.__class__.__name__}: {exc}")
else:
    print("No API key provided; using cached or bundled Materials Project data only.")

if Path(MP_CACHE_PATH).exists() and not FORCE_MP_REFRESH:
    cache_payload = loadfn(MP_CACHE_PATH)
    if isinstance(cache_payload, dict) and "records" in cache_payload:
        candidate_records = cache_payload["records"]
        candidate_skipped = cache_payload.get("skipped", {})
        candidate_query = cache_payload.get("mp_query_used", {})
    else:
        candidate_records = cache_payload
        candidate_skipped = {}
        candidate_query = {}

    candidate_records = ensure_record_feature_fields(list(candidate_records))
    cache_source = cache_payload_source_label(cache_payload, candidate_records)

    if cache_payload_is_real_mp(cache_payload, candidate_records):
        print(f"Loading cached curated records from {MP_CACHE_PATH} [{cache_source}]")
        records = candidate_records
        skipped = candidate_skipped
        mp_query_used = candidate_query
        used_cache = True
    elif MP_API_KEY:
        print(
            f"Existing cache at {MP_CACHE_PATH} came from {cache_source}; regenerating from live Materials Project because an API key is available."
        )
    else:
        print(
            f"Existing cache at {MP_CACHE_PATH} came from {cache_source}; replacing it with the bundled real Materials Project fallback dataset."
        )

if not used_cache:
    live_fetch_error = None
    if MP_API_KEY:
        try:
            docs, mp_query_used = fetch_general_summary_docs(api_key=MP_API_KEY, cfg=DATASET_CFG)
            print("Fetched raw summary docs:", len(docs))
            records, skipped = curate_general_summary_docs(docs, cfg=DATASET_CFG, seed=SEED)
            mp_query_used = dict(mp_query_used)
            mp_query_used["source"] = "materials_project_live_query"
        except Exception as exc:
            live_fetch_error = exc
            print("Materials Project fetch failed; using the bundled real MP fallback dataset instead.")
            print(f"Reason: {exc.__class__.__name__}: {exc}")

    if records is None:
        try:
            records, skipped, mp_query_used = load_bundled_fallback_records(DATASET_CFG)
            if not MP_API_KEY:
                print("Loaded the bundled real Materials Project fallback dataset.")
            if live_fetch_error is not None:
                mp_query_used = dict(mp_query_used)
                mp_query_used["fallback_reason"] = f"{live_fetch_error.__class__.__name__}: {live_fetch_error}"
            if mp_probe_error is not None:
                mp_query_used = dict(mp_query_used)
                mp_query_used["probe_error"] = f"{mp_probe_error.__class__.__name__}: {mp_probe_error}"
        except Exception as fallback_exc:
            print("Bundled fallback unavailable; building a synthetic last-resort corpus instead.")
            print(f"Reason: {fallback_exc.__class__.__name__}: {fallback_exc}")
            records, skipped, mp_query_used = build_synthetic_records(DATASET_CFG, seed=SEED)
            mp_query_used = dict(mp_query_used)
            mp_query_used["fallback_reason"] = f"{fallback_exc.__class__.__name__}: {fallback_exc}"

    dumpfn(
        {
            "records": records,
            "skipped": skipped,
            "mp_query_used": mp_query_used,
            "dataset_cfg": canonicalize_config_for_cache(DATASET_CFG) if "canonicalize_config_for_cache" in globals() else DATASET_CFG,
        },
        MP_CACHE_PATH,
    )
    print(f"Cached {len(records)} curated records to {MP_CACHE_PATH}")

records = ensure_record_feature_fields(list(records))
print(f"loaded {len(records)} structures")
print("dataset source:", cache_payload_source_label({"mp_query_used": mp_query_used}, records))
if skipped:
    print("skipped counts:", skipped)

if len(records) < 32:
    print(
        "Warning: only",
        len(records),
        "structures were curated. The notebook can still run, but training and sampling may be weaker. "
        "Widen CHEMISTRY_SCOPE or raise MAX_STRUCTURES if you want a larger corpus.",
    )

query_rows = [{"query_key": key, "value": value} for key, value in (mp_query_used or {}).items() if key not in {"fields"}]
if mp_probe_result is not None:
    query_rows.append({"query_key": "mp_probe_result", "value": mp_probe_result})
if query_rows:
    display(pd.DataFrame(query_rows))

summary = pd.DataFrame(
    {
        "name": [r["name"] for r in records],
        "formula": [r["formula"] for r in records],
        "anonymous_formula": [r.get("anonymous_formula", "") for r in records],
        "chemsys": [r["chemsys"] for r in records],
        "num_elements": [r.get("num_elements", np.nan) for r in records],
        "num_atoms": [r["num_atoms"] for r in records],
        "density": [r["density"] for r in records],
        "band_gap": [r.get("band_gap", np.nan) for r in records],
        "energy_above_hull": [r.get("energy_above_hull", np.nan) for r in records],
    }
)
summary.head()

The next cell is a quick sanity check.

What to look for:

  • do you have enough structures to train a tiny demo model?

  • do the formulas and atom counts look reasonable?

  • do the density and band-gap histograms have a decent spread?

A broad histogram is useful here because later we will pick quantile-based low/high targets from these natural training-set distributions.

# @title
display(summary.describe(include="all").transpose())

fig, axes = plt.subplots(1, 2, figsize=(10, 3.5))

axes[0].hist(summary["density"].dropna(), bins=24, alpha=0.7)
axes[0].set_title("Curated density distribution")
axes[0].set_xlabel("density (g/cm³)")

axes[1].hist(summary["band_gap"].dropna(), bins=24, alpha=0.7)
axes[1].set_title("Curated band-gap distribution")
axes[1].set_xlabel("band gap (eV)")

plt.tight_layout()
plt.show()

Property-space map

The one-dimensional histograms are useful, but the joint view matters too.

The scatter plot below shows how density and band gap vary together across the curated corpus, with the atom count used as a visual cue.

num_atoms_arr = np.array([int(r["num_atoms"]) for r in records], dtype=np.int32)
fig, axes = plt.subplots(1, 3, figsize=(13.5, 3.8))

sc = axes[0].scatter(
    summary["density"],
    summary["band_gap"],
    c=num_atoms_arr,
    cmap="viridis",
    s=40,
    alpha=0.85,
    edgecolor="white",
    linewidth=0.3,
)
axes[0].set_xlabel("density (g/cm³)")
axes[0].set_ylabel("band gap (eV)")
axes[0].set_title("Density vs band gap")
plt.colorbar(sc, ax=axes[0], label="number of atoms")

axes[1].hist(num_atoms_arr, bins=range(int(num_atoms_arr.min()), int(num_atoms_arr.max()) + 2), alpha=0.8, color="tab:orange")
axes[1].set_xlabel("number of atoms")
axes[1].set_ylabel("count")
axes[1].set_title("Crystal size distribution")

spacegroup_numbers = np.array([int(r.get("spacegroup_number", 0)) for r in records], dtype=np.int32)
spacegroup_counts = pd.Series(spacegroup_numbers).value_counts().sort_index()
axes[2].bar(spacegroup_counts.index.astype(int), spacegroup_counts.values, width=1.0)
axes[2].set_xlabel("space-group number")
axes[2].set_ylabel("count")
axes[2].set_title("Space-group spread")

plt.tight_layout()
plt.show()

Try this: narrow the chemistry scope to binaries and see whether the density-band-gap cloud becomes tighter or looser.

Question: if the cloud collapses too much, what tradeoff does the diffusion model lose later on?

Answer

A narrower cloud makes the task easier to fit but gives the generator less variety to learn from, so conditional generation can become less interesting and less robust.

Task for you

  • Pick one chemistry filter that you think will make the problem easier and one that will make it more interesting.

  • Before moving on, decide whether your current corpus is broad enough to support both unconditional generation and a meaningful conditioned demo.

Pause and inspect the dataset

Before training anything, ask:

  • Does the chemistry scope look like what you intended?

  • Are the atom counts small enough for Colab?

  • Do the density and band-gap histograms cover enough range to make conditioning interesting?

Try this: switch from 2_to_4_elements to 2_elements and compare the histograms.
Question: does narrowing the chemistry make generation easier, or does it also make the learning signal less diverse?

Answer

Narrowing the chemistry often makes the regression targets and composition statistics easier to fit, but it also weakens the diversity of the learning signal. The model then sees a smaller slice of crystal space, so generation can become cleaner yet less interesting and less transferable.

3) Build a MatterGen-like batch representation

MatterGen does not treat each crystal as a padded dense tensor. Instead, it flattens all atoms across the batch and keeps a vector that says which crystal each atom belongs to.

We copy that idea.

Why this matters

A batch of crystals becomes:

  • a single list of atoms,

  • atom-wise tensors like frac0 and atom_tokens0,

  • crystal-wise tensors like lattice0, scalar conditions, composition conditions, and space-group labels,

  • a batch_idx vector telling us which atoms belong to which crystal.

This is the same trick used by many graph neural networks.

Mathematical picture

If crystal bb has nbn_b atoms, the full batch has

Natoms=b=1BnbN_{\text{atoms}} = \sum_{b=1}^{B} n_b

atom nodes.
The model works on those NatomsN_{\text{atoms}} nodes, while still knowing which crystal each node came from.

The conditioning split we use

We treat the conditioning information in three pieces:

cscalar=[ρ,Eg],ccompRK,csg{1,,230},\mathbf{c}_{\text{scalar}} = [\rho, E_g], \qquad \mathbf{c}_{\text{comp}} \in \mathbb{R}^{K}, \qquad c_{\text{sg}} \in \{1,\dots,230\},

where

  • ρ\rho is density,

  • EgE_g is band gap,

  • ccomp\mathbf{c}_{\text{comp}} is a composition-fraction vector over the element vocabulary in our tiny dataset,

  • csgc_{\text{sg}} is the space-group number.

This is a nice teaching compromise: scalar properties stay simple, while composition and symmetry get their own conditioning channels.

The first short cell builds the token vocabulary for atom types. The second cell below now pulls the batch-building logic back into the notebook itself, so you can see exactly how a list of crystals becomes the flattened batch dictionary used by the graph denoiser.

# @title
all_atomic_numbers = sorted({int(z) for r in records for z in r["atomic_numbers"]})
z_to_token = {z: i for i, z in enumerate(all_atomic_numbers)}
token_to_z = {i: z for z, i in z_to_token.items()}
NUM_ATOM_CLASSES = len(all_atomic_numbers)
MASK_TOKEN = NUM_ATOM_CLASSES  # absorbing-mask diffusion state

print("number of atom classes:", NUM_ATOM_CLASSES)
print("MASK token id:", MASK_TOKEN)
print("first 20 atomic numbers in the tiny corpus:", all_atomic_numbers[:20])

lat_all = np.stack([r["lattice0"] for r in records], axis=0)
lat_mean = lat_all.mean(axis=0).astype(np.float32)
lat_std = (lat_all.std(axis=0) + 1e-6).astype(np.float32)
set_lattice_feature_stats(lat_mean, lat_std)

CONDITION_NAMES = ["density", "band_gap"]
CONDITION_INDEX = {name: i for i, name in enumerate(CONDITION_NAMES)}
MAX_SPACEGROUP_NUMBER = 230

def infer_spacegroup_info(structure: Structure):
    try:
        sga = SpacegroupAnalyzer(structure, symprec=0.1)
        return int(sga.get_space_group_number()), str(sga.get_space_group_symbol())
    except Exception:
        return 0, "unknown"

def composition_fraction_vector_from_atomic_numbers(atomic_numbers):
    vec = np.zeros(NUM_ATOM_CLASSES, dtype=np.float32)
    atomic_numbers = np.asarray(atomic_numbers, dtype=np.int64)
    unique_z, counts = np.unique(atomic_numbers, return_counts=True)
    total = max(int(counts.sum()), 1)
    for z, count in zip(unique_z.tolist(), counts.tolist()):
        if int(z) in z_to_token:
            vec[z_to_token[int(z)]] = float(count) / float(total)
    return vec

cond_all = np.stack(
    [np.array([r["density"], r["band_gap"]], dtype=np.float32) for r in records],
    axis=0,
)
cond_mean = cond_all.mean(axis=0).astype(np.float32)
cond_std = (cond_all.std(axis=0) + 1e-6).astype(np.float32)

composition_all = []
spacegroup_numbers = []

for r in records:
    r["atom_tokens0"] = np.array([z_to_token[int(z)] for z in r["atomic_numbers"]], dtype=np.int64)
    r["lattice0_norm"] = ((r["lattice0"] - lat_mean) / lat_std).astype(np.float32)

    r["conditions"] = np.array([r["density"], r["band_gap"]], dtype=np.float32)
    r["conditions_norm"] = ((r["conditions"] - cond_mean) / cond_std).astype(np.float32)

    comp_vec = composition_fraction_vector_from_atomic_numbers(r["atomic_numbers"])
    r["composition_cond"] = comp_vec.astype(np.float32)
    composition_all.append(comp_vec)

    sg_num, sg_symbol = infer_spacegroup_info(r["structure"])
    r["spacegroup_number"] = int(np.clip(sg_num, 0, MAX_SPACEGROUP_NUMBER))
    r["spacegroup_symbol"] = sg_symbol
    spacegroup_numbers.append(r["spacegroup_number"])

composition_mean = np.stack(composition_all, axis=0).mean(axis=0).astype(np.float32)
composition_mean = composition_mean / np.clip(composition_mean.sum(), 1e-8, None)

spacegroup_mode = int(pd.Series(spacegroup_numbers).mode().iloc[0]) if len(spacegroup_numbers) else 1
SPACEGROUP_NUM_TO_SYMBOL = {}
for r in records:
    SPACEGROUP_NUM_TO_SYMBOL[int(r["spacegroup_number"])] = str(r["spacegroup_symbol"])
SPACEGROUP_SYMBOL_TO_NUM = {v: k for k, v in SPACEGROUP_NUM_TO_SYMBOL.items()}

rng = np.random.default_rng(SEED)
perm = rng.permutation(len(records))
val_size = min(len(records) - 1, max(8, int(len(records) * VAL_FRACTION))) if len(records) > 1 else 0
val_idx = set(perm[:val_size].tolist())

train_data = [records[i] for i in range(len(records)) if i not in val_idx]
val_data = [records[i] for i in range(len(records)) if i in val_idx]

print("train size:", len(train_data), "| val size:", len(val_data))
print("condition means:", {k: float(v) for k, v in zip(CONDITION_NAMES, cond_mean)})
print("condition stds:", {k: float(v) for k, v in zip(CONDITION_NAMES, cond_std)})
print("default space-group condition:", spacegroup_mode, SPACEGROUP_NUM_TO_SYMBOL.get(spacegroup_mode, "unknown"))

top_comp = np.argsort(-composition_mean)[: min(8, len(composition_mean))]
print(
    "mean composition prior:",
    {Element.from_Z(token_to_z[int(i)]).symbol: float(composition_mean[int(i)]) for i in top_comp if composition_mean[int(i)] > 0}
)

Task for you

  • Relate batch_idx to what you saw in graph-networks.ipynb: which tensors are atom-wise and which are crystal-wise?

  • Print one mini-batch and annotate the shapes by hand before you trust the model code.

  • If you replaced this flattened graph-style representation with naive padding, what would become slower or harder to interpret?

# @title
CRYSTAL_BATCH_TENSOR_KEYS = (
    "frac0",
    "atom_tokens0",
    "lattice0",
    "continuous_conditions",
    "composition_conditions",
    "spacegroup_conditions",
    "num_atoms",
    "batch_idx",
)


def prepare_crystal_item(record):
    return {
        "name": record["name"],
        "formula": record["formula"],
        "num_atoms": record["num_atoms"],
        "frac0": record["frac0"].copy(),
        "atom_tokens0": record["atom_tokens0"].copy(),
        "lattice0": record["lattice0_norm"].copy(),
        "continuous_conditions": record["conditions_norm"].copy(),
        "composition_conditions": record["composition_cond"].copy(),
        "spacegroup_conditions": int(record["spacegroup_number"]),
    }


def collate_crystals(records):
    items = [prepare_crystal_item(record) for record in records]
    num_atoms = torch.tensor([x["num_atoms"] for x in items], dtype=torch.long)
    frac0 = torch.tensor(np.concatenate([x["frac0"] for x in items], axis=0), dtype=torch.float32)
    atom_tokens0 = torch.tensor(np.concatenate([x["atom_tokens0"] for x in items], axis=0), dtype=torch.long)
    lattice0 = torch.tensor(np.stack([x["lattice0"] for x in items], axis=0), dtype=torch.float32)
    continuous_conditions = torch.tensor(
        np.stack([x["continuous_conditions"] for x in items], axis=0),
        dtype=torch.float32,
    )
    composition_conditions = torch.tensor(
        np.stack([x["composition_conditions"] for x in items], axis=0),
        dtype=torch.float32,
    )
    spacegroup_conditions = torch.tensor(
        [x["spacegroup_conditions"] for x in items],
        dtype=torch.long,
    )
    batch_idx = torch.repeat_interleave(torch.arange(len(items), dtype=torch.long), num_atoms)
    return {
        "frac0": frac0,
        "atom_tokens0": atom_tokens0,
        "lattice0": lattice0,
        "continuous_conditions": continuous_conditions,
        "composition_conditions": composition_conditions,
        "spacegroup_conditions": spacegroup_conditions,
        "num_atoms": num_atoms,
        "batch_idx": batch_idx,
        "names": [x["name"] for x in items],
        "formulas": [x["formula"] for x in items],
    }


def move_crystal_batch_to_device(batch, device):
    moved = {key: batch[key].to(device) for key in CRYSTAL_BATCH_TENSOR_KEYS}
    moved["names"] = batch["names"]
    moved["formulas"] = batch["formulas"]
    return moved


BATCH_SIZE = int(globals().get("BATCH_SIZE", 12))
train_loader = DataLoader(
    train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_crystals,
)
val_loader = DataLoader(
    val_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_crystals,
)

batch_example = next(iter(train_loader))
print(
    batch_example["frac0"].shape,
    batch_example["atom_tokens0"].shape,
    batch_example["lattice0"].shape,
    batch_example["continuous_conditions"].shape,
    batch_example["composition_conditions"].shape,
    batch_example["spacegroup_conditions"].shape,
)

Try this: print a small batch and verify that atoms from the same crystal share the same batch_idx.
Question: why is this representation more natural than padding every crystal up to the maximum number of atoms?

Answer

The flattened representation keeps every atom as a real node rather than a padded placeholder, so computation scales with the true number of atoms instead of the maximum size in the batch. It also matches graph-message-passing libraries naturally, because batch_idx tells the model which crystal each atom belongs to without wasting memory on padding.

4) Visualize a few real crystals with structure viewers

Before building the model, it is useful to reconnect the tensors to actual structures.

The default path uses static ASE previews because they render reliably in Colab, JupyterLab, and exported notebooks. If you want to rotate the crystals interactively, you can switch the viewer mode to py3Dmol in the small form cell below.

Try this: compare one dense structure and one more open structure. Can you already guess which one might have the higher density just from the packing?

Use the small form cell below to choose between:

  • static: the safest option for Colab, JupyterLab, and exported notebooks,

  • py3Dmol: an interactive 3D viewer with rotation and zoom.

If py3Dmol is unavailable for any reason, the notebook falls back to the static previews automatically.

# @title Viewer settings
STRUCTURE_VIEWER_MODE = "static"  # @param ["static", "py3Dmol"]
PY3DMOL_WIDTH = 360  # @param {type:"integer"}
PY3DMOL_HEIGHT = 300  # @param {type:"integer"}


def describe_viewer_settings() -> str:
    lines = []
    if STRUCTURE_VIEWER_MODE == "py3Dmol":
        if py3Dmol is None:
            lines.append("py3Dmol is not available in this session, so the notebook will fall back to static ASE previews.")
        else:
            lines.append("Interactive py3Dmol previews are enabled.")
    else:
        lines.append("Static ASE previews are enabled.")
    lines.append(f"Preview size: {PY3DMOL_WIDTH} x {PY3DMOL_HEIGHT}")
    return "\n".join(lines)


def apply_viewer_settings(mode: str, width: int, height: int, announce: bool = True) -> str:
    global STRUCTURE_VIEWER_MODE, PY3DMOL_WIDTH, PY3DMOL_HEIGHT

    STRUCTURE_VIEWER_MODE = mode
    PY3DMOL_WIDTH = int(width)
    PY3DMOL_HEIGHT = int(height)

    summary = describe_viewer_settings()
    if announce:
        print(summary)
    return summary


if IN_COLAB or widgets is None:
    print(apply_viewer_settings(STRUCTURE_VIEWER_MODE, PY3DMOL_WIDTH, PY3DMOL_HEIGHT, announce=False))
    if not IN_COLAB and widgets is None:
        print("Install `ipywidgets` to adjust the viewer controls from Jupyter.")
else:
    viewer_mode_widget = widgets.Dropdown(
        options=[("static", "static"), ("py3Dmol", "py3Dmol")],
        value=STRUCTURE_VIEWER_MODE,
        description="Viewer:",
        style={"description_width": "70px"},
        layout=widgets.Layout(width="260px"),
    )
    viewer_width_widget = widgets.BoundedIntText(
        value=PY3DMOL_WIDTH,
        min=180,
        max=1200,
        step=20,
        description="Width:",
        style={"description_width": "70px"},
        layout=widgets.Layout(width="200px"),
    )
    viewer_height_widget = widgets.BoundedIntText(
        value=PY3DMOL_HEIGHT,
        min=180,
        max=900,
        step=20,
        description="Height:",
        style={"description_width": "70px"},
        layout=widgets.Layout(width="200px"),
    )
    viewer_help = widgets.HTML(
        "<small>These controls work in Jupyter; Colab still uses the form values above.</small>"
    )
    viewer_status = widgets.HTML()
    display(
        widgets.VBox(
            [
                widgets.HBox([viewer_mode_widget, viewer_width_widget, viewer_height_widget]),
                viewer_help,
                viewer_status,
            ]
        )
    )

    def _refresh_viewer_settings(mode, width, height):
        viewer_status.value = format_widget_pre(apply_viewer_settings(mode, width, height, announce=False))

    bind_widget_state(
        {
            "mode": viewer_mode_widget,
            "width": viewer_width_widget,
            "height": viewer_height_widget,
        },
        _refresh_viewer_settings,
    )
# @title
from gen_helpers.crystal_display import (
    save_structures_as_cifs,
    show_structures as _show_structures,
    structure_to_ase_atoms,
)


def show_structures(
    structures,
    labels=None,
    columns=2,
    rotation="20x,30y,0z",
    dpi=180,
    viewer_mode=None,
    py3dmol_width=None,
    py3dmol_height=None,
):
    return _show_structures(
        structures,
        labels=labels,
        columns=columns,
        rotation=rotation,
        dpi=dpi,
        viewer_mode=STRUCTURE_VIEWER_MODE if viewer_mode is None else viewer_mode,
        py3dmol_width=PY3DMOL_WIDTH if py3dmol_width is None else py3dmol_width,
        py3dmol_height=PY3DMOL_HEIGHT if py3dmol_height is None else py3dmol_height,
        py3dmol_module=py3Dmol,
    )


example_ids = list(range(min(4, len(records))))
real_structures = [records[i]["structure"] for i in example_ids]
real_labels = [
    f'{records[i]["formula"]} | density={records[i]["density"]:.2f} | gap={records[i]["band_gap"]:.2f} eV'
    for i in example_ids
]
display(show_structures(real_structures, real_labels))

5) Define the three forward corruption processes

This is the core conceptual step.

We do not use one generic noise process for everything.
Instead, we use a separate corruption process for each part of the crystal:

  1. fractional coordinates,

  2. lattice parameters,

  3. atom identities.

5.1 Fractional coordinates: wrapped continuous diffusion

Fractional coordinates live on a torus, not on ordinary Euclidean space.
So we corrupt them continuously and then wrap them back into [0.5,0.5)[-0.5, 0.5):

xt=wrap ⁣(αˉtx0+1αˉtϵ)\mathbf{x}_t = \operatorname{wrap}\!\left( \sqrt{\bar\alpha_t}\,\mathbf{x}_0 + \sqrt{1-\bar\alpha_t}\,\boldsymbol\epsilon \right)

5.2 Lattice: continuous diffusion in a 6D descriptor

We represent the lattice with:

(loga,logb,logc,cosα,cosβ,cosγ)(\log a, \log b, \log c, \cos \alpha, \cos \beta, \cos \gamma)

and diffuse that continuously.

Question: Why this parameterization?

Answer
Because it is easier to normalize and keeps lengths positive after decoding.

5.3 Atom types: absorbing-mask diffusion

For atom identities, we use a simple discrete process with a special MASK token:

q(at=a0a0)=αˉt,q(at=MASKa0)=1αˉt.q(a_t = a_0 \mid a_0) = \bar\alpha_t, \qquad q(a_t = \text{MASK} \mid a_0) = 1 - \bar\alpha_t.

This is a pedagogical stand-in for MatterGen’s discrete corruption logic.

Try this: pick one real crystal and inspect the same state at an early, middle, and late timestep. Which part looks hardest to noise gracefully: positions, lattice, or atom identity?

Read the code below by matching it to the three bullets above. The main things to notice are the schedule definitions and the different corruption rules for continuous versus discrete variables.


T = 1000  # number of diffusion steps

# Continuous schedules, indexed by 1..T, with alpha_bar[0] = 1 for convenience.
cont_betas = torch.linspace(1e-4, 2e-2, T, dtype=torch.float32)
cont_alphas = 1.0 - cont_betas
cont_alpha_bars = torch.cat([torch.ones(1, dtype=torch.float32), torch.cumprod(cont_alphas, dim=0)], dim=0)

# Discrete absorbing-mask schedule. The earlier linear-beta version masked almost
# everything far too quickly, so we use a gentler keep-probability curve that only
# becomes heavily masked late in the trajectory while still ending near an absorbing state.
DISC_MASK_FINAL_KEEP = 0.05
disc_time = torch.linspace(0.0, 1.0, T + 1, dtype=torch.float32)
disc_alpha_bars = torch.exp(torch.log(torch.tensor(DISC_MASK_FINAL_KEEP, dtype=torch.float32)) * (disc_time**2))
disc_alphas = disc_alpha_bars[1:] / torch.clamp(disc_alpha_bars[:-1], min=1e-8)
disc_betas = 1.0 - disc_alphas

def extract_schedule(arr: torch.Tensor, t: torch.Tensor, batch_idx: torch.Tensor | None, target: torch.Tensor) -> torch.Tensor:
    # arr indexed by timestep in 0..T
    out = arr.to(target.device)[t]
    if batch_idx is not None:
        out = out[batch_idx]
    while out.ndim < target.ndim:
        out = out.unsqueeze(-1)
    return out

def q_sample_wrapped_coords(x0: torch.Tensor, t_graph: torch.Tensor, batch_idx: torch.Tensor):
    sqrt_ab = extract_schedule(cont_alpha_bars.sqrt(), t_graph, batch_idx, x0)
    sqrt_one_minus_ab = extract_schedule((1.0 - cont_alpha_bars).sqrt(), t_graph, batch_idx, x0)
    eps = torch.randn_like(x0)
    x_unwrapped = sqrt_ab * x0 + sqrt_one_minus_ab * eps
    x_t = wrap_centered_torch(x_unwrapped)
    # wrapped epsilon target: nearest periodic residual
    eps_target = wrap_centered_torch(x_t - sqrt_ab * x0) / torch.clamp(sqrt_one_minus_ab, min=1e-6)
    return x_t, eps_target

def q_sample_lattice(x0: torch.Tensor, t_graph: torch.Tensor):
    sqrt_ab = extract_schedule(cont_alpha_bars.sqrt(), t_graph, None, x0)
    sqrt_one_minus_ab = extract_schedule((1.0 - cont_alpha_bars).sqrt(), t_graph, None, x0)
    eps = torch.randn_like(x0)
    x_t = sqrt_ab * x0 + sqrt_one_minus_ab * eps
    return x_t, eps

def q_sample_atom_types(x0: torch.Tensor, t_graph: torch.Tensor, batch_idx: torch.Tensor):
    keep_prob = extract_schedule(disc_alpha_bars, t_graph, batch_idx, x0.float()).squeeze(-1)
    keep = (torch.rand_like(keep_prob) < keep_prob)
    x_t = torch.where(keep, x0, torch.full_like(x0, MASK_TOKEN))
    return x_t, keep

def x0_to_structure(frac_centered: np.ndarray, lattice_norm: np.ndarray, atom_tokens: np.ndarray) -> Structure:
    frac = ((frac_centered + 0.5) % 1.0).astype(np.float64)
    lattice_features = lattice_norm * lat_std + lat_mean
    lattice = features_to_lattice(lattice_features)
    species = [Element.from_Z(int(token_to_z[int(tok)])).symbol for tok in atom_tokens]
    return Structure(lattice=lattice, species=species, coords=frac, coords_are_cartesian=False)

Try this: increase or decrease T.
Question: what do you expect to happen if the number of diffusion steps is too small? Too large?

Answer

If T is too small, the forward process never reaches a genuinely high-noise regime, so the reverse model learns a shallow denoising problem and sampling can look biased or brittle. If T is too large, the learning problem becomes unnecessarily long and noisy, which slows training and can make the reverse chain harder to stabilize in a small notebook setting.

Visual intuition: corrupting one real crystal in three different ways

The forward process is easier to understand when you look at one structure and perturb it channel by channel.

The next plot compares coordinate noise, lattice noise, and atom-type masking on the same training crystal.

Task for you

  • Compare this denoiser to the earlier MLP notebook: which pieces are familiar, and which pieces only appear because the data are crystals rather than fixed-length vectors?

  • Compare it to graph-networks.ipynb: where does message passing enter, and why is it more natural than a plain MLP here?

  • Before training, predict which output head will be hardest to learn first: coordinates, lattice, or atom types.

# Use a representative crystal from the training split.
example_record = train_data[0]

@torch.no_grad()
def forward_corruption_triptych(record, timestep: int = 250):
    frac0 = torch.tensor(record["frac0"], dtype=torch.float32, device=device)
    lattice0 = torch.tensor(record["lattice0_norm"][None], dtype=torch.float32, device=device)
    atom0 = torch.tensor(record["atom_tokens0"], dtype=torch.long, device=device)
    batch_idx = torch.zeros(len(atom0), dtype=torch.long, device=device)
    t_graph = torch.tensor([int(timestep)], dtype=torch.long, device=device)

    frac_t, _ = q_sample_wrapped_coords(frac0, t_graph, batch_idx)
    lattice_t, _ = q_sample_lattice(lattice0, t_graph)
    atom_t, _ = q_sample_atom_types(atom0, t_graph, batch_idx)

    clean = x0_to_structure(record["frac0"], record["lattice0_norm"], record["atom_tokens0"])
    coord_only = x0_to_structure(frac_t.detach().cpu().numpy(), record["lattice0_norm"], record["atom_tokens0"])
    lattice_only = x0_to_structure(record["frac0"], lattice_t[0].detach().cpu().numpy(), record["atom_tokens0"])
    atom_tokens = atom_t.detach().cpu().numpy().copy()
    frac_np = frac_t.detach().cpu().numpy()
    lattice_np = lattice_t[0].detach().cpu().numpy()
    valid = atom_tokens != MASK_TOKEN
    if valid.sum() == 0:
        valid = np.zeros_like(atom_tokens, dtype=bool)
        valid[0] = True
        atom_tokens[0] = 0
    masked = x0_to_structure(frac_np[valid], lattice_np, atom_tokens[valid])
    return clean, coord_only, lattice_only, masked

clean, coord_only, lattice_only, masked = forward_corruption_triptych(example_record, timestep=250)
labels = [
    f"clean | {clean.composition.reduced_formula}",
    "coordinate noise",
    "lattice noise",
    "coordinate + lattice + atom masking",
]
display(show_structures([clean, coord_only, lattice_only, masked], labels=labels, columns=2))

Question: which corruption do you expect to be hardest to undo, and why?

Answer

Atom masking is often hardest because it removes discrete identity information, while coordinates and lattice still leave geometric clues behind.

6) Build a compact MatterGen-like denoiser

The real MatterGen denoiser is stronger than ours, but the interface is the same: the network sees a noised crystal and predicts

  • coordinate noise,

  • lattice noise,

  • atom logits.

Inputs to the denoiser

For each atom, the model sees:

  • the current noisy fractional coordinate,

  • the current noisy atom token,

  • the current noisy lattice (broadcast from crystal level to atom level),

  • the timestep embedding,

  • the number of atoms in the crystal,

  • optionally a condition embedding.

Output heads

The model predicts three things:

ϵ^coord,ϵ^lattice,^atom\hat{\epsilon}_{\text{coord}}, \qquad \hat{\epsilon}_{\text{lattice}}, \qquad \hat{\ell}_{\text{atom}}

where ^atom\hat{\ell}_{\text{atom}} are logits over atom classes.

How we encode the conditions

We now use three small condition encoders:

  1. a scalar-property encoder for density and band gap,

  2. a composition encoder for the element-fraction vector,

  3. a space-group embedding.

Those are fused into one graph-level condition vector before message passing.

Why message passing?

Atoms are not independent. Their local environment matters.
So we build a graph with periodic edges and update atom features by message passing.

A useful mental model

This denoiser is a small graph UNet without the UNet shape: it repeatedly mixes local geometry, noisy atom identity, lattice context, time information, and optional conditioning.

Try this: halve HIDDEN_DIM and rerun a short training job. What gets worse first: the losses, the samples, or both?

The next cell is intentionally collapsed because it is implementation-heavy. On first pass, you can skip it and just remember:

  1. embed atoms, coordinates, lattice, time, and atom-count information,

  2. build periodic edges,

  3. apply several interaction blocks,

  4. decode three heads.

Open it later if you want the exact message-passing details.

# @title
def build_sinusoidal_time_embedding(dim: int):
    module = nn.ModuleDict(
        {
            "mlp": nn.Sequential(
                nn.Linear(dim, dim),
                nn.SiLU(),
                nn.Linear(dim, dim),
            )
        }
    )
    module.dim = dim
    return module


def apply_sinusoidal_time_embedding(time_embed, t_scaled: torch.Tensor) -> torch.Tensor:
    half = time_embed.dim // 2
    freqs = torch.exp(
        -math.log(10000.0) * torch.arange(half, device=t_scaled.device, dtype=torch.float32) / max(half - 1, 1)
    )
    args = t_scaled.float().unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    if emb.shape[-1] < time_embed.dim:
        emb = F.pad(emb, (0, 1))
    return time_embed["mlp"](emb)


def build_interaction_block(hidden_dim: int, use_adapter: bool = False):
    block = nn.ModuleDict(
        {
            "edge_mlp": nn.Sequential(
                nn.Linear(hidden_dim * 2 + 16, hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ),
            "node_mlp": nn.Sequential(
                nn.Linear(hidden_dim * 2, hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ),
        }
    )
    block.use_adapter = use_adapter
    if use_adapter:
        block["adapter"] = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim, bias=False),
        )
        nn.init.zeros_(block["adapter"][-1].weight)
    return block


def apply_interaction_block(block, h, src, dst, edge_feat, cond_per_atom=None, use_uncond_per_atom=None):
    m = block["edge_mlp"](torch.cat([h[src], h[dst], edge_feat], dim=-1))
    agg = torch.zeros_like(h)
    agg.index_add_(0, dst, m)
    deg = torch.zeros(h.shape[0], 1, device=h.device)
    deg.index_add_(0, dst, torch.ones(dst.shape[0], 1, device=h.device))
    agg = agg / deg.clamp(min=1.0)
    h = h + block["node_mlp"](torch.cat([h, agg], dim=-1))
    if block.use_adapter and cond_per_atom is not None and use_uncond_per_atom is not None:
        adapt = block["adapter"](torch.cat([h, cond_per_atom], dim=-1))
        h = h + (~use_uncond_per_atom).float() * adapt
    return h


def build_periodic_edges(frac_t: torch.Tensor, lattice_t: torch.Tensor, num_atoms: torch.Tensor):
    src_all, dst_all, edge_all = [], [], []
    offset = 0

    for g, n in enumerate(num_atoms.tolist()):
        if n <= 1:
            offset += n
            continue

        frac_g = frac_t[offset : offset + n]
        lattice_g = lattice_t[g]

        idx = torch.arange(n, device=frac_t.device)
        ii, jj = torch.meshgrid(idx, idx, indexing="ij")
        mask = ii != jj
        ii = ii[mask]
        jj = jj[mask]

        dfrac = wrap_centered_torch(frac_g[jj] - frac_g[ii])
        dcart = torch.matmul(dfrac, lattice_g)
        dcart = torch.clamp(dcart, min=-GEOM_MAX_DIST, max=GEOM_MAX_DIST)
        dist = torch.norm(dcart, dim=-1, keepdim=True).clamp(max=GEOM_MAX_DIST)

        edge_feat = torch.cat(
            [
                dfrac,
                dcart,
                dist,
                (dist**2).clamp(max=GEOM_MAX_DIST**2),
                torch.sin(dist),
                torch.cos(dist),
                torch.exp(-dist),
            ],
            dim=-1,
        )
        if edge_feat.shape[-1] < 16:
            edge_feat = F.pad(edge_feat, (0, 16 - edge_feat.shape[-1]))
        edge_feat = edge_feat[:, :16]

        src_all.append(ii + offset)
        dst_all.append(jj + offset)
        edge_all.append(edge_feat)
        offset += n

    if len(src_all) == 0:
        src = torch.zeros(0, dtype=torch.long, device=frac_t.device)
        dst = torch.zeros(0, dtype=torch.long, device=frac_t.device)
        edge_feat = torch.zeros(0, 16, dtype=torch.float32, device=frac_t.device)
    else:
        src = torch.cat(src_all, dim=0)
        dst = torch.cat(dst_all, dim=0)
        edge_feat = torch.cat(edge_all, dim=0)

    return src, dst, edge_feat


def build_mini_mattergen_denoiser(
    num_atom_classes: int,
    hidden_dim: int = 128,
    num_blocks: int = 4,
    max_atoms: int = MAX_ATOMS,
    conditional: bool = False,
    cond_dim: int = len(CONDITION_NAMES),
):
    model = nn.ModuleDict(
        {
            "atom_embed": nn.Embedding(num_atom_classes + 1, hidden_dim),
            "coord_proj": nn.Linear(3, hidden_dim),
            "lattice_proj": nn.Linear(6, hidden_dim),
            "time_embed": build_sinusoidal_time_embedding(hidden_dim),
            "num_atoms_embed": nn.Embedding(max_atoms + 1, hidden_dim),
            "blocks": nn.ModuleList(
                [build_interaction_block(hidden_dim, use_adapter=conditional) for _ in range(num_blocks)]
            ),
            "coord_head": nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, 3),
            ),
            "atom_head": nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, num_atom_classes),
            ),
            "lattice_head": nn.Sequential(
                nn.Linear(hidden_dim * 2, hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, 6),
            ),
        }
    )
    if conditional:
        model["cond_scalar_embed"] = nn.Sequential(
            nn.Linear(cond_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        model["cond_comp_embed"] = nn.Sequential(
            nn.Linear(num_atom_classes, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        model["cond_spacegroup_embed"] = nn.Embedding(MAX_SPACEGROUP_NUMBER + 1, hidden_dim)
        model["cond_fuse"] = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        model.register_parameter("uncond_embedding", nn.Parameter(torch.zeros(1, hidden_dim)))

    model.num_atom_classes = num_atom_classes
    model.conditional = conditional
    model.hidden_dim = hidden_dim
    model.cond_dim = cond_dim
    return model


def run_mini_mattergen_denoiser(
    model,
    frac_t: torch.Tensor,
    lattice_t: torch.Tensor,
    atom_t: torch.Tensor,
    num_atoms: torch.Tensor,
    batch_idx: torch.Tensor,
    t_graph: torch.Tensor,
    continuous_cond: torch.Tensor | None = None,
    composition_cond: torch.Tensor | None = None,
    spacegroup_cond: torch.Tensor | None = None,
    use_uncond_embedding: torch.Tensor | None = None,
):
    B = num_atoms.shape[0]
    t_scaled = t_graph.float() / float(T)

    lattice_features_t = denormalize_lattice_features_torch(lattice_t)
    lattice_matrix_t = lattice_matrix_from_features_torch(lattice_features_t)

    graph_ctx = (
        apply_sinusoidal_time_embedding(model["time_embed"], t_scaled)
        + model["lattice_proj"](lattice_t)
        + model["num_atoms_embed"](num_atoms)
    )

    cond_per_graph = None
    cond_per_atom = None
    use_uncond_per_atom = None

    if model.conditional:
        if continuous_cond is None or composition_cond is None or spacegroup_cond is None:
            cond_per_graph = model.uncond_embedding.expand(B, -1)
            use_uncond_embedding = torch.ones(B, 1, dtype=torch.bool, device=frac_t.device)
        else:
            if use_uncond_embedding is None:
                use_uncond_embedding = torch.zeros(B, 1, dtype=torch.bool, device=frac_t.device)
            scalar_embed = model["cond_scalar_embed"](continuous_cond)
            comp_embed = model["cond_comp_embed"](composition_cond)
            sg_embed = model["cond_spacegroup_embed"](
                torch.clamp(spacegroup_cond.long(), min=0, max=MAX_SPACEGROUP_NUMBER)
            )
            cond_embed = model["cond_fuse"](torch.cat([scalar_embed, comp_embed, sg_embed], dim=-1))
            uncond_embed = model.uncond_embedding.expand(B, -1)
            cond_per_graph = torch.where(use_uncond_embedding, uncond_embed, cond_embed)

        graph_ctx = graph_ctx + cond_per_graph
        cond_per_atom = cond_per_graph[batch_idx]
        use_uncond_per_atom = use_uncond_embedding[batch_idx]

    h = model["atom_embed"](atom_t) + model["coord_proj"](frac_t) + graph_ctx[batch_idx]

    src, dst, edge_feat = build_periodic_edges(frac_t, lattice_matrix_t, num_atoms)
    for block in model["blocks"]:
        h = apply_interaction_block(
            block,
            h,
            src,
            dst,
            edge_feat,
            cond_per_atom=cond_per_atom,
            use_uncond_per_atom=use_uncond_per_atom,
        )

    pooled = torch.zeros(B, model.hidden_dim, device=h.device)
    pooled.index_add_(0, batch_idx, h)
    pooled = pooled / num_atoms.unsqueeze(-1)

    pred_coord_eps = model["coord_head"](h)
    pred_atom_logits = model["atom_head"](h)
    pred_lattice_eps = model["lattice_head"](torch.cat([pooled, graph_ctx], dim=-1))
    return pred_coord_eps, pred_lattice_eps, pred_atom_logits

This short visible cell sets the model size and instantiates the unconditional base model. These are the easiest architectural knobs to play with.

HIDDEN_DIM = int(globals().get("HIDDEN_DIM", 256))
NUM_BLOCKS = int(globals().get("NUM_BLOCKS", 8))

base_model = build_mini_mattergen_denoiser(
    NUM_ATOM_CLASSES,
    hidden_dim=HIDDEN_DIM,
    num_blocks=NUM_BLOCKS,
    conditional=False,
).to(device)

num_params = sum(p.numel() for p in base_model.parameters())
print(f"base model parameters: {num_params:,}")
print(f"Using HIDDEN_DIM={HIDDEN_DIM}, NUM_BLOCKS={NUM_BLOCKS}")

Try this: change HIDDEN_DIM or NUM_BLOCKS.
Question: which do you think would matter more in this smaller notebook implementation: wider layers or more message-passing steps?

Answer

Here, wider layers usually matter first because they control how much geometric and chemical information can be stored at each atom. More message-passing steps help too, but if the hidden state is too small, extra propagation just moves around a weak signal.

7) Define the multi-part training objective

At a random timestep tt, we corrupt a clean crystal and train the denoiser to undo the corruption.

Loss

We use a weighted sum of three terms:

L=λcoordLcoord+λlatLlat+λatomLatom.\mathcal{L} = \lambda_{\text{coord}} \, \mathcal{L}_{\text{coord}} + \lambda_{\text{lat}} \, \mathcal{L}_{\text{lat}} + \lambda_{\text{atom}} \, \mathcal{L}_{\text{atom}}.

For coordinates and lattice we use a robust Smooth L1 loss:

Lcoord=SmoothL1 ⁣(ϵ^coord,ϵcoord),Llat=SmoothL1 ⁣(ϵ^lat,ϵlat),\mathcal{L}_{\text{coord}} = \operatorname{SmoothL1}\!\left(\hat{\epsilon}_{\text{coord}}, \epsilon_{\text{coord}}\right), \qquad \mathcal{L}_{\text{lat}} = \operatorname{SmoothL1}\!\left(\hat{\epsilon}_{\text{lat}}, \epsilon_{\text{lat}}\right),

and for atoms we use cross-entropy on the clean atom identity:

Latom=CE ⁣(^atom,a0).\mathcal{L}_{\text{atom}} = \operatorname{CE}\!\left(\hat{\ell}_{\text{atom}}, a_0\right).

Why Smooth L1 here?

Small toy crystal datasets can produce noisy batches. Smooth L1 makes the continuous losses less brittle than plain MSE.

Try this: double ATOM_LOSS_WEIGHT for one run. Do you get cleaner compositions at the cost of worse geometry?

The short cell below contains the main user-facing training knobs. The long helper cell after it handles evaluation, EMA-smoothed logging, early stopping, and checkpointing.

P_UNCOND = 0.2
COORD_LOSS_WEIGHT = 1.0
LATTICE_LOSS_WEIGHT = 1.0
ATOM_LOSS_WEIGHT = 1.5

ROBUST_BETA = 0.25
GRAD_CLIP_NORM = 0.5
CHECKPOINT_DIR = Path("mattergen_checkpoints")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
# @title
def compute_training_losses(model: nn.Module, batch, training: bool = True):
    batch = move_crystal_batch_to_device(batch, device)
    B = batch["num_atoms"].shape[0]
    t_graph = torch.randint(1, T + 1, (B,), device=device)

    frac_t, coord_eps_target = q_sample_wrapped_coords(batch["frac0"], t_graph, batch["batch_idx"])
    lattice_t, lattice_eps_target = q_sample_lattice(batch["lattice0"], t_graph)
    atom_t, keep_mask = q_sample_atom_types(batch["atom_tokens0"], t_graph, batch["batch_idx"])

    continuous_cond = None
    composition_cond = None
    spacegroup_cond = None
    use_uncond = None

    if getattr(model, "conditional", False):
        continuous_cond = batch["continuous_conditions"]
        composition_cond = batch["composition_conditions"]
        spacegroup_cond = batch["spacegroup_conditions"]
        if training:
            use_uncond = (torch.rand(B, 1, device=device) < P_UNCOND)
        else:
            use_uncond = torch.zeros(B, 1, device=device, dtype=torch.bool)

    pred_coord_eps, pred_lattice_eps, pred_atom_logits = run_mini_mattergen_denoiser(
        model,
        frac_t=frac_t,
        lattice_t=lattice_t,
        atom_t=atom_t,
        num_atoms=batch["num_atoms"],
        batch_idx=batch["batch_idx"],
        t_graph=t_graph,
        continuous_cond=continuous_cond,
        composition_cond=composition_cond,
        spacegroup_cond=spacegroup_cond,
        use_uncond_embedding=use_uncond,
    )

    coord_loss = F.smooth_l1_loss(pred_coord_eps, coord_eps_target, beta=ROBUST_BETA)
    lattice_loss = F.smooth_l1_loss(pred_lattice_eps, lattice_eps_target, beta=ROBUST_BETA)
    atom_loss = F.cross_entropy(pred_atom_logits, batch["atom_tokens0"])

    total = (
        COORD_LOSS_WEIGHT * coord_loss
        + LATTICE_LOSS_WEIGHT * lattice_loss
        + ATOM_LOSS_WEIGHT * atom_loss
    )
    metrics = {
        "loss": float(total.detach().cpu()),
        "coord_loss": float(coord_loss.detach().cpu()),
        "lattice_loss": float(lattice_loss.detach().cpu()),
        "atom_loss": float(atom_loss.detach().cpu()),
        "masked_atom_fraction": float((atom_t == MASK_TOKEN).float().mean().detach().cpu()),
    }
    return total, metrics

@torch.no_grad()
def evaluate_model(model: nn.Module, loader: DataLoader):
    model.eval()
    logs = []
    for batch in loader:
        _, metrics = compute_training_losses(model, batch, training=False)
        logs.append(metrics)
    return pd.DataFrame(logs).mean().to_dict()

def save_training_checkpoint(
    path: Path,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler,
    epoch: int,
    history_rows,
    run_name: str,
    best_val_loss: float,
    best_epoch: int,
):
    payload = {
        "run_name": run_name,
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict() if scheduler is not None else None,
        "history": history_rows,
        "best_val_loss": best_val_loss,
        "best_epoch": best_epoch,
        "rng_state_torch": torch.get_rng_state(),
        "rng_state_numpy": {
            "bit_generator": np.random.get_state()[0],
            "state": np.random.get_state()[1].tolist(),
            "pos": int(np.random.get_state()[2]),
            "has_gauss": int(np.random.get_state()[3]),
            "cached_gaussian": float(np.random.get_state()[4]),
        },
        "rng_state_python": random.getstate(),
    }
    torch.save(payload, path)

def restore_rng_states_from_checkpoint(ckpt: dict):
    if ckpt.get("rng_state_torch") is not None:
        torch.set_rng_state(ckpt["rng_state_torch"])

    np_state = ckpt.get("rng_state_numpy")
    if isinstance(np_state, dict) and "state" in np_state:
        np.random.set_state((
            np_state["bit_generator"],
            np.array(np_state["state"], dtype=np.uint32),
            int(np_state["pos"]),
            int(np_state["has_gauss"]),
            float(np_state["cached_gaussian"]),
        ))

    py_state = ckpt.get("rng_state_python")
    if py_state is not None:
        random.setstate(py_state)

def load_model_checkpoint(
    model: nn.Module,
    path: str | Path,
    optimizer=None,
    scheduler=None,
    map_location=None,
    restore_rng_state: bool = False,
):
    ckpt = torch.load(path, map_location=map_location or device, weights_only=False)
    model.load_state_dict(ckpt["model_state"])
    if optimizer is not None and ckpt.get("optimizer_state") is not None:
        optimizer.load_state_dict(ckpt["optimizer_state"])
    if scheduler is not None and ckpt.get("scheduler_state") is not None:
        scheduler.load_state_dict(ckpt["scheduler_state"])
    if restore_rng_state:
        restore_rng_states_from_checkpoint(ckpt)
    return ckpt

def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    epochs: int,
    lr: float,
    run_name: str,
    weight_decay: float = 1e-4,
    warmup_frac: float = 0.1,
    ema_decay: float = 0.95,
    patience: int | None = 6,
    min_delta: float = 1e-3,
    checkpoint_dir: str | Path = CHECKPOINT_DIR,
    restore_best_at_end: bool = True,
):
    checkpoint_dir = Path(checkpoint_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    total_steps = max(1, epochs * len(train_loader))
    warmup_steps = max(1, int(total_steps * warmup_frac))

    def lr_lambda(step: int):
        if step < warmup_steps:
            return float(step + 1) / float(warmup_steps)
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return 0.1 + 0.9 * 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    history = []
    global_step = 0
    ema_loss = None
    best_val_loss = float("inf")
    best_epoch = 0
    epochs_since_improvement = 0
    stopped_early = False

    last_ckpt_path = checkpoint_dir / f"{run_name}_last.pt"
    best_ckpt_path = checkpoint_dir / f"{run_name}_best.pt"

    for epoch in range(1, epochs + 1):
        model.train()
        train_logs = []
        epoch_ema_values = []

        pbar = tqdm(train_loader, desc=f"epoch {epoch}/{epochs}")
        for batch in pbar:
            optimizer.zero_grad(set_to_none=True)
            loss, metrics = compute_training_losses(model, batch, training=True)

            if not torch.isfinite(loss):
                print(f"Skipping non-finite batch at epoch {epoch}")
                continue

            loss.backward()
            grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRAD_CLIP_NORM)
            optimizer.step()
            scheduler.step()
            global_step += 1

            batch_loss = float(metrics["loss"])
            ema_loss = batch_loss if ema_loss is None else (ema_decay * ema_loss + (1.0 - ema_decay) * batch_loss)
            epoch_ema_values.append(float(ema_loss))

            metrics["ema_loss"] = float(ema_loss)
            metrics["grad_norm"] = float(grad_norm.detach().cpu()) if torch.is_tensor(grad_norm) else float(grad_norm)
            metrics["lr"] = float(optimizer.param_groups[0]["lr"])
            train_logs.append(metrics)

            pbar.set_postfix(
                loss=f'{metrics["loss"]:.3f}',
                ema=f'{metrics["ema_loss"]:.3f}',
                atom=f'{metrics["atom_loss"]:.3f}',
                lr=f'{metrics["lr"]:.1e}'
            )

        if len(train_logs) == 0:
            raise RuntimeError("All training batches were skipped. Try lowering the learning rate.")

        train_mean = pd.DataFrame(train_logs).mean().to_dict()
        val_mean = evaluate_model(model, val_loader)

        row = {
            "epoch": epoch,
            "train_loss": train_mean["loss"],
            "train_loss_ema": float(np.mean(epoch_ema_values)),
            "train_coord_loss": train_mean["coord_loss"],
            "train_lattice_loss": train_mean["lattice_loss"],
            "train_atom_loss": train_mean["atom_loss"],
            "train_masked_atom_fraction": train_mean["masked_atom_fraction"],
            "val_loss": val_mean["loss"],
            "val_coord_loss": val_mean["coord_loss"],
            "val_lattice_loss": val_mean["lattice_loss"],
            "val_atom_loss": val_mean["atom_loss"],
            "val_masked_atom_fraction": val_mean["masked_atom_fraction"],
            "lr": float(optimizer.param_groups[0]["lr"]),
        }
        history.append(row)

        save_training_checkpoint(
            last_ckpt_path,
            model,
            optimizer,
            scheduler,
            epoch,
            history,
            run_name,
            best_val_loss,
            best_epoch,
        )

        improved = row["val_loss"] < (best_val_loss - min_delta)
        if improved:
            best_val_loss = row["val_loss"]
            best_epoch = epoch
            epochs_since_improvement = 0
            save_training_checkpoint(
                best_ckpt_path,
                model,
                optimizer,
                scheduler,
                epoch,
                history,
                run_name,
                best_val_loss,
                best_epoch,
            )
        else:
            epochs_since_improvement += 1

        print(
            f"epoch {epoch:03d} | "
            f"train={row['train_loss']:.3f} (ema={row['train_loss_ema']:.3f}) | "
            f"val={row['val_loss']:.3f} | "
            f"coord={row['val_coord_loss']:.3f} | "
            f"lattice={row['val_lattice_loss']:.3f} | "
            f"atom={row['val_atom_loss']:.3f}"
        )

        if patience is not None and epochs_since_improvement >= patience:
            print(f"Stopping early after {patience} epochs without a validation improvement larger than {min_delta}.")
            stopped_early = True
            break

    history_df = pd.DataFrame(history)
    if restore_best_at_end and best_ckpt_path.exists():
        load_model_checkpoint(model, best_ckpt_path, map_location=device)

    run_info = {
        "run_name": run_name,
        "best_val_loss": float(best_val_loss),
        "best_epoch": int(best_epoch),
        "stopped_early": bool(stopped_early),
        "best_checkpoint": best_ckpt_path,
        "last_checkpoint": last_ckpt_path,
    }
    return history_df, run_info

8) Train the unconditional base model

This is the first real training phase.
We train without any property conditioning so the model first learns a generic crystal distribution.

What to watch

  • train loss tells you whether optimization is working,

  • EMA-smoothed train loss is easier to read than the raw batch loss,

  • validation loss tells you whether the model is improving on held-out structures,

  • early stopping prevents wasting time after the curve plateaus.

Good first-run defaults

The quick preset keeps the denoiser, batch size, and epoch budget small enough for a first pass through the full notebook.
The full preset strengthens the model and training budget without changing the overall workflow.
If you manually widen the dataset further, it usually makes sense to raise the epoch cap again.

The cell below launches the full training loop. Read the arguments as “how long to train”, “learning rate”, and “when to stop if validation stops improving.”

Task for you

  • Read the loss curve exactly as you would in the supervised-learning notebooks: where do you see learning, plateauing, or overfitting risk?

  • Change only one knob at a time, such as learning rate or model width, and record what moves first.

  • If the model generates poor structures but the training loss looks healthy, ask whether the issue is data coverage, corruption design, or sampling rather than optimization alone.

BASE_MAX_EPOCHS = int(globals().get("BASE_MAX_EPOCHS", 200))
BASE_PATIENCE = int(globals().get("BASE_PATIENCE", 50))

base_history, base_run = train_model(
    base_model,
    train_loader,
    val_loader,
    epochs=BASE_MAX_EPOCHS,
    lr=1e-4,
    run_name="base_model",
    patience=BASE_PATIENCE,
    ema_decay=0.95,
)
print(base_run)
base_history.tail()

This plot is your first health check. Ideally, the validation loss should go down early and then level off.

plt.figure(figsize=(8, 4))
plt.plot(base_history["epoch"], base_history["train_loss"], alpha=0.35, label="train")
plt.plot(base_history["epoch"], base_history["train_loss_ema"], label="train (EMA)")
plt.plot(base_history["epoch"], base_history["val_loss"], label="val")
plt.axvline(base_run["best_epoch"], linestyle="--", alpha=0.6, label=f'best epoch={base_run["best_epoch"]}')
plt.xlabel("epoch")
plt.ylabel("loss")
plt.title("Unconditional base model learning curve")
plt.legend()
plt.show()

Reflection

If the validation loss barely moves, check the dataset and model size.
If training is unstable, reduce the learning rate or broaden the dataset slightly.

Try this: lower the learning rate to 5e-5 and compare the curve.
Question: why might a tiny, highly filtered dataset make the diffusion problem harder rather than easier?

Answer

A tiny, highly filtered dataset can actually make the problem harder because the model sees too few structural variations to learn robust denoising rules. Instead of learning a broad crystal prior, it can overfit a narrow set of motifs and become fragile when sampling starts from heavily corrupted states.

9) Reverse sampling and classifier-free guidance

Now that the base model is trained, we need the reverse diffusion loop that turns pure noise and masked atom types back into a candidate crystal.

Two uses of the same sampler

  • For the unconditional base model, each reverse step is just a single denoiser call.

  • For the conditional adapter that we train later, we reuse the same loop and add classifier-free guidance.

Guidance rule for the conditional case

Once the conditional model exists, we run the denoiser twice at each step:

  • once without the condition,

  • once with the condition.

Then we combine the predictions as

y^guided=y^uncond+s(y^condy^uncond),\hat{y}_{\text{guided}} = \hat{y}_{\text{uncond}} + s \left(\hat{y}_{\text{cond}} - \hat{y}_{\text{uncond}}\right),

where ss is the guidance scale.

Interpretation

  • s=1s = 1 means “just use the conditional prediction,”

  • larger ss pushes the sample harder toward the requested target package,

  • too large a value can produce worse or less stable structures.

Mini-exercise: in the next cell, implement that one-line guidance rule as a helper function cfg_mix(...) before you look at the full sampler.

Answer
def cfg_mix(pred_uncond, pred_cond, guidance_scale: float):
    return pred_uncond + guidance_scale * (pred_cond - pred_uncond)

The rest of the sampling implementation is a bit long, so it stays collapsed below.

Try this later: after training the adapter, sample with guidance scales 0.0, 1.0, 2.0, and 4.0. At what point does stronger guidance stop helping?

# @title
def cfg_mix(pred_uncond, pred_cond, guidance_scale: float):
    # TODO: implement classifier-free guidance mixing.
    raise NotImplementedError

@torch.no_grad()
def guided_predictions(
    model,
    frac_t: torch.Tensor,
    lattice_t: torch.Tensor,
    atom_t: torch.Tensor,
    num_atoms: torch.Tensor,
    batch_idx: torch.Tensor,
    t_graph: torch.Tensor,
    continuous_cond: torch.Tensor | None,
    composition_cond: torch.Tensor | None,
    spacegroup_cond: torch.Tensor | None,
    guidance_scale: float,
):
    if not model.conditional or continuous_cond is None or composition_cond is None or spacegroup_cond is None:
        return run_mini_mattergen_denoiser(
            model,
            frac_t=frac_t,
            lattice_t=lattice_t,
            atom_t=atom_t,
            num_atoms=num_atoms,
            batch_idx=batch_idx,
            t_graph=t_graph,
            continuous_cond=None,
            composition_cond=None,
            spacegroup_cond=None,
            use_uncond_embedding=None,
        )

    B = num_atoms.shape[0]
    uncond_mask = torch.ones(B, 1, dtype=torch.bool, device=frac_t.device)
    cond_mask = torch.zeros(B, 1, dtype=torch.bool, device=frac_t.device)

    pred_coord_u, pred_lattice_u, pred_atom_u = run_mini_mattergen_denoiser(
        model,
        frac_t=frac_t,
        lattice_t=lattice_t,
        atom_t=atom_t,
        num_atoms=num_atoms,
        batch_idx=batch_idx,
        t_graph=t_graph,
        continuous_cond=continuous_cond,
        composition_cond=composition_cond,
        spacegroup_cond=spacegroup_cond,
        use_uncond_embedding=uncond_mask,
    )
    pred_coord_c, pred_lattice_c, pred_atom_c = run_mini_mattergen_denoiser(
        model,
        frac_t=frac_t,
        lattice_t=lattice_t,
        atom_t=atom_t,
        num_atoms=num_atoms,
        batch_idx=batch_idx,
        t_graph=t_graph,
        continuous_cond=continuous_cond,
        composition_cond=composition_cond,
        spacegroup_cond=spacegroup_cond,
        use_uncond_embedding=cond_mask,
    )

    pred_coord = cfg_mix(pred_coord_u, pred_coord_c, guidance_scale)
    pred_lattice = cfg_mix(pred_lattice_u, pred_lattice_c, guidance_scale)
    pred_atom = cfg_mix(pred_atom_u, pred_atom_c, guidance_scale)
    return pred_coord, pred_lattice, pred_atom

def ddpm_reverse_step_continuous(x_t, pred_eps, t_graph, batch_idx=None, wrap=False):
    beta_t = extract_schedule(torch.cat([torch.zeros(1), cont_betas]).to(x_t.device), t_graph, batch_idx, x_t)
    alpha_t = extract_schedule(torch.cat([torch.ones(1), cont_alphas]).to(x_t.device), t_graph, batch_idx, x_t)
    alpha_bar_t = extract_schedule(cont_alpha_bars.to(x_t.device), t_graph, batch_idx, x_t)

    noise = torch.randn_like(x_t)
    mask_t = (t_graph if batch_idx is None else t_graph[batch_idx]).view(-1, *([1] * (x_t.ndim - 1)))
    noise = torch.where(mask_t > 1, noise, torch.zeros_like(noise))
    mean = (x_t - (beta_t / torch.clamp(torch.sqrt(1.0 - alpha_bar_t), min=1e-6)) * pred_eps) / torch.clamp(torch.sqrt(alpha_t), min=1e-6)
    x_prev = mean + torch.sqrt(beta_t) * noise

    if wrap:
        x_prev = wrap_centered_torch(x_prev)
    return x_prev

def discrete_reverse_step_absorbing(atom_t, pred_atom_logits, t_graph, batch_idx):
    pred_x0 = torch.distributions.Categorical(logits=pred_atom_logits).sample()

    t_per_atom = t_graph[batch_idx]
    ab_t = disc_alpha_bars.to(atom_t.device)[t_per_atom]
    ab_prev = disc_alpha_bars.to(atom_t.device)[torch.clamp(t_per_atom - 1, min=0)]

    denom = torch.clamp(1.0 - ab_t, min=1e-6)
    reveal_prob = torch.where(
        atom_t == MASK_TOKEN,
        (ab_prev - ab_t) / denom,
        torch.zeros_like(ab_t),
    ).clamp(0.0, 1.0)

    reveal = torch.rand_like(reveal_prob) < reveal_prob
    atom_prev = atom_t.clone()
    atom_prev[reveal] = pred_x0[reveal]
    return atom_prev

def sample_num_atoms(batch_size: int):
    choices = np.array([r["num_atoms"] for r in train_data], dtype=np.int64)
    sampled = np.random.choice(choices, size=batch_size, replace=True)
    return torch.tensor(sampled, dtype=torch.long, device=device)

def normalize_composition_vector(vec: np.ndarray) -> np.ndarray:
    vec = np.asarray(vec, dtype=np.float32).reshape(-1)
    vec = np.clip(vec, 0.0, None)
    total = float(vec.sum())
    if total <= 1e-8:
        return composition_mean.copy()
    return (vec / total).astype(np.float32)

def composition_vector_from_formula(formula: str) -> np.ndarray:
    comp = Composition(formula)
    vec = np.zeros(NUM_ATOM_CLASSES, dtype=np.float32)
    total = 0.0
    for el, amt in comp.get_el_amt_dict().items():
        z = Element(el).Z
        if int(z) in z_to_token:
            vec[z_to_token[int(z)]] = float(amt)
            total += float(amt)
    if total <= 1e-8:
        raise ValueError(
            f"Formula {formula!r} does not overlap with the current tiny dataset vocabulary. "
            "Choose a formula using elements seen in the curated dataset."
        )
    return normalize_composition_vector(vec)

def composition_vector_to_formula_hint(vec: np.ndarray, top_k: int = 6) -> str:
    vec = np.asarray(vec, dtype=np.float32)
    idx = np.argsort(-vec)[:top_k]
    parts = []
    for i in idx:
        if vec[int(i)] <= 1e-4:
            continue
        sym = Element.from_Z(token_to_z[int(i)]).symbol
        parts.append(f"{sym}:{float(vec[int(i)]):.2f}")
    return ", ".join(parts) if parts else "mean composition"

def resolve_spacegroup_number(target_spacegroup) -> int:
    if target_spacegroup is None:
        return int(spacegroup_mode)
    if isinstance(target_spacegroup, str):
        text = target_spacegroup.strip()
        if text.isdigit():
            return int(np.clip(int(text), 0, MAX_SPACEGROUP_NUMBER))
        if text in SPACEGROUP_SYMBOL_TO_NUM:
            return int(SPACEGROUP_SYMBOL_TO_NUM[text])
        raise ValueError(
            f"Unknown space-group symbol {target_spacegroup!r} for this tiny dataset. "
            "Use an integer 1..230 or a symbol already present in the curated records."
        )
    return int(np.clip(int(target_spacegroup), 0, MAX_SPACEGROUP_NUMBER))

def build_condition_tensors(
    batch_size: int,
    target_density: float | None = None,
    target_band_gap: float | None = None,
    target_formula: str | None = None,
    target_composition: np.ndarray | list | tuple | None = None,
    target_spacegroup: int | str | None = None,
    target_record: dict | None = None,
    target_conditions: dict | list | tuple | np.ndarray | None = None,
):
    if (
        target_conditions is None
        and target_density is None
        and target_band_gap is None
        and target_formula is None
        and target_composition is None
        and target_spacegroup is None
        and target_record is None
    ):
        return None

    continuous_values = cond_mean.astype(np.float32).copy()
    composition_values = composition_mean.astype(np.float32).copy()
    spacegroup_value = int(spacegroup_mode)

    if target_record is not None:
        continuous_values[CONDITION_INDEX["density"]] = float(target_record.get("density", continuous_values[0]))
        continuous_values[CONDITION_INDEX["band_gap"]] = float(target_record.get("band_gap", continuous_values[1]))
        if "composition_cond" in target_record:
            composition_values = normalize_composition_vector(target_record["composition_cond"])
        else:
            composition_values = composition_fraction_vector_from_atomic_numbers(target_record["atomic_numbers"])
        spacegroup_value = int(target_record.get("spacegroup_number", spacegroup_value))

    if isinstance(target_conditions, dict):
        for name, value in target_conditions.items():
            if name in CONDITION_INDEX and value is not None:
                continuous_values[CONDITION_INDEX[name]] = float(value)
    elif target_conditions is not None:
        arr = np.asarray(target_conditions, dtype=np.float32).reshape(-1)
        if len(arr) != len(CONDITION_NAMES):
            raise ValueError(f"target_conditions must have length {len(CONDITION_NAMES)}")
        continuous_values[:] = arr

    if target_density is not None:
        continuous_values[CONDITION_INDEX["density"]] = float(target_density)
    if target_band_gap is not None:
        continuous_values[CONDITION_INDEX["band_gap"]] = float(target_band_gap)

    if target_formula is not None:
        composition_values = composition_vector_from_formula(target_formula)
    if target_composition is not None:
        composition_values = normalize_composition_vector(np.asarray(target_composition, dtype=np.float32))

    if target_spacegroup is not None:
        spacegroup_value = resolve_spacegroup_number(target_spacegroup)

    continuous_norm = (continuous_values - cond_mean) / cond_std

    return {
        "continuous": torch.tensor(
            np.repeat(continuous_norm[None, :], batch_size, axis=0),
            dtype=torch.float32,
            device=device,
        ),
        "composition": torch.tensor(
            np.repeat(composition_values[None, :], batch_size, axis=0),
            dtype=torch.float32,
            device=device,
        ),
        "spacegroup": torch.full(
            (batch_size,),
            int(np.clip(spacegroup_value, 0, MAX_SPACEGROUP_NUMBER)),
            dtype=torch.long,
            device=device,
        ),
    }

def pretty_condition_target(cond_tensors: dict | None):
    if cond_tensors is None:
        return {"mode": "unconditional"}

    scalar_vals = cond_tensors["continuous"][0].detach().cpu().numpy() * cond_std + cond_mean
    comp_vals = cond_tensors["composition"][0].detach().cpu().numpy()
    sg_num = int(cond_tensors["spacegroup"][0].detach().cpu().item())
    return {
        "density": float(scalar_vals[CONDITION_INDEX["density"]]),
        "band_gap": float(scalar_vals[CONDITION_INDEX["band_gap"]]),
        "composition_hint": composition_vector_to_formula_hint(comp_vals),
        "spacegroup_number": sg_num,
        "spacegroup_symbol": SPACEGROUP_NUM_TO_SYMBOL.get(sg_num, "unknown"),
    }

@torch.no_grad()
def sample_crystals(
    model,
    batch_size: int = 8,
    target_density: float | None = None,
    target_band_gap: float | None = None,
    target_formula: str | None = None,
    target_composition: np.ndarray | list | tuple | None = None,
    target_spacegroup: int | str | None = None,
    target_record: dict | None = None,
    target_conditions: dict | list | tuple | np.ndarray | None = None,
    guidance_scale: float = 2.0,
    num_atoms: torch.Tensor | None = None,
    record: bool = False,
):
    model.eval()

    if num_atoms is None:
        num_atoms = sample_num_atoms(batch_size)
    else:
        num_atoms = num_atoms.to(device)

    batch_idx = torch.repeat_interleave(torch.arange(batch_size, device=device), num_atoms)
    n_total = int(num_atoms.sum().item())

    frac_t = wrap_centered_torch(torch.randn(n_total, 3, device=device))
    lattice_t = torch.randn(batch_size, 6, device=device)
    atom_t = torch.full((n_total,), MASK_TOKEN, dtype=torch.long, device=device)

    cond_tensors = build_condition_tensors(
        batch_size=batch_size,
        target_density=target_density,
        target_band_gap=target_band_gap,
        target_formula=target_formula,
        target_composition=target_composition,
        target_spacegroup=target_spacegroup,
        target_record=target_record,
        target_conditions=target_conditions,
    )

    records = []
    checkpoints = sorted(set(np.linspace(T, 1, 10).round().astype(int).tolist()), reverse=True)
    for step in range(T, 0, -1):
        t_graph = torch.full((batch_size,), step, dtype=torch.long, device=device)

        pred_coord_eps, pred_lattice_eps, pred_atom_logits = guided_predictions(
            model=model,
            frac_t=frac_t,
            lattice_t=lattice_t,
            atom_t=atom_t,
            num_atoms=num_atoms,
            batch_idx=batch_idx,
            t_graph=t_graph,
            continuous_cond=None if cond_tensors is None else cond_tensors["continuous"],
            composition_cond=None if cond_tensors is None else cond_tensors["composition"],
            spacegroup_cond=None if cond_tensors is None else cond_tensors["spacegroup"],
            guidance_scale=guidance_scale,
        )

        frac_t = ddpm_reverse_step_continuous(frac_t, pred_coord_eps, t_graph, batch_idx=batch_idx, wrap=True)
        lattice_t = ddpm_reverse_step_continuous(lattice_t, pred_lattice_eps, t_graph, batch_idx=None, wrap=False)
        lattice_t = clamp_lattice_features_norm_torch(lattice_t)
        atom_t = discrete_reverse_step_absorbing(atom_t, pred_atom_logits, t_graph, batch_idx)

        if record and (step in checkpoints or step == 1):
            records.append(
                {
                    "step": step,
                    "frac_t": frac_t.detach().cpu().clone(),
                    "lattice_t": lattice_t.detach().cpu().clone(),
                    "atom_t": atom_t.detach().cpu().clone(),
                    "num_atoms": num_atoms.detach().cpu().clone(),
                    "batch_idx": batch_idx.detach().cpu().clone(),
                    "target_conditions": pretty_condition_target(cond_tensors),
                }
            )

    return frac_t.cpu(), lattice_t.cpu(), atom_t.cpu(), num_atoms.cpu(), records

def decode_sampled_structures(frac_t, lattice_t, atom_t, num_atoms):
    structures = []
    offset = 0
    for i, n in enumerate(num_atoms.tolist()):
        frac_i = frac_t[offset : offset + n].numpy()
        atom_i = atom_t[offset : offset + n].numpy()
        offset += n

        valid = atom_i != MASK_TOKEN
        if valid.sum() == 0:
            valid[0] = True
            atom_i[0] = 0

        structure = x0_to_structure(frac_i[valid], lattice_t[i].numpy(), atom_i[valid])
        structures.append(structure)
    return structures

10) Unconditional generation

Let’s sample from the base model before we teach the network any target properties at all.

This is the cleanest test of whether the unconditional prior has learned a plausible crystal distribution.

Lightweight validity checks

Every generated-structure summary in this notebook includes a small geometric sanity screen:

  • finite positive cell volume,

  • density in a broad but reasonable range,

  • minimum pair distance above 0.6A˚0.6\,\AA,

  • lattice lengths and angles inside a broad physical window.

These are fast demonstrative checks, not a substitute for relaxation, symmetry cleanup, oxidation-state analysis, or DFT validation. They are still useful because they catch the most obvious broken decodes before you inspect galleries or save CIF files.

The first cell generates and summarizes samples. The second one shows a few of them as static ASE previews.

frac_u, lattice_u, atom_u, num_u, record_u = sample_crystals(
    base_model,
    batch_size=6,
    guidance_scale=1.0,
    record=True,
)
uncond_structures = decode_sampled_structures(frac_u, lattice_u, atom_u, num_u)
uncond_summary = summarize_structures(uncond_structures)
uncond_validity = validity_report(uncond_summary, "unconditional")
display(uncond_summary)
display(uncond_validity)
if (~uncond_summary["lightweight_valid"]).any():
    display(uncond_summary.loc[~uncond_summary["lightweight_valid"], ["sample_id", "formula", "min_pair_distance", "failure_reason"]])
display(
    show_structures(
        uncond_structures[:4],
        [f'{s.composition.reduced_formula} | density={safe_structure_density(s):.2f}' for s in uncond_structures[:4]],
    )
)

What to look for

  • Are the compositions and densities at least vaguely reasonable?

  • Do different samples look visibly different?

  • Are the unit cells non-degenerate?

  • Do the samples look like they came from one broad crystal family rather than collapsing to a single motif?

If not, the usual fixes are:

  • a slightly larger or cleaner dataset,

  • more training,

  • a better-matched distribution over num_atoms,

  • or a slightly larger denoiser.

Once the unconditional prior looks usable, we can ask the next question: can we steer it without destroying plausibility?

11) Add a conditional adapter for density, band gap, composition, and space group

Now that we have seen what the base model can generate on its own, we can move closer to the MatterGen workflow:

  1. start from the unconditional base model,

  2. copy its weights into a conditional model,

  3. fine-tune only the condition-specific pieces and output heads.

Conditioning package

We use four targets, but they enter through three encoders:

  • scalar properties:

    cscalar=[density,band gap]\mathbf{c}_{\text{scalar}} = [\text{density}, \text{band gap}]
  • composition fractions over the dataset vocabulary:

    ccompRK,ici=1\mathbf{c}_{\text{comp}} \in \mathbb{R}^{K}, \qquad \sum_i c_i = 1
  • a discrete space-group label:

    csg{1,,230}c_{\text{sg}} \in \{1, \dots, 230\}

Classifier-free dropout during training

Sometimes we hide the whole condition package during training.
That lets us use classifier-free guidance at sampling time later.

If puncondp_{\text{uncond}} is the dropout probability, the model learns both:

  • a conditional denoiser,

  • and an unconditional fallback denoiser.

Important interpretation note

Composition conditioning here is soft compare to proper implementations: it encourages the generated atom distribution toward a requested composition vector, but it does not guarantee an exact formula.

The code below builds the conditional model, matches its width and depth to the base model so weight transfer is meaningful, and then freezes most parameters so the fine-tune behaves like a lightweight adapter stage.

cond_model = build_mini_mattergen_denoiser(
    NUM_ATOM_CLASSES,
    hidden_dim=HIDDEN_DIM,
    num_blocks=NUM_BLOCKS,
    conditional=True,
).to(device)

# Load base weights wherever the parameter shapes match.
base_state = base_model.state_dict()
cond_state = cond_model.state_dict()
for k, v in base_state.items():
    if k in cond_state and cond_state[k].shape == v.shape:
        cond_state[k] = v
cond_model.load_state_dict(cond_state)

# Optional: freeze everything except the conditional pieces and the output heads.
for name, param in cond_model.named_parameters():
    trainable = (
        ("cond_scalar_embed" in name)
        or ("cond_comp_embed" in name)
        or ("cond_spacegroup_embed" in name)
        or ("cond_fuse" in name)
        or ("uncond_embedding" in name)
        or ("adapter" in name)
        or ("coord_head" in name)
        or ("atom_head" in name)
        or ("lattice_head" in name)
    )
    param.requires_grad = trainable

num_trainable = sum(p.numel() for p in cond_model.parameters() if p.requires_grad)
num_total = sum(p.numel() for p in cond_model.parameters())
print(f"trainable parameters: {num_trainable:,} / {num_total:,}")

Task for you

  • Change only one target at a time and write down which diagnostic actually moves.

  • Separate conditioning from screening in your notes: which quantities are requested before sampling, and which are only measured afterward?

Then we train the adapter. This phase is usually faster than training from scratch because most of the geometric representation has already been learned.

ADAPTER_MAX_EPOCHS = int(globals().get("ADAPTER_MAX_EPOCHS", 64))
ADAPTER_PATIENCE = int(globals().get("ADAPTER_PATIENCE", 12))

adapter_history, adapter_run = train_model(
    cond_model,
    train_loader,
    val_loader,
    epochs=ADAPTER_MAX_EPOCHS,
    lr=1e-4,
    run_name="scalar_composition_spacegroup_adapter",
    patience=ADAPTER_PATIENCE,
    ema_decay=0.95,
)
print(adapter_run)
adapter_history.tail()
plt.figure(figsize=(8, 4))
plt.plot(adapter_history["epoch"], adapter_history["train_loss"], alpha=0.35, label="train")
plt.plot(adapter_history["epoch"], adapter_history["train_loss_ema"], label="train (EMA)")
plt.plot(adapter_history["epoch"], adapter_history["val_loss"], label="val")
plt.axvline(adapter_run["best_epoch"], linestyle="--", alpha=0.6, label=f'best epoch={adapter_run["best_epoch"]}')
plt.xlabel("epoch")
plt.ylabel("loss")
plt.title("Density + band-gap adapter learning curve")
plt.legend()
plt.show()

Try this: unfreeze the whole network instead of only the adapter-related pieces.
Try this too: keep scalar conditioning on, but ablate either the composition encoder or the space-group embedding.
Question: which targets seem easiest for the model to learn, and which ones most strongly affect sample plausibility?

Answer

Scalar targets such as density are often the easiest to learn because they are smooth and low-dimensional. Composition and space group usually affect plausibility more strongly: they constrain chemistry and symmetry more directly, so mistakes there can make a structure look unrealistic even when the scalar targets are roughly correct.

12) Conditional generation with scalar, composition, and space-group targets

Now we ask the adapter model to steer the unconditional prior toward different crystal-property regimes.

Important interpretation note

  • For density, we can directly summarize the generated structures after decoding.

  • For band gap, this notebook uses the MP band-gap labels during training but does not recompute the generated band gap afterward.

  • For composition, the target is a soft composition vector. The model is encouraged toward that stoichiometric mix, but exact formulas are not guaranteed.

  • For space group, we can run a symmetry analyzer on decoded structures, but for noisy/generated outputs this should be read as a rough diagnostic rather than a guaranteed exact match.

So the conditional examples are best interpreted as targeted generation demos, not as strict constrained generation.

Try this: change only one target at a time. For example, fix composition and space group but sweep density. That makes it easier to see which condition the model is actually responding to.

First we pick low/high targets from the training-set property distributions.

train_density_real = np.array([r["density"] for r in train_data], dtype=np.float32)
train_gap_real = np.array([r["band_gap"] for r in train_data], dtype=np.float32)

low_density_target = float(np.quantile(train_density_real, 0.2))
high_density_target = float(np.quantile(train_density_real, 0.8))
low_gap_target = float(np.quantile(train_gap_real, 0.2))
high_gap_target = float(np.quantile(train_gap_real, 0.8))

print("low target density:", low_density_target)
print("high target density:", high_density_target)
print("low target band gap:", low_gap_target)
print("high target band gap:", high_gap_target)

Then we sample four groups of structures: low-density, high-density, low-band-gap, and high-band-gap. After that, we will add a composition + space-group example anchored to one training crystal.

frac_low_d, lattice_low_d, atom_low_d, num_low_d, _ = sample_crystals(
    cond_model,
    batch_size=6,
    target_density=low_density_target,
    guidance_scale=2.0,
)
low_density_structures = decode_sampled_structures(frac_low_d, lattice_low_d, atom_low_d, num_low_d)
low_density_summary = summarize_structures(low_density_structures)

frac_high_d, lattice_high_d, atom_high_d, num_high_d, _ = sample_crystals(
    cond_model,
    batch_size=6,
    target_density=high_density_target,
    guidance_scale=2.0,
)
high_density_structures = decode_sampled_structures(frac_high_d, lattice_high_d, atom_high_d, num_high_d)
high_density_summary = summarize_structures(high_density_structures)

frac_low_g, lattice_low_g, atom_low_g, num_low_g, _ = sample_crystals(
    cond_model,
    batch_size=6,
    target_band_gap=low_gap_target,
    guidance_scale=2.0,
)
low_gap_structures = decode_sampled_structures(frac_low_g, lattice_low_g, atom_low_g, num_low_g)
low_gap_summary = summarize_structures(low_gap_structures)

frac_high_g, lattice_high_g, atom_high_g, num_high_g, _ = sample_crystals(
    cond_model,
    batch_size=6,
    target_band_gap=high_gap_target,
    guidance_scale=2.0,
)
high_gap_structures = decode_sampled_structures(frac_high_g, lattice_high_g, atom_high_g, num_high_g)
high_gap_summary = summarize_structures(high_gap_structures)

conditional_validity = pd.concat([
    validity_report(low_density_summary, "low density target"),
    validity_report(high_density_summary, "high density target"),
    validity_report(low_gap_summary, "low band-gap target"),
    validity_report(high_gap_summary, "high band-gap target"),
], ignore_index=True)

This plot helps you compare the targets against the empirical training distribution, and (where possible) the generated samples.

fig, axes = plt.subplots(1, 2, figsize=(10, 3.8))

axes[0].hist(train_density_real, bins=24, alpha=0.35, label="train densities")
axes[0].axvline(low_density_target, linestyle="--", label="low density target")
axes[0].axvline(high_density_target, linestyle="--", label="high density target")
axes[0].scatter(low_density_summary["density"], np.full(len(low_density_summary), -0.02), marker="x", label="generated low density")
axes[0].scatter(high_density_summary["density"], np.full(len(high_density_summary), -0.05), marker="x", label="generated high density")
axes[0].set_xlabel("density (g/cm³)")
axes[0].set_title("Density-conditioned sampling")
axes[0].legend(fontsize=8)

axes[1].hist(train_gap_real, bins=24, alpha=0.35, label="train band gaps")
axes[1].axvline(low_gap_target, linestyle="--", label="low band-gap target")
axes[1].axvline(high_gap_target, linestyle="--", label="high band-gap target")
axes[1].set_xlabel("band gap (eV)")
axes[1].set_title("Band-gap targets used for conditioning")
axes[1].legend(fontsize=8)

plt.tight_layout()
plt.show()

print("Low-density conditioned samples")
display(low_density_summary)
print("High-density conditioned samples")
display(high_density_summary)

print("Low-band-gap conditioned samples")
display(low_gap_summary)
print("High-band-gap conditioned samples")
display(high_gap_summary)
print("Lightweight validity report across conditional sweeps")
display(conditional_validity)

Finally, inspect a few generated structures from each scalar-property regime.

print("Low-density conditioned examples")
display(
    show_structures(
        low_density_structures[:4],
        [f'{s.composition.reduced_formula} | density={safe_structure_density(s):.2f}' for s in low_density_structures[:4]],
    )
)

print("High-density conditioned examples")
display(
    show_structures(
        high_density_structures[:4],
        [f'{s.composition.reduced_formula} | density={safe_structure_density(s):.2f}' for s in high_density_structures[:4]],
    )
)

print("Low-band-gap conditioned examples")
display(
    show_structures(
        low_gap_structures[:4],
        [f'{s.composition.reduced_formula} | target gap={low_gap_target:.2f} eV' for s in low_gap_structures[:4]],
    )
)

print("High-band-gap conditioned examples")
display(
    show_structures(
        high_gap_structures[:4],
        [f'{s.composition.reduced_formula} | target gap={high_gap_target:.2f} eV' for s in high_gap_structures[:4]],
    )
)

Composition + space-group conditioning demo

A simple way to demonstrate the richer conditioning is to borrow a target composition and space group from one real crystal.

In the next cell we choose one anchor crystal from the training set and ask the model to sample new structures with:

  • the same target composition,

  • the same target space-group label,

  • the same number of atoms,

  • and roughly the same density.

This is not strict constrained generation, but it is a good teaching experiment because you can compare “what was requested” versus “what came out”.

Prediction before you run it: rank these targets from easiest to hardest for this notebook to satisfy: composition, rough density, exact space group.
Question: would you expect the exact space-group target to work reliably here, or would you expect many decoded samples to fall back to P1? Why?

anchor_record = train_data[0]
print("Anchor record")
print({
    "formula": anchor_record["formula"],
    "density": float(anchor_record["density"]),
    "band_gap": float(anchor_record["band_gap"]),
    "spacegroup_number": int(anchor_record["spacegroup_number"]),
    "spacegroup_symbol": anchor_record["spacegroup_symbol"],
    "num_atoms": int(anchor_record["num_atoms"]),
})

anchor_formula = Composition(anchor_record["formula"]).reduced_formula
anchor_num_atoms = torch.full((6,), int(anchor_record["num_atoms"]), dtype=torch.long)

frac_comp_sg, lattice_comp_sg, atom_comp_sg, num_comp_sg, _ = sample_crystals(
    cond_model,
    batch_size=6,
    target_formula=anchor_record["formula"],
    target_spacegroup=int(anchor_record["spacegroup_number"]),
    target_density=float(anchor_record["density"]),
    guidance_scale=3.0,
    num_atoms=anchor_num_atoms,
)

comp_sg_structures = decode_sampled_structures(frac_comp_sg, lattice_comp_sg, atom_comp_sg, num_comp_sg)
comp_sg_summary = summarize_structures(comp_sg_structures).assign(
    formula_match=lambda df: df["formula"].eq(anchor_formula),
    spacegroup_match=lambda df: df["spacegroup_number"].eq(int(anchor_record["spacegroup_number"])),
)
comp_sg_validity = validity_report(comp_sg_summary, "composition + space-group target")
display(comp_sg_summary)
display(comp_sg_validity)
print("Formula-match fraction:", float(comp_sg_summary["formula_match"].mean()))
print("Space-group-match fraction:", float(comp_sg_summary["spacegroup_match"].mean()))

Things to notice:

  • Do the generated formulas look related to the target composition?

  • Does the inferred space group sometimes match the requested one?

  • If many decoded structures still end up as P1, is that surprising, or is it what you should expect from a tiny label-conditioned teaching model?

  • What happens if you keep the composition fixed but change the target density or guidance scale?

A fun follow-up is to swap target_formula=anchor_record["formula"] for a hand-written formula such as "SrTiO3" or "BaTiO3" provided those elements appear in your tiny curated dataset.

Instructor note

In this notebook, exact space-group control is usually the hardest target. Composition and rough density can often move in the requested direction, but symmetry is much more fragile: the model only sees a soft space-group label, the dataset is tiny, the decoder is not enforcing crystallographic symmetry, and post hoc symmetry finding is sensitive to even small distortions. So a large fraction of decoded samples being labeled P1 is not especially surprising here.

display(
    show_structures(
        [anchor_record["structure"]] + comp_sg_structures[:3],
        labels=[
            f'target anchor | {anchor_record["formula"]} | sg={anchor_record["spacegroup_symbol"]}',
            *[
                f'{s.composition.reduced_formula} | sg={safe_spacegroup_info(s)[1]}'
                for s in comp_sg_structures[:3]
            ],
        ],
        columns=2,
    )
)

Try this: change guidance_scale from 2.0 to 1.0 or 3.0.
Try this too: keep target_formula fixed and swap only the target space group, or vice versa.
Question: what tradeoff do you observe between stronger conditioning and structural plausibility?

Answer

As guidance increases, the model usually follows the requested condition more strongly, but diversity drops first and geometry can start to degrade if the guidance becomes too aggressive. The sweet spot is the smallest scale that visibly moves the samples toward the target without making them collapse or distort.

13) Show the diffusion trajectory as strips and animated GIFs

We now render each trajectory in two complementary ways:

  • a single long strip of selected timesteps for side-by-side comparison,

  • an animated GIF with more intermediate frames so the motion is easier to follow.

Why keep both?

  • the strip is stable in notebooks and exported documents,

  • the strip makes it easier to ask “what changed between these two times?”,

  • the GIF preserves more of the gradual trajectory.

We will show two trajectories:

  1. a forward trajectory for a real crystal,

  2. a reverse trajectory starting from a maximally corrupted version of that crystal.

Try this: choose a different anchor crystal for the trajectory strip and GIF. Does the reverse process look smoother for some structures than for others?

The helper code that decodes intermediate states and draws the strip is tucked away below. Open it later if you want to understand exactly how the trajectory images are constructed.

# @title
from pathlib import Path
from PIL import Image as PILImage


def decode_state_to_structure(frac_centered: np.ndarray, lattice_norm: np.ndarray, atom_tokens: np.ndarray) -> Structure:
    atom_tokens = np.asarray(atom_tokens, dtype=np.int64)
    frac_centered = np.asarray(frac_centered, dtype=np.float32)
    valid = atom_tokens != MASK_TOKEN

    if valid.sum() == 0:
        valid = np.zeros_like(atom_tokens, dtype=bool)
        valid[0] = True
        atom_tokens = atom_tokens.copy()
        atom_tokens[0] = 0

    return x0_to_structure(frac_centered[valid], lattice_norm, atom_tokens[valid])


def select_evenly_spaced_indices(num_items: int, num_frames: int):
    if num_items <= 0:
        raise ValueError('Need at least one trajectory state to visualize.')
    count = min(int(num_frames), int(num_items))
    return np.unique(np.linspace(0, num_items - 1, num=count, dtype=int)).tolist()


def select_trajectory_snapshots(structures, titles, num_frames: int):
    indices = select_evenly_spaced_indices(len(structures), num_frames)
    return [structures[i] for i in indices], [titles[i] for i in indices]


def render_structure_panel(structure, title, rotation: str = '20x,30y,0z', dpi: int = 180):
    fig, ax = plt.subplots(1, 1, figsize=(3.2, 3.4), facecolor='white')
    atoms = structure_to_ase_atoms(structure)
    plot_atoms(atoms, ax, rotation=rotation, radii=0.35, show_unit_cell=2)
    ax.set_title(title, fontsize=10)
    ax.set_axis_off()
    plt.tight_layout()

    buffer = io.BytesIO()
    fig.savefig(buffer, format='png', dpi=dpi, facecolor='white')
    plt.close(fig)
    buffer.seek(0)

    frame = PILImage.open(buffer).convert('RGB')
    frame.load()
    buffer.close()
    return frame


def render_structure_strip(structures, titles, out_path: str, rotation: str = '20x,30y,0z'):
    out_path = str(Path(out_path))
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)

    n = len(structures)
    fig, axes = plt.subplots(1, n, figsize=(2.8 * n, 3.2), squeeze=False)
    axes = axes[0]

    for ax, structure, title in zip(axes, structures, titles):
        atoms = structure_to_ase_atoms(structure)
        plot_atoms(atoms, ax, rotation=rotation, radii=0.35, show_unit_cell=2)
        ax.set_title(title, fontsize=9)
        ax.set_axis_off()

    plt.tight_layout()
    fig.savefig(out_path, dpi=180, bbox_inches='tight', facecolor='white')
    plt.close(fig)
    return out_path


def render_structure_gif(structures, titles, out_path: str, rotation: str = '20x,30y,0z', duration_ms: int = 220):
    out_path = str(Path(out_path))
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)

    frames = [render_structure_panel(structure, title, rotation=rotation) for structure, title in zip(structures, titles)]
    if not frames:
        raise ValueError('Need at least one frame to save a GIF.')

    frames[0].save(
        out_path,
        save_all=True,
        append_images=frames[1:],
        duration=int(duration_ms),
        loop=0,
        optimize=False,
    )
    return out_path


def show_image(path: str):
    return IPythonImage(filename=path)


@torch.no_grad()
def collect_forward_diffusion_states(record):
    frac0 = torch.tensor(record['frac0'], dtype=torch.float32, device=device)
    lattice0 = torch.tensor(record['lattice0_norm'][None], dtype=torch.float32, device=device)
    atom0 = torch.tensor(record['atom_tokens0'], dtype=torch.long, device=device)
    batch_idx = torch.zeros(len(atom0), dtype=torch.long, device=device)

    structures, titles = [], []
    for step in range(T + 1):
        t_graph = torch.tensor([int(step)], dtype=torch.long, device=device)

        if step == 0:
            frac_t = frac0
            lattice_t = lattice0[0]
            atom_t = atom0
        else:
            frac_t, _ = q_sample_wrapped_coords(frac0, t_graph, batch_idx)
            lattice_t, _ = q_sample_lattice(lattice0, t_graph)
            atom_t, _ = q_sample_atom_types(atom0, t_graph, batch_idx)
            lattice_t = lattice_t[0]

        structure_t = decode_state_to_structure(
            frac_t.detach().cpu().numpy(),
            lattice_t.detach().cpu().numpy(),
            atom_t.detach().cpu().numpy(),
        )
        masked_fraction = float((atom_t == MASK_TOKEN).float().mean().detach().cpu())
        structures.append(structure_t)
        titles.append(f't={step}\nmasked={masked_fraction:.0%}')

    return structures, titles


@torch.no_grad()
def make_forward_diffusion_visuals(
    record,
    strip_path='forward_diffusion_strip.png',
    gif_path='forward_diffusion.gif',
    strip_frames=10,
    gif_frames=24,
):
    structures, titles = collect_forward_diffusion_states(record)
    strip_structures, strip_titles = select_trajectory_snapshots(structures, titles, strip_frames)
    gif_structures, gif_titles = select_trajectory_snapshots(structures, titles, gif_frames)

    return (
        render_structure_strip(strip_structures, strip_titles, out_path=strip_path),
        render_structure_gif(gif_structures, gif_titles, out_path=gif_path),
    )


@torch.no_grad()
def reconstruct_example_with_record(
    model,
    record,
    guidance_scale=2.0,
    strip_path='reverse_diffusion_strip.png',
    gif_path='reverse_diffusion.gif',
    strip_frames=10,
    gif_frames=24,
):
    model.eval()

    num_atoms = torch.tensor([record['num_atoms']], dtype=torch.long, device=device)
    batch_idx = torch.zeros(record['num_atoms'], dtype=torch.long, device=device)

    frac0 = torch.tensor(record['frac0'], dtype=torch.float32, device=device)
    lattice0 = torch.tensor(record['lattice0_norm'][None], dtype=torch.float32, device=device)
    atom0 = torch.tensor(record['atom_tokens0'], dtype=torch.long, device=device)
    cond_tensors = build_condition_tensors(batch_size=1, target_record=record)

    t_init = torch.tensor([T], dtype=torch.long, device=device)
    frac_t, _ = q_sample_wrapped_coords(frac0, t_init, batch_idx)
    lattice_t, _ = q_sample_lattice(lattice0, t_init)
    atom_t, _ = q_sample_atom_types(atom0, t_init, batch_idx)
    lattice_t = lattice_t.clone()

    trajectory_structures, trajectory_titles = [], []
    start_structure = decode_state_to_structure(
        frac_t.detach().cpu().numpy(),
        lattice_t[0].detach().cpu().numpy(),
        atom_t.detach().cpu().numpy(),
    )
    trajectory_structures.append(start_structure)
    trajectory_titles.append(f'start\nt={T}')

    for step in range(T, 0, -1):
        t_graph = torch.tensor([step], dtype=torch.long, device=device)

        pred_coord_eps, pred_lattice_eps, pred_atom_logits = guided_predictions(
            model=model,
            frac_t=frac_t,
            lattice_t=lattice_t,
            atom_t=atom_t,
            num_atoms=num_atoms,
            batch_idx=batch_idx,
            t_graph=t_graph,
            continuous_cond=cond_tensors['continuous'],
            composition_cond=cond_tensors['composition'],
            spacegroup_cond=cond_tensors['spacegroup'],
            guidance_scale=guidance_scale,
        )

        frac_t = ddpm_reverse_step_continuous(frac_t, pred_coord_eps, t_graph, batch_idx=batch_idx, wrap=True)
        lattice_t = ddpm_reverse_step_continuous(lattice_t, pred_lattice_eps, t_graph, batch_idx=None, wrap=False)
        lattice_t = clamp_lattice_features_norm_torch(lattice_t)
        atom_t = discrete_reverse_step_absorbing(atom_t, pred_atom_logits, t_graph, batch_idx)

        structure_t = decode_state_to_structure(
            frac_t.detach().cpu().numpy(),
            lattice_t[0].detach().cpu().numpy(),
            atom_t.detach().cpu().numpy(),
        )
        trajectory_structures.append(structure_t)
        trajectory_titles.append(f't={step - 1}')

    strip_structures, strip_titles = select_trajectory_snapshots(trajectory_structures, trajectory_titles, strip_frames)
    gif_structures, gif_titles = select_trajectory_snapshots(trajectory_structures, trajectory_titles, gif_frames)

    return (
        render_structure_strip(strip_structures, strip_titles, out_path=strip_path),
        render_structure_gif(gif_structures, gif_titles, out_path=gif_path),
        trajectory_structures[-1],
    )


example_record = train_data[0]
print(
    'chosen example:',
    example_record['name'],
    '|',
    example_record['formula'],
    '| density =',
    example_record['density'],
    '| band gap =',
    example_record['band_gap'],
    '| space group =',
    example_record['spacegroup_symbol'],
)
display(
    show_structures(
        [example_record['structure']],
        [f"real example | {example_record['formula']} | density={example_record['density']:.2f} | gap={example_record['band_gap']:.2f} eV | sg={example_record['spacegroup_symbol']}"],
        columns=1,
    )
)

This first cell shows how a real example is gradually destroyed by the forward process.

forward_strip, forward_gif = make_forward_diffusion_visuals(
    example_record,
    strip_path='forward_diffusion_strip.png',
    gif_path='forward_diffusion.gif',
    strip_frames=10,
    gif_frames=24,
)
print('Forward strip:')
display(show_image(forward_strip))
print('Forward GIF (denser set of intermediate frames):')
display(show_image(forward_gif))

This second cell runs the learned reverse process from a maximally corrupted version of the same crystal.

Prediction before you run it: do you expect the reverse trajectory to recover the original space group exactly, or only a chemically plausible low-symmetry approximation?
Question: if the endpoint looks reasonable but is still assigned P1, does that mean the denoiser failed completely?

reverse_strip, reverse_gif, reverse_final_structure = reconstruct_example_with_record(
    cond_model,
    example_record,
    guidance_scale=2.0,
    strip_path='reverse_diffusion_strip.png',
    gif_path='reverse_diffusion.gif',
    strip_frames=10,
    gif_frames=24,
)
print('Reverse strip:')
display(show_image(reverse_strip))
print('Reverse GIF (denser set of intermediate frames):')
display(show_image(reverse_gif))

reverse_sg_num, reverse_sg_symbol = safe_spacegroup_info(reverse_final_structure)
display(
    show_structures(
        [example_record['structure'], reverse_final_structure],
        labels=[
            f'real example | {example_record["formula"]} | sg={example_record["spacegroup_symbol"]}',
            f'reverse endpoint | {reverse_final_structure.composition.reduced_formula} | sg={reverse_sg_symbol}',
        ],
        columns=2,
    )
)
print('Reverse endpoint density:', f'{safe_structure_density(reverse_final_structure):.2f}')
print('Reverse endpoint space group:', reverse_sg_num, reverse_sg_symbol)

What to notice in the strips

In the forward strip:

  • coordinates get noisier,

  • the lattice drifts,

  • atom identities become increasingly masked.

In the reverse strip:

  • the model gradually proposes a coherent lattice,

  • atom identities are revealed,

  • coordinates settle into a more structured arrangement,

  • but exact symmetry may still be fragile even when the chemistry looks plausible.

Question: which part seems hardest for the model to recover: atom identity, geometry, or exact crystallographic symmetry?

Answer

Atom identity is often the hardest single channel to reconstruct because masking removes discrete chemical information entirely, while noisy coordinates and lattices still leave geometric hints behind. Exact space group is often harder still, because it depends on all three channels being jointly clean enough that a symmetry finder no longer collapses the structure to P1.

14) Save generated structures and keep the useful artifacts

At the end of the notebook we save:

  • generated CIF files,

  • checkpoint files,

  • trajectory-strip images,

  • trajectory GIFs.

That makes it easy to inspect results outside the notebook or compare different runs.

The helper is short enough that it is worth reading: it just loops over structures and writes them as CIFs.

# @title
from gen_helpers.crystal_display import save_structures_as_cifs

out_uncond = save_structures_as_cifs(uncond_structures, 'generated_unconditional_cifs')
out_low_density = save_structures_as_cifs(low_density_structures, 'generated_low_density_cifs')
out_high_density = save_structures_as_cifs(high_density_structures, 'generated_high_density_cifs')
out_low_gap = save_structures_as_cifs(low_gap_structures, 'generated_low_gap_cifs')
out_high_gap = save_structures_as_cifs(high_gap_structures, 'generated_high_gap_cifs')

print('saved unconditional CIFs to:', out_uncond)
print('saved low-density CIFs to:', out_low_density)
print('saved high-density CIFs to:', out_high_density)
print('saved low-gap CIFs to:', out_low_gap)
print('saved high-gap CIFs to:', out_high_gap)
print('forward trajectory strip:', forward_strip)
print('forward trajectory gif:', forward_gif)
print('reverse trajectory strip:', reverse_strip)
print('reverse trajectory gif:', reverse_gif)
print('cached MP dataset:', MP_CACHE_PATH)
print('base best checkpoint:', base_run['best_checkpoint'])
print('base last checkpoint:', base_run['last_checkpoint'])
print('adapter best checkpoint:', adapter_run['best_checkpoint'])
print('adapter last checkpoint:', adapter_run['last_checkpoint'])

15) What is faithful to MatterGen, and what is simplified?

Faithful ideas

This notebook really does implement the main design pattern:

  • a ChemGraph-like flattened representation,

  • separate corruption processes for coordinates / lattice / atoms,

  • a graph denoiser with three output heads,

  • unconditional base training first,

  • adapter-style conditional fine-tuning,

  • classifier-free guidance for sampling.

Simplifications

It is still intentionally smaller than the real system:

  • the denoiser is a compact message-passing network rather than full GemNet-T,

  • the discrete atom diffusion is a simplified absorbing-mask variant,

  • the training corpus is tiny compared with production-scale materials generation,

  • band-gap conditioning is demonstrated as a label-conditioned task, not a validated property-prediction pipeline,

  • composition conditioning is a soft fraction-vector target rather than exact formula control,

  • space-group conditioning is a lightweight symmetry label rather than a full equivariant crystallographic constraint, so many decoded samples still end up as P1.

Those simplifications are what make the notebook teachable.

Suggested experiments

  1. Dataset scope: switch the chemistry filter and compare how the density / band-gap distributions change before retraining.

  2. Architecture ablation: halve HIDDEN_DIM or NUM_BLOCKS and inspect what degrades first: the losses or the generated structures.

  3. Guidance sweep: compare guidance_scale = 1.0, 2.0, and 3.0 while keeping the target fixed.

  4. Condition isolation: hold composition and space group fixed, then sweep only density or only band gap.

  5. Symmetry stress test: measure how often decoded samples are assigned P1 as you vary guidance_scale or target space group.

What to open next

  • mattergen-crystals.ipynb shows how the same ideas look in a pretrained production-style workflow.

  • chemeleon-crystals.ipynb shows the same generative family reorganized around two user-facing tasks: DNG and CSP.

Crystal Exercises

  1. Representation: Why is batch_idx more natural than padding every crystal to the same number of atoms?

Answer

batch_idx keeps the computation aligned with the true number of atoms in each crystal, avoids wasted padded nodes, and matches graph neural network tooling directly.

  1. Corruption design: Why do we use different forward corruption processes for coordinates, lattice, and atom identity instead of one generic noise rule?

Answer

Those three channels live in different spaces: coordinates are periodic continuous variables, lattice parameters are constrained continuous variables, and atom identities are discrete. Using one corruption rule for all three would ignore those structural differences and make denoising less natural.

  1. Training strategy: What is the point of training the unconditional base model before the conditional adapter?

Answer

The base model learns a generic crystal prior first. The adapter then learns how to steer that prior toward requested properties without having to relearn all of crystal plausibility from scratch.

  1. Guidance: If you raise the classifier-free guidance scale, what usually changes first: diversity or condition adherence?

Answer

Condition adherence usually improves while diversity drops. If the scale becomes too large, structural plausibility can start to degrade as well.

  1. Interpretation: Why should the band-gap and space-group demos here be read as targeted generation illustrations rather than strict guarantees?

Answer

The notebook conditions on labels and analyzes decoded outputs, but it does not solve a fully constrained crystallographic inverse problem. The conditions steer the generator, yet the decoded structures still need downstream validation if exact property or symmetry guarantees matter.

  1. Symmetry quality: Why should we expect exact space-group control to be weaker than composition or rough density control in this notebook, and why do many decoded samples end up as P1?

Answer

The space-group target is only a soft label in a very small dataset, not a hard symmetry constraint built into the decoder. Small coordinate or lattice errors can break the intended symmetry, and the post hoc symmetry analysis is sensitive enough that nearly symmetric structures are often classified as P1.

  1. Course connection: What part of this notebook would have been hardest to understand without graph-networks.ipynb?

Answer

Usually the flattened crystal batching, graph message passing, and the distinction between atom-wise and crystal-wise tensors. The diffusion logic itself is new, but the graph representation borrows heavily from the CGCNN-style notebook.

  1. Conditioning vs screening: Give one example of a quantity used for conditioning and one used mainly for screening in this notebook.

Answer

A conditioning example is the requested density or band-gap target. A screening example is the decoded density histogram or the post hoc space-group analysis used to judge what the samples look like afterward.

  1. Looking ahead: After this notebook, when might you open mattergen-crystals.ipynb and when might you open chemeleon-crystals.ipynb?

Answer

Open MatterGen when you want the closest production-style continuation of direct crystal diffusion with pretrained checkpoints and scalar-property steering. Open Chemeleon when you want a task-oriented workflow centered on open-ended DNG or formula-conditioned CSP.

References
  1. Jain, A., Ong, S. P., Hautier, G., Chen, W., Richards, W. D., Dacek, S., Cholia, S., Gunter, D., Skinner, D., Ceder, G., & Persson, K. A. (2013). Commentary: The Materials Project: A materials genome approach to accelerating materials innovation. APL Materials, 1(1). 10.1063/1.4812323