Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

MatterGen-light, taught step by step

Open in Colab

MatterGen-light, taught step by step

This notebook assumes the diffusion ideas from diffusion-fundamentals.ipynb and turns them into a small crystal-generation pipeline.

Aims

You will build a compact, MatterGen-like workflow that:

  1. curates a crystal dataset directly from Materials Project,

  2. turns crystals into a ChemGraph-like batch format,

  3. corrupts coordinates, lattice, and atom types with separate forward processes,

  4. trains an unconditional base model,

  5. adds density + band-gap conditioning with an adapter-style fine-tune,

  6. samples crystals and inspects diffusion trajectories.

Learning outcomes

By the end you should be able to:

  1. explain what parts of a crystal need separate corruption processes,

  2. read and manipulate a flattened crystal-graph batch representation,

  3. train a small unconditional crystal denoiser,

  4. add conditional signals without rebuilding the whole model,

  5. interpret unconditional samples, conditioned samples, and diffusion-path plots as complementary diagnostics.

How to read this notebook

  • Run it from top to bottom after the fundamentals notebook.

  • Many implementation-heavy cells are collapsed. You can ignore them on the first pass.

  • Focus first on the markdown explanations, the small visible code cells, and the plots / generated structures.

  • Then reopen hidden cells only when you want the full implementation details.

Runtime expectations

  • quick mode is the best first pass: it uses a small binary-crystal dataset, a smaller denoiser, and shorter training loops so the whole workflow stays notebook-scale.

  • full mode keeps the same teaching path but broadens the chemistry and training budget, so it is better once the notebook is already working on your machine.

  • On CPU, quick mode is usually in the 10--20 minute range once dependencies are installed and the dataset is cached; full mode can easily take roughly twice that.

  • A live Materials Project query adds a short data-fetch step; the bundled fallback stays local.

Key sources and papers

This notebook is a pedagogical MatterGen-light build rather than an official reproduction. The main sources behind the data/API path and the crystal-diffusion design choices are:

The notebook keeps the same broad decomposition as production crystal diffusion systems, with separate treatment of atom identity, periodic coordinates, and lattice geometry, but scales the model and dataset down so the full workflow stays readable and runnable in a teaching setting.

0) Install the small dependency set

This cell installs the notebook dependencies. After it runs once, you can usually ignore it.

It also installs py3Dmol, which is only used if you switch the structure viewer from the default static previews to the optional interactive mode later in the notebook.

# @title
# Colab install cell.
# PyTorch usually already exists in Colab, so we only install the extras we use.
!pip install uv
!uv pip -q install mp-api pymatgen ase pandas matplotlib tqdm pillow monty py3Dmol

1) Setup and connect to Materials Project

This section does three things:

  • imports the packages we need,

  • chooses CPU or GPU,

  • asks for your Materials Project API key.

Why start here? Because this notebook queries Materials Project directly rather than reading a bundled archive. The dataset step is part of the workflow.

Design choice: that keeps the notebook closer to a real research workflow, where the data query is part of the experiment.

If you do not provide an API key, the notebook falls back to a bundled small Materials Project mini-corpus stored in the repo. That means the notebook still teaches from real MP structures rather than a synthetic toy set.

If mp_api imports cleanly, the notebook uses it directly. If that import fails in a lightweight environment, the notebook falls back to direct REST requests against the same Materials Project API.

Try later: after the notebook works once, rerun only the dataset section with a different chemistry or structure family.

Fastest key paths:

  • paste the key into the short form cell just below for this notebook session,

  • or set MP_API_KEY / PMG_MAPI_KEY in your environment,

  • or store a Colab secret named MP_API_KEY.

The code below is intentionally short and practical. Read it as:

  • imports,

  • seed everything for reproducibility,

  • choose device,

  • fetch API key from the quick form, your environment, or a Colab secret.

# @title Optional: add your Materials Project API key for this session
import os

MP_API_KEY_INPUT = ""  # @param {type:"string"}

if MP_API_KEY_INPUT.strip():
    os.environ["MP_API_KEY"] = MP_API_KEY_INPUT.strip()
    os.environ["PMG_MAPI_KEY"] = MP_API_KEY_INPUT.strip()
    print(
        "Stored a Materials Project API key for this notebook session:",
        "*" * 8 + MP_API_KEY_INPUT.strip()[-4:],
    )
else:
    print(
        "Leave this blank to use MP_API_KEY / PMG_MAPI_KEY from your environment or a Colab secret named MP_API_KEY."
    )
# @title
import os
import io
import html
import math
import copy
import json
import hashlib
import time
import random
import re
import itertools
import warnings
from getpass import getpass
from pathlib import Path
from dataclasses import dataclass
import sys
import types

# Some lightweight environments ship without the optional bz2 extension.
# A few imports in the crystal stack transitively touch `bz2`, so we provide
# a small compatibility shim when the compiled module is unavailable.
try:
    import bz2  # type: ignore
except Exception:
    def _bz2_unavailable(*args, **kwargs):
        raise RuntimeError("bz2 compression is unavailable in this Python build")

    bz2 = types.ModuleType("bz2")
    bz2.BZ2File = _bz2_unavailable
    bz2.compress = _bz2_unavailable
    bz2.decompress = _bz2_unavailable
    bz2.open = _bz2_unavailable
    sys.modules["bz2"] = bz2

try:
    import lzma  # type: ignore
except Exception:
    lzma = types.ModuleType("lzma")
    lzma.LZMAFile = _bz2_unavailable
    lzma.compress = _bz2_unavailable
    lzma.decompress = _bz2_unavailable
    sys.modules["lzma"] = lzma


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import requests
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from pymatgen.core import Structure, Lattice, Element, Composition
from pymatgen.io.cif import CifWriter
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
try:
    from mp_api.client import MPRester
except Exception:
    MPRester = None
from monty.serialization import dumpfn, loadfn

from ase.visualize.plot import plot_atoms

try:
    import py3Dmol  # type: ignore
except Exception:
    py3Dmol = None

from IPython.display import display, HTML, Markdown, Image as IPythonImage

warnings.filterwarnings("ignore")

SEED = 7
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

def get_mp_api_key() -> str:
    key = os.environ.get("PMG_MAPI_KEY") or os.environ.get("MP_API_KEY")
    if key:
        os.environ["PMG_MAPI_KEY"] = key
        return key

    # Optional Colab secret named MP_API_KEY.
    try:
        from google.colab import userdata  # type: ignore
        key = userdata.get("MP_API_KEY")
    except Exception:
        key = None

    if not key:
        print(
            "No Materials Project API key found; using the bundled small Materials Project fallback corpus. "
            "Fill the form cell above or set MP_API_KEY to switch to live Materials Project data."
        )
        return ""

    os.environ["PMG_MAPI_KEY"] = key
    return key

MP_API_KEY = get_mp_api_key()
if MP_API_KEY:
    print("Using Materials Project key ending with:", "*" * 8 + MP_API_KEY[-4:])
else:
    print("No Materials Project API key detected; the bundled Materials Project fallback corpus will be used.")

Choose quick or full notebook mode

Use quick for a first pass through the entire notebook. It focuses on a small binary-crystal corpus so the chemistry, plots, and generated samples stay easy to read. Use full once you want a stronger demo with broader chemistry and a larger denoiser.

The mode controls the default dataset size, chemistry scope, batch size, network width, and training budgets. You can still override any of those values later if you want to explore manually.

# @title Choose quick or full notebook mode
DEMO_MODE = "quick"  # @param ["quick", "full"]

DEMO_PRESETS = {
    "quick": {
        "chemistry_scope": "2_elements",
        "max_structures": 20000,
        "max_atoms": 5,
        "batch_size": 16,
        "hidden_dim": 64,
        "num_blocks": 4,
        "base_epochs": 75,
        "base_patience": 25,
        "adapter_epochs": 75,
        "adapter_patience": 25,
    },
    "full": {
        "chemistry_scope": "2_to_4_elements",
        "max_structures": 20000,
        "max_atoms": 20,
        "batch_size": 64,
        "hidden_dim": 128,
        "num_blocks": 6,
        "base_epochs": 75,
        "base_patience": 25,
        "adapter_epochs": 75,
        "adapter_patience": 25,
    },
}

DEMO_CONFIG = DEMO_PRESETS[DEMO_MODE].copy()
CHEMISTRY_SCOPE = DEMO_CONFIG["chemistry_scope"]
MAX_STRUCTURES = DEMO_CONFIG["max_structures"]
MAX_ATOMS = DEMO_CONFIG["max_atoms"]
FORCE_MP_REFRESH = False
BATCH_SIZE = DEMO_CONFIG["batch_size"]
HIDDEN_DIM = DEMO_CONFIG["hidden_dim"]
NUM_BLOCKS = DEMO_CONFIG["num_blocks"]
BASE_MAX_EPOCHS = DEMO_CONFIG["base_epochs"]
BASE_PATIENCE = DEMO_CONFIG["base_patience"]
ADAPTER_MAX_EPOCHS = DEMO_CONFIG["adapter_epochs"]
ADAPTER_PATIENCE = DEMO_CONFIG["adapter_patience"]

print(f"Notebook mode: {DEMO_MODE}")
for key, value in DEMO_CONFIG.items():
    print(f"  {key}: {value}")
print("  force_mp_refresh:", FORCE_MP_REFRESH)

2) Build a small crystal corpus from Materials Project

Dataset creation stays deliberately simple.

The idea

We fetch a broad, stable, non-elemental crystal set from Materials Project and keep only a small number of structures so the rest of the notebook stays fast enough for Colab.

If a live API key is available, this section queries Materials Project directly and caches the curated result. It uses mp_api when available and otherwise falls back to direct REST requests to the same endpoint. If no key is available, it loads a bundled small real-MP fallback dataset from the repository and adapts it to the current filters.

You only choose:

  1. how many distinct elements typical crystals should have,

  2. how many structures you want in the teaching corpus,

  3. the maximum number of atoms per crystal,

  4. whether to use cached results if they already exist.

That is enough to get a realistic miniature dataset without turning the notebook into a database-engineering exercise.

The hidden implementation cell below contains the full Materials Project helper code.

You do not need to read it on your first pass.

For most users, the next visible dataset cell is enough: choose the chemistry scope, choose the dataset size, and fetch the corpus.

#@title Materials Project dataset helpers (implementation details) { display-mode: "form" }
VAL_FRACTION = 0.1

GEOM_LOG_LENGTH_CLAMP = (-0.5, 2.5)
GEOM_COS_CLAMP = (-0.98, 0.98)
GEOM_MAX_DIST = 12.0

MIN_DECODED_VOLUME = 5e-2
MIN_DECODED_LENGTH = 2.0
MAX_DECODED_LENGTH = 20.0
MIN_DECODED_ANGLE = 35.0
MAX_DECODED_ANGLE = 145.0

DEFAULT_MP_CACHE_DIR = Path("mp_curated_cache")
DEFAULT_MP_CACHE_DIR.mkdir(parents=True, exist_ok=True)

BUNDLED_MP_FALLBACK_PATH = DEFAULT_MP_CACHE_DIR / "bundled_real_mp_2_to_4_elements_small.json.gz"
BUNDLED_MP_FALLBACK_CONFIG = {
    "name": "bundled_real_mp_2_to_4_elements_small",
    "chemistry_scope": "2_to_4_elements",
    "min_num_elements": 2,
    "max_num_elements": 4,
    "max_atoms": 20,
    "max_structures": 128,
    "energy_above_hull_max": 0.05,
    "is_stable_only": True,
    "include_theoretical": False,
    "exclude_elements": ["H"],
}

CHEMISTRY_SCOPE_TO_NUM_ELEMENTS = {
    "2_elements": (2, 2),
    "3_elements": (3, 3),
    "2_to_4_elements": (2, 4),
    "2_to_5_elements": (2, 5),
}

def normalized_structure_anonymous_formula(structure: Structure) -> str:
    try:
        return str(structure.composition.anonymized_formula).replace(" ", "")
    except Exception:
        return ""

def canonicalize_config_for_cache(cfg: dict) -> dict:
    return json.loads(json.dumps(cfg, sort_keys=True))

def dataset_cache_path_from_config(cfg: dict) -> Path:
    canonical = canonicalize_config_for_cache(cfg)
    digest = hashlib.md5(json.dumps(canonical, sort_keys=True).encode()).hexdigest()[:10]
    safe_name = re.sub(r"[^a-zA-Z0-9_-]+", "_", str(cfg.get("name", "mp_general_dataset"))).strip("_") or "mp_general_dataset"
    return DEFAULT_MP_CACHE_DIR / f"{safe_name}_{digest}.json.gz"

def make_general_dataset_config(
    chemistry_scope: str = "2_to_4_elements",
    max_structures: int = 96,
    max_atoms: int = 16,
    force_refresh: bool = False,
):
    if chemistry_scope not in CHEMISTRY_SCOPE_TO_NUM_ELEMENTS:
        raise ValueError(f"Unknown chemistry_scope: {chemistry_scope}")
    min_el, max_el = CHEMISTRY_SCOPE_TO_NUM_ELEMENTS[chemistry_scope]
    target = int(max_structures)
    # Fetch somewhat more than we need so that post-filters still leave a healthy dataset.
    chunk_size = 250
    num_chunks = max(2, math.ceil((target * 3) / chunk_size))
    return {
        "name": f"general_{chemistry_scope}",
        "chemistry_scope": chemistry_scope,
        "min_num_elements": int(min_el),
        "max_num_elements": int(max_el),
        "max_atoms": int(max_atoms),
        "max_structures": int(max_structures),
        "energy_above_hull_max": 0.05,
        "is_stable_only": True,
        "include_theoretical": False,
        "exclude_elements": ["H"],
        "force_refresh": bool(force_refresh),
        "query_chunk_size": int(chunk_size),
        "query_num_chunks": int(num_chunks),
    }

def build_general_mp_query_from_config(cfg: dict):
    min_el = int(cfg["min_num_elements"])
    max_el = int(cfg["max_num_elements"])
    num_elements_filter = min_el if min_el == max_el else (min_el, max_el)

    return dict(
        num_sites=(2, int(cfg["max_atoms"])),
        num_elements=num_elements_filter,
        energy_above_hull=(0.0, float(cfg.get("energy_above_hull_max", 0.05))),
        is_stable=bool(cfg.get("is_stable_only", True)),
        theoretical=bool(cfg.get("include_theoretical", False)),
        exclude_elements=list(cfg.get("exclude_elements", ["H"])),
        num_chunks=int(cfg.get("query_num_chunks", 2)),
        chunk_size=int(cfg.get("query_chunk_size", 250)),
        all_fields=False,
        fields=[
            "material_id",
            "formula_pretty",
            "density",
            "structure",
            "energy_above_hull",
            "band_gap",
            "is_stable",
            "theoretical",
        ],
    )

def mp_doc_get(doc, key: str, default=None):
    if isinstance(doc, dict):
        return doc.get(key, default)
    return getattr(doc, key, default)

def structure_from_mp_doc(doc):
    structure = mp_doc_get(doc, "structure", None)
    if structure is None:
        return None
    if isinstance(structure, Structure):
        return structure
    if isinstance(structure, dict):
        return Structure.from_dict(structure)
    raise TypeError(f"Unsupported structure payload type: {type(structure)!r}")

def build_general_mp_rest_params(query: dict) -> dict:
    params = {}

    num_sites = query.get("num_sites")
    if num_sites is not None:
        if isinstance(num_sites, (tuple, list)):
            params["nsites_min"] = int(num_sites[0])
            params["nsites_max"] = int(num_sites[1])
        else:
            params["nsites_min"] = int(num_sites)
            params["nsites_max"] = int(num_sites)

    num_elements = query.get("num_elements")
    if num_elements is not None:
        if isinstance(num_elements, (tuple, list)):
            params["nelements_min"] = int(num_elements[0])
            params["nelements_max"] = int(num_elements[1])
        else:
            params["nelements_min"] = int(num_elements)
            params["nelements_max"] = int(num_elements)

    energy_above_hull = query.get("energy_above_hull")
    if energy_above_hull is not None:
        params["energy_above_hull_min"] = float(energy_above_hull[0])
        params["energy_above_hull_max"] = float(energy_above_hull[1])

    if query.get("is_stable") is not None:
        params["is_stable"] = bool(query["is_stable"])
    if query.get("theoretical") is not None:
        params["theoretical"] = bool(query["theoretical"])

    exclude_elements = query.get("exclude_elements")
    if exclude_elements:
        params["exclude_elements"] = ",".join(str(el) for el in exclude_elements)

    fields = query.get("fields")
    if fields:
        params["_fields"] = ",".join(fields)
    params["_all_fields"] = False
    return params

def fetch_general_summary_docs_via_rest(api_key: str, cfg: dict):
    query = build_general_mp_query_from_config(cfg)
    base_params = build_general_mp_rest_params(query)
    url = "https://api.materialsproject.org/materials/summary/"
    headers = {"X-API-KEY": api_key}
    docs = []
    seen = set()

    num_chunks = int(query.get("num_chunks", 2))
    chunk_size = int(query.get("chunk_size", 250))

    for chunk_idx in range(num_chunks):
        params = dict(base_params)
        params["_limit"] = chunk_size
        params["_skip"] = chunk_idx * chunk_size

        response = requests.get(url, params=params, headers=headers, timeout=60)
        response.raise_for_status()
        payload = response.json()
        page_docs = payload.get("data", [])
        if not page_docs:
            break

        for doc in page_docs:
            mpid = str(mp_doc_get(doc, "material_id", ""))
            if mpid in seen:
                continue
            seen.add(mpid)
            docs.append(doc)

        if len(page_docs) < chunk_size:
            break

    query = dict(query)
    query["transport"] = "rest"
    return docs, query

def is_teaching_friendly_structure(structure: Structure) -> bool:
    elements = getattr(structure.composition, "elements", [])
    if any(getattr(el, "is_radioactive", False) for el in elements):
        return False
    if any(getattr(el, "is_noble_gas", False) for el in elements):
        return False
    return True

def build_synthetic_records(cfg: dict, seed: int = SEED):
    rng = np.random.default_rng(seed)
    prototypes = [
        ("NaCl", Lattice.cubic(5.64), ["Na", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]], 5.0),
        ("CsCl", Lattice.cubic(4.12), ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]], 4.4),
        ("Si", Lattice.cubic(5.43), ["Si", "Si"], [[0, 0, 0], [0.25, 0.25, 0.25]], 1.1),
        ("ZnS", Lattice.cubic(5.41), ["Zn", "S"], [[0, 0, 0], [0.25, 0.25, 0.25]], 3.2),
        ("MgO", Lattice.cubic(4.21), ["Mg", "O"], [[0, 0, 0], [0.5, 0.5, 0.5]], 7.4),
        ("SrTiO3", Lattice.cubic(3.905), ["Sr", "Ti", "O", "O", "O"], [[0, 0, 0], [0.5, 0.5, 0.5], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]], 2.1),
        ("CaTiO3", Lattice.orthorhombic(5.44, 7.64, 5.38), ["Ca", "Ti", "O", "O", "O"], [[0, 0, 0], [0.5, 0.5, 0.5], [0.25, 0.25, 0], [0.25, 0, 0.25], [0, 0.25, 0.25]], 3.0),
        ("LiFePO4", Lattice.orthorhombic(10.33, 6.01, 4.69), ["Li", "Fe", "P", "O", "O", "O", "O", "O", "O", "O"], [[0.1, 0.1, 0.1], [0.3, 0.3, 0.3], [0.5, 0.5, 0.5], [0.05, 0.15, 0.25], [0.15, 0.25, 0.35], [0.25, 0.35, 0.45], [0.35, 0.45, 0.55], [0.45, 0.55, 0.65], [0.55, 0.65, 0.75], [0.65, 0.75, 0.85]], 3.4),
        ("Na2CO3", Lattice.monoclinic(7.1, 5.8, 6.4, 110), ["Na", "Na", "C", "O", "O", "O"], [[0, 0, 0], [0.5, 0.5, 0.5], [0.25, 0.25, 0.25], [0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5]], 5.8),
        ("BaTiO3", Lattice.cubic(4.00), ["Ba", "Ti", "O", "O", "O"], [[0, 0, 0], [0.5, 0.5, 0.5], [0.5, 0.5, 0], [0.5, 0, 0.5], [0, 0.5, 0.5]], 2.3),
    ]

    target = int(cfg.get("max_structures", 24))
    min_el = int(cfg.get("min_num_elements", 2))
    max_el = int(cfg.get("max_num_elements", 5))
    max_atoms = int(cfg.get("max_atoms", 16))

    records = []
    skipped = {"synthetic": 0}
    i = 0
    while len(records) < target:
        name, lattice, species, frac_coords, base_gap = prototypes[i % len(prototypes)]
        scale = float(rng.uniform(0.96, 1.08))
        structure = Structure(Lattice(lattice.matrix * scale), species, frac_coords)
        num_elements = len({str(el) for el in structure.composition.elements})
        if not (min_el <= num_elements <= max_el) or len(structure) > max_atoms:
            i += 1
            skipped["synthetic"] += 1
            continue
        try:
            sga = SpacegroupAnalyzer(structure, symprec=0.15)
            sg_num = int(sga.get_space_group_number())
            sg_symbol = str(sga.get_space_group_symbol())
        except Exception:
            sg_num = 1
            sg_symbol = "P1"
        records.append({
            "name": f"{name}_{len(records):03d}",
            "formula": structure.composition.reduced_formula,
            "chemsys": "-".join(sorted({str(el) for el in structure.composition.elements})),
            "structure": structure,
            "atomic_numbers": [int(site.specie.Z) for site in structure],
            "num_atoms": int(len(structure)),
            "frac0": wrap_centered_np(np.asarray(structure.frac_coords, dtype=np.float32)).astype(np.float32),
            "lattice0": lattice_to_features(structure).astype(np.float32),
            "density": float(structure.density),
            "band_gap": float(max(0.0, base_gap + rng.normal(0.0, 0.2))),
            "spacegroup_number": sg_num,
            "spacegroup_symbol": sg_symbol,
        })
        i += 1

    return records, skipped, {"source": "synthetic_last_resort", "target_size": target}

def fetch_general_summary_docs(api_key: str, cfg: dict):
    query = build_general_mp_query_from_config(cfg)

    if MPRester is not None:
        try:
            docs = []
            seen = set()
            with MPRester(api_key) as mpr:
                for doc in mpr.materials.summary.search(**query):
                    mpid = str(mp_doc_get(doc, "material_id", ""))
                    if mpid in seen:
                        continue
                    docs.append(doc)
                    seen.add(mpid)
            query = dict(query)
            query["transport"] = "mp_api"
            return docs, query
        except Exception as exc:
            print(
                "`mp_api` import/query failed in this environment; retrying via direct REST requests to Materials Project."
            )
            print(f"Reason: {exc.__class__.__name__}: {exc}")

    return fetch_general_summary_docs_via_rest(api_key=api_key, cfg=cfg)

def select_evenly_spaced_records(records, target_size: int):
    records = list(records)
    if len(records) <= target_size:
        return records
    records = sorted(records, key=lambda r: (float(r["density"]), float(r["band_gap"]), r["formula"]))
    idxs = np.linspace(0, len(records) - 1, target_size).round().astype(int).tolist()
    # preserve order but drop duplicate indices
    dedup = []
    seen = set()
    for idx in idxs:
        if idx not in seen:
            dedup.append(records[idx])
            seen.add(idx)
    return dedup

def cache_payload_source_label(cache_payload, records) -> str:
    if isinstance(cache_payload, dict):
        mp_meta = cache_payload.get("mp_query_used", {}) or {}
        if mp_meta.get("source"):
            return str(mp_meta["source"])
    if records:
        record_source = records[0].get("source")
        if record_source:
            return str(record_source)
    return "unknown"

def cache_payload_is_real_mp(cache_payload, records) -> bool:
    source = cache_payload_source_label(cache_payload, records).lower()
    return "materials_project" in source or source in {"bundled_materials_project_fallback", "materials_project_live_query"}

def ensure_record_feature_fields(records):
    if not records:
        return records
    for r in records:
        structure = r.get("structure")
        if structure is None:
            continue
        if "lattice0" not in r:
            r["lattice0"] = lattice_to_features(structure).astype(np.float32)
        if "frac0" not in r:
            r["frac0"] = wrap_centered_np(np.asarray(structure.frac_coords, dtype=np.float32)).astype(np.float32)
        if "num_atoms" not in r:
            r["num_atoms"] = int(len(structure))
        if "formula" not in r:
            r["formula"] = structure.composition.reduced_formula
        if "chemsys" not in r:
            r["chemsys"] = "-".join(sorted({el.symbol for el in structure.composition.elements}))
        if "num_elements" not in r:
            r["num_elements"] = int(len(structure.composition.elements))
        if "spacegroup_number" not in r or "spacegroup_symbol" not in r:
            try:
                sga = SpacegroupAnalyzer(structure, symprec=0.15)
                r.setdefault("spacegroup_number", int(sga.get_space_group_number()))
                r.setdefault("spacegroup_symbol", str(sga.get_space_group_symbol()))
            except Exception:
                r.setdefault("spacegroup_number", 1)
                r.setdefault("spacegroup_symbol", "P1")
    return records

def filter_records_for_config(records, cfg: dict):
    min_el = int(cfg.get("min_num_elements", 2))
    max_el = int(cfg.get("max_num_elements", 5))
    max_atoms = int(cfg.get("max_atoms", 20))
    target = int(cfg.get("max_structures", len(records)))

    filtered = []
    for r in records:
        structure = r.get("structure")
        num_atoms = int(r.get("num_atoms", len(structure) if structure is not None else 0))
        num_elements = int(r.get("num_elements", len(structure.composition.elements) if structure is not None else 0))
        if structure is not None and not is_teaching_friendly_structure(structure):
            continue
        if num_atoms > max_atoms:
            continue
        if num_elements < min_el or num_elements > max_el:
            continue
        filtered.append(r)

    return select_evenly_spaced_records(filtered, min(target, len(filtered)))

def load_bundled_fallback_records(cfg: dict):
    if not BUNDLED_MP_FALLBACK_PATH.exists():
        raise FileNotFoundError(f"Bundled fallback dataset is missing: {BUNDLED_MP_FALLBACK_PATH}")

    payload = loadfn(BUNDLED_MP_FALLBACK_PATH)
    if isinstance(payload, dict) and "records" in payload:
        raw_records = payload["records"]
        skipped = payload.get("skipped", {})
        meta = dict(payload.get("mp_query_used", {}) or {})
    else:
        raw_records = payload
        skipped = {}
        meta = {}

    raw_records = ensure_record_feature_fields(list(raw_records))
    records = filter_records_for_config(raw_records, cfg)
    meta["source"] = "bundled_materials_project_fallback"
    meta["bundled_fallback_path"] = BUNDLED_MP_FALLBACK_PATH.name
    meta["bundled_record_count"] = len(raw_records)
    meta["adapted_record_count"] = len(records)
    return records, skipped, meta

def probe_materials_project_connection(api_key: str):
    if MPRester is not None:
        try:
            with MPRester(api_key) as mpr:
                docs = list(
                    mpr.materials.summary.search(
                        material_ids=["mp-149"],
                        all_fields=False,
                        fields=["material_id"],
                    )
                )
            return len(docs)
        except Exception:
            pass

    response = requests.get(
        "https://api.materialsproject.org/materials/summary/",
        params={
            "material_ids": "mp-149",
            "_fields": "material_id",
            "_all_fields": False,
            "_limit": 1,
        },
        headers={"X-API-KEY": api_key},
        timeout=30,
    )
    response.raise_for_status()
    payload = response.json()
    return len(payload.get("data", []))

def curate_general_summary_docs(docs, cfg: dict, seed: int = 7):
    rng = random.Random(seed)
    docs = list(docs)
    rng.shuffle(docs)

    records = []
    skipped = {
        "missing_structure": 0,
        "missing_band_gap": 0,
        "parse_error": 0,
        "too_many_atoms": 0,
        "too_few_elements": 0,
        "too_many_elements": 0,
        "unfriendly_chemistry": 0,
    }

    min_el = int(cfg["min_num_elements"])
    max_el = int(cfg["max_num_elements"])
    max_atoms = int(cfg["max_atoms"])

    for doc in tqdm(docs, desc="Curating MP structures"):
        try:
            structure = structure_from_mp_doc(doc)
            if structure is None:
                skipped["missing_structure"] += 1
                continue

            structure = structure.copy()
            try:
                structure = structure.get_primitive_structure()
            except Exception:
                pass

            if len(structure) > max_atoms:
                skipped["too_many_atoms"] += 1
                continue

            num_elements = len(structure.composition.elements)
            if num_elements < min_el:
                skipped["too_few_elements"] += 1
                continue
            if num_elements > max_el:
                skipped["too_many_elements"] += 1
                continue
            if not is_teaching_friendly_structure(structure):
                skipped["unfriendly_chemistry"] += 1
                continue

            band_gap = mp_doc_get(doc, "band_gap", None)
            if band_gap is None or not np.isfinite(float(band_gap)):
                skipped["missing_band_gap"] += 1
                continue

            try:
                sga = SpacegroupAnalyzer(structure, symprec=0.15)
                sg_num = int(sga.get_space_group_number())
                sg_symbol = str(sga.get_space_group_symbol())
            except Exception:
                sg_num = 1
                sg_symbol = "P1"

            atomic_numbers = np.array([site.specie.Z for site in structure], dtype=np.int64)
            frac0 = wrap_centered_np(np.asarray(structure.frac_coords, dtype=np.float32)).astype(np.float32)
            lattice0 = lattice_to_features(structure).astype(np.float32)
            density_value = mp_doc_get(doc, "density", None)
            density = float(density_value if density_value is not None else structure.density)
            band_gap = max(0.0, float(band_gap))
            chemsys = "-".join(sorted({el.symbol for el in structure.composition.elements}))

            records.append(
                {
                    "source": "materials_project",
                    "name": str(mp_doc_get(doc, "material_id", "unknown")),
                    "mpid": str(mp_doc_get(doc, "material_id", "unknown")),
                    "formula": str(mp_doc_get(doc, "formula_pretty", None) or structure.composition.reduced_formula),
                    "anonymous_formula": normalized_structure_anonymous_formula(structure),
                    "chemsys": chemsys,
                    "num_elements": int(num_elements),
                    "num_atoms": int(len(structure)),
                    "atomic_numbers": atomic_numbers,
                    "frac0": frac0,
                    "lattice0": lattice0,
                    "density": density,
                    "band_gap": band_gap,
                    "energy_above_hull": float(mp_doc_get(doc, "energy_above_hull")) if mp_doc_get(doc, "energy_above_hull", None) is not None else np.nan,
                    "spacegroup_number": sg_num,
                    "spacegroup_symbol": sg_symbol,
                    "structure": structure,
                }
            )
        except Exception:
            skipped["parse_error"] += 1
            continue

    records = select_evenly_spaced_records(records, int(cfg["max_structures"]))
    return records, skipped

def wrap_centered_np(frac: np.ndarray) -> np.ndarray:
    return ((frac + 0.5) % 1.0) - 0.5

def wrap_centered_torch(x: torch.Tensor) -> torch.Tensor:
    return torch.remainder(x + 0.5, 1.0) - 0.5

def lattice_to_features(structure: Structure) -> np.ndarray:
    a, b, c = structure.lattice.abc
    alpha, beta, gamma = structure.lattice.angles
    return np.array(
        [
            np.log(a),
            np.log(b),
            np.log(c),
            np.cos(np.deg2rad(alpha)),
            np.cos(np.deg2rad(beta)),
            np.cos(np.deg2rad(gamma)),
        ],
        dtype=np.float32,
    )

def features_to_lattice(features: np.ndarray) -> Lattice:
    features = np.asarray(features, dtype=np.float64)

    log_lengths = np.clip(features[:3], *GEOM_LOG_LENGTH_CLAMP)
    lengths = np.exp(log_lengths)
    lengths = np.clip(lengths, MIN_DECODED_LENGTH, MAX_DECODED_LENGTH)

    cos_angles = np.clip(features[3:], *GEOM_COS_CLAMP)
    angles = np.rad2deg(np.arccos(cos_angles))
    angles = np.clip(angles, MIN_DECODED_ANGLE, MAX_DECODED_ANGLE)

    try:
        lattice = Lattice.from_parameters(*lengths.tolist(), *angles.tolist())
    except Exception:
        side = float(np.clip(np.mean(lengths), MIN_DECODED_LENGTH, MAX_DECODED_LENGTH))
        lattice = Lattice.cubic(side)

    volume = float(getattr(lattice, "volume", np.nan))
    if (not np.isfinite(volume)) or volume <= MIN_DECODED_VOLUME:
        side = float(
            np.clip(
                np.cbrt(max(float(np.prod(lengths)), MIN_DECODED_VOLUME)),
                MIN_DECODED_LENGTH,
                MAX_DECODED_LENGTH,
            )
        )
        lattice = Lattice.cubic(side)

    return lattice

def denormalize_lattice_features_torch(features_norm: torch.Tensor) -> torch.Tensor:
    lat_mean_t = torch.as_tensor(lat_mean, dtype=features_norm.dtype, device=features_norm.device)
    lat_std_t = torch.as_tensor(lat_std, dtype=features_norm.dtype, device=features_norm.device)
    return features_norm * lat_std_t + lat_mean_t

def clamp_lattice_features_norm_torch(features_norm: torch.Tensor) -> torch.Tensor:
    features = denormalize_lattice_features_torch(features_norm).clone()

    features[:, :3] = torch.clamp(
        features[:, :3],
        min=float(np.log(MIN_DECODED_LENGTH)),
        max=float(np.log(MAX_DECODED_LENGTH)),
    )

    cos_min = float(np.cos(np.deg2rad(MAX_DECODED_ANGLE)))
    cos_max = float(np.cos(np.deg2rad(MIN_DECODED_ANGLE)))
    features[:, 3:] = torch.clamp(features[:, 3:], min=cos_min, max=cos_max)

    lat_mean_t = torch.as_tensor(lat_mean, dtype=features_norm.dtype, device=features_norm.device)
    lat_std_t = torch.as_tensor(lat_std, dtype=features_norm.dtype, device=features_norm.device)
    return (features - lat_mean_t) / torch.clamp(lat_std_t, min=1e-8)

def lattice_matrix_from_features_torch(features: torch.Tensor) -> torch.Tensor:
    log_lengths = torch.clamp(features[:, :3], min=GEOM_LOG_LENGTH_CLAMP[0], max=GEOM_LOG_LENGTH_CLAMP[1])
    lengths = torch.exp(log_lengths)
    cos_angles = torch.clamp(features[:, 3:], min=GEOM_COS_CLAMP[0], max=GEOM_COS_CLAMP[1])
    angles = torch.arccos(cos_angles)

    a, b, c = lengths[:, 0], lengths[:, 1], lengths[:, 2]
    alpha, beta, gamma = angles[:, 0], angles[:, 1], angles[:, 2]

    va = torch.stack([a, torch.zeros_like(a), torch.zeros_like(a)], dim=-1)
    vb = torch.stack([b * torch.cos(gamma), b * torch.sin(gamma), torch.zeros_like(b)], dim=-1)

    cx = c * torch.cos(beta)
    cy = c * (torch.cos(alpha) - torch.cos(beta) * torch.cos(gamma)) / torch.clamp(torch.sin(gamma), min=1e-6)
    cz_sq = torch.clamp(c**2 - cx**2 - cy**2, min=1e-8)
    vz = torch.stack([cx, cy, torch.sqrt(cz_sq)], dim=-1)

    return torch.stack([va, vb, vz], dim=1)

Choose a small teaching dataset

Keep this part light. The goal is not to perfectly curate materials data; it is to get a clean, varied mini-corpus that makes the later diffusion model easy to study.

We use a single broad general query, and the main lever is how many distinct elements a crystal is allowed to have.

#@title Choose the teaching dataset
# Most users only need the quick/full mode cell above.
# Edit the four variables below only if you want to override that preset manually.

CHEMISTRY_SCOPE = globals().get("CHEMISTRY_SCOPE", "2_to_4_elements")
MAX_STRUCTURES = int(globals().get("MAX_STRUCTURES", 96))
MAX_ATOMS = int(globals().get("MAX_ATOMS", 16))
FORCE_MP_REFRESH = bool(globals().get("FORCE_MP_REFRESH", False))

DATASET_CFG = make_general_dataset_config(
    chemistry_scope=CHEMISTRY_SCOPE,
    max_structures=MAX_STRUCTURES,
    max_atoms=MAX_ATOMS,
    force_refresh=FORCE_MP_REFRESH,
)

MP_CACHE_PATH = dataset_cache_path_from_config(DATASET_CFG)

print("Preset mode:", globals().get("DEMO_MODE", "manual"))
print("Dataset config:")
for k, v in DATASET_CFG.items():
    if "query_" in k:
        continue
    print(f"  {k}: {v}")
print("cache path:", MP_CACHE_PATH)

A good way to explore is:

  • start with DEMO_MODE = "quick" and the default 2-to-4-element corpus,

  • switch to DEMO_MODE = "full" once the notebook is working and you want a stronger denoiser,

  • if you specifically want chemically simpler live-MP runs, then try binaries by setting CHEMISTRY_SCOPE = "2_elements",

  • then compare how the dataset histograms and generated structures change as you broaden the chemistry again.

After each change, look at the dataset histograms and ask: did the property range broaden or narrow?

Fetch, cache, and summarize

This cell does four practical things:

  1. builds one broad MP query,

  2. downloads a manageable pool of summary docs,

  3. curates them into one small corpus,

  4. caches the result locally so reruns are fast.

# @title
# Make this cell safe even if earlier cells were skipped or re-run out of order.
if "Path" not in globals():
    from pathlib import Path

if "DEFAULT_MP_CACHE_DIR" not in globals():
    DEFAULT_MP_CACHE_DIR = Path("mp_curated_cache")
    DEFAULT_MP_CACHE_DIR.mkdir(parents=True, exist_ok=True)

if "DATASET_CFG" not in globals():
    if "make_general_dataset_config" not in globals():
        raise NameError(
            "DATASET_CFG is not defined because the Materials Project helper cell has not been run yet. "
            "Please run the hidden 'Materials Project dataset helpers' cell first."
        )
    DATASET_CFG = make_general_dataset_config()

if "MP_CACHE_PATH" not in globals() or MP_CACHE_PATH is None:
    if "dataset_cache_path_from_config" in globals():
        MP_CACHE_PATH = dataset_cache_path_from_config(DATASET_CFG)
    else:
        MP_CACHE_PATH = DEFAULT_MP_CACHE_DIR / "mp_general_dataset_fallback.json.gz"

FORCE_MP_REFRESH = bool(globals().get("FORCE_MP_REFRESH", DATASET_CFG.get("force_refresh", False)))
MAX_ATOMS = int(DATASET_CFG.get("max_atoms", 16))
MAX_STRUCTURES = int(DATASET_CFG.get("max_structures", 96))

print("Using dataset cache:", MP_CACHE_PATH)
print("Bundled fallback dataset:", BUNDLED_MP_FALLBACK_PATH)

records = None
skipped = {}
mp_query_used = {}
cache_payload = None
cache_source = None
used_cache = False
mp_probe_result = None
mp_probe_error = None

if MP_API_KEY:
    try:
        mp_probe_result = probe_materials_project_connection(MP_API_KEY)
        print(f"Live Materials Project probe succeeded ({mp_probe_result} document(s) returned).")
    except Exception as exc:
        mp_probe_error = exc
        print("Live Materials Project probe failed; the notebook will use cached or bundled data if needed.")
        print(f"Reason: {exc.__class__.__name__}: {exc}")
else:
    print("No API key provided; using cached or bundled Materials Project data only.")

if Path(MP_CACHE_PATH).exists() and not FORCE_MP_REFRESH:
    cache_payload = loadfn(MP_CACHE_PATH)
    if isinstance(cache_payload, dict) and "records" in cache_payload:
        candidate_records = cache_payload["records"]
        candidate_skipped = cache_payload.get("skipped", {})
        candidate_query = cache_payload.get("mp_query_used", {})
    else:
        candidate_records = cache_payload
        candidate_skipped = {}
        candidate_query = {}

    candidate_records = ensure_record_feature_fields(list(candidate_records))
    cache_source = cache_payload_source_label(cache_payload, candidate_records)

    if cache_payload_is_real_mp(cache_payload, candidate_records):
        print(f"Loading cached curated records from {MP_CACHE_PATH} [{cache_source}]")
        records = candidate_records
        skipped = candidate_skipped
        mp_query_used = candidate_query
        used_cache = True
    elif MP_API_KEY:
        print(
            f"Existing cache at {MP_CACHE_PATH} came from {cache_source}; regenerating from live Materials Project because an API key is available."
        )
    else:
        print(
            f"Existing cache at {MP_CACHE_PATH} came from {cache_source}; replacing it with the bundled real Materials Project fallback dataset."
        )

if not used_cache:
    live_fetch_error = None
    if MP_API_KEY:
        try:
            docs, mp_query_used = fetch_general_summary_docs(api_key=MP_API_KEY, cfg=DATASET_CFG)
            print("Fetched raw summary docs:", len(docs))
            records, skipped = curate_general_summary_docs(docs, cfg=DATASET_CFG, seed=SEED)
            mp_query_used = dict(mp_query_used)
            mp_query_used["source"] = "materials_project_live_query"
        except Exception as exc:
            live_fetch_error = exc
            print("Materials Project fetch failed; using the bundled real MP fallback dataset instead.")
            print(f"Reason: {exc.__class__.__name__}: {exc}")

    if records is None:
        try:
            records, skipped, mp_query_used = load_bundled_fallback_records(DATASET_CFG)
            if not MP_API_KEY:
                print("Loaded the bundled real Materials Project fallback dataset.")
            if live_fetch_error is not None:
                mp_query_used = dict(mp_query_used)
                mp_query_used["fallback_reason"] = f"{live_fetch_error.__class__.__name__}: {live_fetch_error}"
            if mp_probe_error is not None:
                mp_query_used = dict(mp_query_used)
                mp_query_used["probe_error"] = f"{mp_probe_error.__class__.__name__}: {mp_probe_error}"
        except Exception as fallback_exc:
            print("Bundled fallback unavailable; building a synthetic last-resort corpus instead.")
            print(f"Reason: {fallback_exc.__class__.__name__}: {fallback_exc}")
            records, skipped, mp_query_used = build_synthetic_records(DATASET_CFG, seed=SEED)
            mp_query_used = dict(mp_query_used)
            mp_query_used["fallback_reason"] = f"{fallback_exc.__class__.__name__}: {fallback_exc}"

    dumpfn(
        {
            "records": records,
            "skipped": skipped,
            "mp_query_used": mp_query_used,
            "dataset_cfg": canonicalize_config_for_cache(DATASET_CFG) if "canonicalize_config_for_cache" in globals() else DATASET_CFG,
        },
        MP_CACHE_PATH,
    )
    print(f"Cached {len(records)} curated records to {MP_CACHE_PATH}")

records = ensure_record_feature_fields(list(records))
print(f"loaded {len(records)} structures")
print("dataset source:", cache_payload_source_label({"mp_query_used": mp_query_used}, records))
if skipped:
    print("skipped counts:", skipped)

if len(records) < 32:
    print(
        "Warning: only",
        len(records),
        "structures were curated. The notebook can still run, but training and sampling may be weaker. "
        "Widen CHEMISTRY_SCOPE or raise MAX_STRUCTURES if you want a larger corpus.",
    )

query_rows = [{"query_key": key, "value": value} for key, value in (mp_query_used or {}).items() if key not in {"fields"}]
if mp_probe_result is not None:
    query_rows.append({"query_key": "mp_probe_result", "value": mp_probe_result})
if query_rows:
    display(pd.DataFrame(query_rows))

summary = pd.DataFrame(
    {
        "name": [r["name"] for r in records],
        "formula": [r["formula"] for r in records],
        "anonymous_formula": [r.get("anonymous_formula", "") for r in records],
        "chemsys": [r["chemsys"] for r in records],
        "num_elements": [r.get("num_elements", np.nan) for r in records],
        "num_atoms": [r["num_atoms"] for r in records],
        "density": [r["density"] for r in records],
        "band_gap": [r.get("band_gap", np.nan) for r in records],
        "energy_above_hull": [r.get("energy_above_hull", np.nan) for r in records],
    }
)
summary.head()

The next cell is a quick sanity check.

What to look for:

  • do you have enough structures to train a tiny demo model?

  • do the formulas and atom counts look reasonable?

  • do the density and band-gap histograms have a decent spread?

A broad histogram is useful here because later we will pick quantile-based low/high targets from these natural training-set distributions.

# @title
display(summary.describe(include="all").transpose())

fig, axes = plt.subplots(1, 2, figsize=(10, 3.5))

axes[0].hist(summary["density"].dropna(), bins=24, alpha=0.7)
axes[0].set_title("Curated density distribution")
axes[0].set_xlabel("density (g/cm³)")

axes[1].hist(summary["band_gap"].dropna(), bins=24, alpha=0.7)
axes[1].set_title("Curated band-gap distribution")
axes[1].set_xlabel("band gap (eV)")

plt.tight_layout()
plt.show()

Property-space map

The one-dimensional histograms are useful, but the joint view matters too.

The scatter plot below shows how density and band gap vary together across the curated corpus, with the atom count used as a visual cue.

num_atoms_arr = np.array([int(r["num_atoms"]) for r in records], dtype=np.int32)
fig, axes = plt.subplots(1, 3, figsize=(13.5, 3.8))

sc = axes[0].scatter(
    summary["density"],
    summary["band_gap"],
    c=num_atoms_arr,
    cmap="viridis",
    s=40,
    alpha=0.85,
    edgecolor="white",
    linewidth=0.3,
)
axes[0].set_xlabel("density (g/cm³)")
axes[0].set_ylabel("band gap (eV)")
axes[0].set_title("Density vs band gap")
plt.colorbar(sc, ax=axes[0], label="number of atoms")

axes[1].hist(num_atoms_arr, bins=range(int(num_atoms_arr.min()), int(num_atoms_arr.max()) + 2), alpha=0.8, color="tab:orange")
axes[1].set_xlabel("number of atoms")
axes[1].set_ylabel("count")
axes[1].set_title("Crystal size distribution")

spacegroup_numbers = np.array([int(r.get("spacegroup_number", 0)) for r in records], dtype=np.int32)
spacegroup_counts = pd.Series(spacegroup_numbers).value_counts().sort_index()
axes[2].bar(spacegroup_counts.index.astype(int), spacegroup_counts.values, width=1.0)
axes[2].set_xlabel("space-group number")
axes[2].set_ylabel("count")
axes[2].set_title("Space-group spread")

plt.tight_layout()
plt.show()

Try this: narrow the chemistry scope to binaries and see whether the density-band-gap cloud becomes tighter or looser.

Question: if the cloud collapses too much, what tradeoff does the diffusion model lose later on?

Answer

A narrower cloud makes the task easier to fit but gives the generator less variety to learn from, so conditional generation can become less interesting and less robust.

Pause and inspect the dataset

Before training anything, ask:

  • Does the chemistry scope look like what you intended?

  • Are the atom counts small enough for Colab?

  • Do the density and band-gap histograms cover enough range to make conditioning interesting?

Try this: switch from 2_to_4_elements to 2_elements and compare the histograms.
Question: does narrowing the chemistry make generation easier, or does it also make the learning signal less diverse?

Answer

Narrowing the chemistry often makes the regression targets and composition statistics easier to fit, but it also weakens the diversity of the learning signal. The model then sees a smaller slice of crystal space, so generation can become cleaner yet less interesting and less transferable.

3) Build a MatterGen-like batch representation

MatterGen does not treat each crystal as a padded dense tensor. Instead, it flattens all atoms across the batch and keeps a vector that says which crystal each atom belongs to.

We copy that idea.

Why this matters

A batch of crystals becomes:

  • a single list of atoms,

  • atom-wise tensors like frac0 and atom_tokens0,

  • crystal-wise tensors like lattice0, scalar conditions, composition conditions, and space-group labels,

  • a batch_idx vector telling us which atoms belong to which crystal.

This is the same trick used by many graph neural networks.

Mathematical picture

If crystal bb has nbn_b atoms, the full batch has

Natoms=b=1BnbN_{\text{atoms}} = \sum_{b=1}^{B} n_b

atom nodes.
The model works on those NatomsN_{\text{atoms}} nodes, while still knowing which crystal each node came from.

The conditioning split we use

We treat the conditioning information in three pieces:

cscalar=[ρ,Eg],ccompRK,csg{1,,230},\mathbf{c}_{\text{scalar}} = [\rho, E_g], \qquad \mathbf{c}_{\text{comp}} \in \mathbb{R}^{K}, \qquad c_{\text{sg}} \in \{1,\dots,230\},

where

  • ρ\rho is density,

  • EgE_g is band gap,

  • ccomp\mathbf{c}_{\text{comp}} is a composition-fraction vector over the element vocabulary in our tiny dataset,

  • csgc_{\text{sg}} is the space-group number.

This is a nice teaching compromise: scalar properties stay simple, while composition and symmetry get their own conditioning channels.

The first short cell builds the token vocabulary for atom types. The second one defines the dataset object and the collate function that produces a flattened CrystalBatch.

# @title
all_atomic_numbers = sorted({int(z) for r in records for z in r["atomic_numbers"]})
z_to_token = {z: i for i, z in enumerate(all_atomic_numbers)}
token_to_z = {i: z for z, i in z_to_token.items()}
NUM_ATOM_CLASSES = len(all_atomic_numbers)
MASK_TOKEN = NUM_ATOM_CLASSES  # absorbing-mask diffusion state

print("number of atom classes:", NUM_ATOM_CLASSES)
print("MASK token id:", MASK_TOKEN)
print("first 20 atomic numbers in the tiny corpus:", all_atomic_numbers[:20])

lat_all = np.stack([r["lattice0"] for r in records], axis=0)
lat_mean = lat_all.mean(axis=0).astype(np.float32)
lat_std = (lat_all.std(axis=0) + 1e-6).astype(np.float32)

CONDITION_NAMES = ["density", "band_gap"]
CONDITION_INDEX = {name: i for i, name in enumerate(CONDITION_NAMES)}
MAX_SPACEGROUP_NUMBER = 230

def infer_spacegroup_info(structure: Structure):
    try:
        sga = SpacegroupAnalyzer(structure, symprec=0.1)
        return int(sga.get_space_group_number()), str(sga.get_space_group_symbol())
    except Exception:
        return 0, "unknown"

def composition_fraction_vector_from_atomic_numbers(atomic_numbers):
    vec = np.zeros(NUM_ATOM_CLASSES, dtype=np.float32)
    atomic_numbers = np.asarray(atomic_numbers, dtype=np.int64)
    unique_z, counts = np.unique(atomic_numbers, return_counts=True)
    total = max(int(counts.sum()), 1)
    for z, count in zip(unique_z.tolist(), counts.tolist()):
        if int(z) in z_to_token:
            vec[z_to_token[int(z)]] = float(count) / float(total)
    return vec

cond_all = np.stack(
    [np.array([r["density"], r["band_gap"]], dtype=np.float32) for r in records],
    axis=0,
)
cond_mean = cond_all.mean(axis=0).astype(np.float32)
cond_std = (cond_all.std(axis=0) + 1e-6).astype(np.float32)

composition_all = []
spacegroup_numbers = []

for r in records:
    r["atom_tokens0"] = np.array([z_to_token[int(z)] for z in r["atomic_numbers"]], dtype=np.int64)
    r["lattice0_norm"] = ((r["lattice0"] - lat_mean) / lat_std).astype(np.float32)

    r["conditions"] = np.array([r["density"], r["band_gap"]], dtype=np.float32)
    r["conditions_norm"] = ((r["conditions"] - cond_mean) / cond_std).astype(np.float32)

    comp_vec = composition_fraction_vector_from_atomic_numbers(r["atomic_numbers"])
    r["composition_cond"] = comp_vec.astype(np.float32)
    composition_all.append(comp_vec)

    sg_num, sg_symbol = infer_spacegroup_info(r["structure"])
    r["spacegroup_number"] = int(np.clip(sg_num, 0, MAX_SPACEGROUP_NUMBER))
    r["spacegroup_symbol"] = sg_symbol
    spacegroup_numbers.append(r["spacegroup_number"])

composition_mean = np.stack(composition_all, axis=0).mean(axis=0).astype(np.float32)
composition_mean = composition_mean / np.clip(composition_mean.sum(), 1e-8, None)

spacegroup_mode = int(pd.Series(spacegroup_numbers).mode().iloc[0]) if len(spacegroup_numbers) else 1
SPACEGROUP_NUM_TO_SYMBOL = {}
for r in records:
    SPACEGROUP_NUM_TO_SYMBOL[int(r["spacegroup_number"])] = str(r["spacegroup_symbol"])
SPACEGROUP_SYMBOL_TO_NUM = {v: k for k, v in SPACEGROUP_NUM_TO_SYMBOL.items()}

rng = np.random.default_rng(SEED)
perm = rng.permutation(len(records))
val_size = min(len(records) - 1, max(8, int(len(records) * VAL_FRACTION))) if len(records) > 1 else 0
val_idx = set(perm[:val_size].tolist())

train_data = [records[i] for i in range(len(records)) if i not in val_idx]
val_data = [records[i] for i in range(len(records)) if i in val_idx]

print("train size:", len(train_data), "| val size:", len(val_data))
print("condition means:", {k: float(v) for k, v in zip(CONDITION_NAMES, cond_mean)})
print("condition stds:", {k: float(v) for k, v in zip(CONDITION_NAMES, cond_std)})
print("default space-group condition:", spacegroup_mode, SPACEGROUP_NUM_TO_SYMBOL.get(spacegroup_mode, "unknown"))

top_comp = np.argsort(-composition_mean)[: min(8, len(composition_mean))]
print(
    "mean composition prior:",
    {Element.from_Z(token_to_z[int(i)]).symbol: float(composition_mean[int(i)]) for i in top_comp if composition_mean[int(i)] > 0}
)
# @title
@dataclass
class CrystalBatch:
    frac0: torch.Tensor                    # [N_atoms_total, 3]
    atom_tokens0: torch.Tensor             # [N_atoms_total]
    lattice0: torch.Tensor                 # [B, 6]
    continuous_conditions: torch.Tensor    # [B, 2] -> [density, band_gap]
    composition_conditions: torch.Tensor   # [B, NUM_ATOM_CLASSES]
    spacegroup_conditions: torch.Tensor    # [B]
    num_atoms: torch.Tensor                # [B]
    batch_idx: torch.Tensor                # [N_atoms_total]
    names: list
    formulas: list

    def to(self, device):
        return CrystalBatch(
            frac0=self.frac0.to(device),
            atom_tokens0=self.atom_tokens0.to(device),
            lattice0=self.lattice0.to(device),
            continuous_conditions=self.continuous_conditions.to(device),
            composition_conditions=self.composition_conditions.to(device),
            spacegroup_conditions=self.spacegroup_conditions.to(device),
            num_atoms=self.num_atoms.to(device),
            batch_idx=self.batch_idx.to(device),
            names=self.names,
            formulas=self.formulas,
        )

class CrystalGraphDataset(Dataset):
    def __init__(self, items):
        self.items = items

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        r = self.items[idx]
        return {
            "name": r["name"],
            "formula": r["formula"],
            "num_atoms": r["num_atoms"],
            "frac0": r["frac0"].copy(),
            "atom_tokens0": r["atom_tokens0"].copy(),
            "lattice0": r["lattice0_norm"].copy(),
            "continuous_conditions": r["conditions_norm"].copy(),
            "composition_conditions": r["composition_cond"].copy(),
            "spacegroup_conditions": int(r["spacegroup_number"]),
        }

def collate_crystals(items):
    num_atoms = torch.tensor([x["num_atoms"] for x in items], dtype=torch.long)
    frac0 = torch.tensor(np.concatenate([x["frac0"] for x in items], axis=0), dtype=torch.float32)
    atom_tokens0 = torch.tensor(np.concatenate([x["atom_tokens0"] for x in items], axis=0), dtype=torch.long)
    lattice0 = torch.tensor(np.stack([x["lattice0"] for x in items], axis=0), dtype=torch.float32)
    continuous_conditions = torch.tensor(
        np.stack([x["continuous_conditions"] for x in items], axis=0),
        dtype=torch.float32,
    )
    composition_conditions = torch.tensor(
        np.stack([x["composition_conditions"] for x in items], axis=0),
        dtype=torch.float32,
    )
    spacegroup_conditions = torch.tensor(
        [x["spacegroup_conditions"] for x in items],
        dtype=torch.long,
    )
    batch_idx = torch.repeat_interleave(torch.arange(len(items), dtype=torch.long), num_atoms)
    return CrystalBatch(
        frac0=frac0,
        atom_tokens0=atom_tokens0,
        lattice0=lattice0,
        continuous_conditions=continuous_conditions,
        composition_conditions=composition_conditions,
        spacegroup_conditions=spacegroup_conditions,
        num_atoms=num_atoms,
        batch_idx=batch_idx,
        names=[x["name"] for x in items],
        formulas=[x["formula"] for x in items],
    )

BATCH_SIZE = int(globals().get("BATCH_SIZE", 12))
train_loader = DataLoader(
    CrystalGraphDataset(train_data),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_crystals,
)
val_loader = DataLoader(
    CrystalGraphDataset(val_data),
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_crystals,
)

batch_example = next(iter(train_loader))
print(
    batch_example.frac0.shape,
    batch_example.atom_tokens0.shape,
    batch_example.lattice0.shape,
    batch_example.continuous_conditions.shape,
    batch_example.composition_conditions.shape,
    batch_example.spacegroup_conditions.shape,
)

Try this: print a small batch and verify that atoms from the same crystal share the same batch_idx.
Question: why is this representation more natural than padding every crystal up to the maximum number of atoms?

Answer

The flattened representation keeps every atom as a real node rather than a padded placeholder, so computation scales with the true number of atoms instead of the maximum size in the batch. It also matches graph-message-passing libraries naturally, because batch_idx tells the model which crystal each atom belongs to without wasting memory on padding.

4) Visualize a few real crystals with structure viewers

Before building the model, it is useful to reconnect the tensors to actual structures.

The default path uses static ASE previews because they render reliably in Colab, JupyterLab, and exported notebooks. If you want to rotate the crystals interactively, you can switch the viewer mode to py3Dmol in the small form cell below.

Try this: compare one dense structure and one more open structure. Can you already guess which one might have the higher density just from the packing?

Use the small form cell below to choose between:

  • static: the safest option for Colab, JupyterLab, and exported notebooks,

  • py3Dmol: an interactive 3D viewer with rotation and zoom.

If py3Dmol is unavailable for any reason, the notebook falls back to the static previews automatically.

# @title Viewer settings
STRUCTURE_VIEWER_MODE = "static"  # @param ["static", "py3Dmol"]
PY3DMOL_WIDTH = 360  # @param {type:"integer"}
PY3DMOL_HEIGHT = 300  # @param {type:"integer"}

if STRUCTURE_VIEWER_MODE == "py3Dmol":
    if py3Dmol is None:
        print("py3Dmol is not available in this session, so the notebook will fall back to static ASE previews.")
    else:
        print("Interactive py3Dmol previews are enabled.")
else:
    print("Static ASE previews are enabled.")
# @title
ase_adaptor = AseAtomsAdaptor()

def structure_to_ase_atoms(structure: Structure):
    return ase_adaptor.get_atoms(structure)

def structure_to_cif_string(structure: Structure) -> str:
    return str(CifWriter(structure))

def show_structures_ase(structures, labels=None, columns=2, rotation="20x,30y,0z", dpi=180):
    labels = labels or [f"structure {i}" for i in range(len(structures))]
    if len(structures) == 0:
        raise ValueError("Need at least one structure to preview.")

    columns = max(1, min(int(columns), len(structures)))
    rows = int(math.ceil(len(structures) / columns))
    fig, axes = plt.subplots(
        rows,
        columns,
        figsize=(3.5 * columns, 3.8 * rows),
        squeeze=False,
        facecolor="white",
    )
    axes = axes.ravel()

    for ax in axes[len(structures):]:
        ax.set_axis_off()

    for ax, structure, label in zip(axes, structures, labels):
        ax.set_facecolor("white")
        try:
            atoms = structure_to_ase_atoms(structure)
            plot_atoms(atoms, ax, rotation=rotation, radii=0.35, show_unit_cell=2)
        except Exception as exc:
            ax.text(
                0.5,
                0.5,
                f"Preview failed\n{type(exc).__name__}: {exc}",
                ha="center",
                va="center",
                fontsize=9,
                wrap=True,
            )
        ax.set_title(str(label), fontsize=9, wrap=True, pad=8)
        ax.set_axis_off()

    plt.tight_layout()
    buffer = io.BytesIO()
    fig.savefig(buffer, format="png", dpi=dpi, bbox_inches="tight", facecolor="white")
    plt.close(fig)
    return IPythonImage(data=buffer.getvalue(), format="png")

def show_structures_py3dmol(structures, labels=None, columns=2, width=360, height=300):
    if py3Dmol is None:
        raise RuntimeError("py3Dmol is not installed in this Python session.")

    labels = labels or [f"structure {i}" for i in range(len(structures))]
    if len(structures) == 0:
        raise ValueError("Need at least one structure to preview.")

    cards = []
    for structure, label in zip(structures, labels):
        viewer = py3Dmol.view(width=int(width), height=int(height))
        viewer.addModel(structure_to_cif_string(structure), "cif")
        viewer.setStyle({
            "sphere": {"scale": 0.32, "colorscheme": "Jmol"},
            "stick": {"radius": 0.14, "colorscheme": "Jmol"},
        })
        viewer.addUnitCell()
        viewer.setBackgroundColor("white")
        viewer.zoomTo()
        cards.append(
            '<div style="border:1px solid #ddd;border-radius:10px;padding:8px;background:white;">'
            f'<div style="font-size:13px;font-weight:600;margin-bottom:6px;">{html.escape(str(label))}</div>'
            f'{viewer._make_html()}'
            '</div>'
        )

    columns = max(1, min(int(columns), len(structures)))
    return HTML(
        '<div style="display:grid;'
        f'grid-template-columns:repeat({columns}, minmax({int(width)}px, 1fr));'
        'gap:14px;align-items:start;">'
        + ''.join(cards)
        + '</div>'
    )

def show_structures(
    structures,
    labels=None,
    columns=2,
    rotation="20x,30y,0z",
    dpi=180,
    viewer_mode=None,
    py3dmol_width=None,
    py3dmol_height=None,
):
    mode = STRUCTURE_VIEWER_MODE if viewer_mode is None else viewer_mode
    if mode == "py3Dmol":
        if py3Dmol is None:
            print("py3Dmol is unavailable in this session, so the notebook is falling back to static ASE previews.")
        else:
            try:
                return show_structures_py3dmol(
                    structures,
                    labels=labels,
                    columns=columns,
                    width=PY3DMOL_WIDTH if py3dmol_width is None else py3dmol_width,
                    height=PY3DMOL_HEIGHT if py3dmol_height is None else py3dmol_height,
                )
            except Exception as exc:
                print(f"Interactive viewer failed ({type(exc).__name__}: {exc}); falling back to static ASE previews.")

    return show_structures_ase(
        structures,
        labels=labels,
        columns=columns,
        rotation=rotation,
        dpi=dpi,
    )

example_ids = list(range(min(4, len(records))))
real_structures = [records[i]["structure"] for i in example_ids]
real_labels = [
    f'{records[i]["formula"]} | density={records[i]["density"]:.2f} | gap={records[i]["band_gap"]:.2f} eV'
    for i in example_ids
]
display(show_structures(real_structures, real_labels))

5) Define the three forward corruption processes

This is the core conceptual step.

We do not use one generic noise process for everything.
Instead, we use a separate corruption process for each part of the crystal:

  1. fractional coordinates,

  2. lattice parameters,

  3. atom identities.

5.1 Fractional coordinates: wrapped continuous diffusion

Fractional coordinates live on a torus, not on ordinary Euclidean space.
So we corrupt them continuously and then wrap them back into [0.5,0.5)[-0.5, 0.5):

xt=wrap ⁣(αˉtx0+1αˉtϵ)\mathbf{x}_t = \operatorname{wrap}\!\left( \sqrt{\bar\alpha_t}\,\mathbf{x}_0 + \sqrt{1-\bar\alpha_t}\,\boldsymbol\epsilon \right)

5.2 Lattice: continuous diffusion in a 6D descriptor

We represent the lattice with:

(loga,logb,logc,cosα,cosβ,cosγ)(\log a, \log b, \log c, \cos \alpha, \cos \beta, \cos \gamma)

and diffuse that continuously.
Why this parameterization? Because it is easier to normalize and keeps lengths positive after decoding.

5.3 Atom types: absorbing-mask diffusion

For atom identities, we use a simple discrete process with a special MASK token:

q(at=a0a0)=αˉt,q(at=MASKa0)=1αˉt.q(a_t = a_0 \mid a_0) = \bar\alpha_t, \qquad q(a_t = \text{MASK} \mid a_0) = 1 - \bar\alpha_t.

This is a pedagogical stand-in for MatterGen’s discrete corruption logic.

Try this: pick one real crystal and inspect the same state at an early, middle, and late timestep. Which part looks hardest to noise gracefully: positions, lattice, or atom identity?

Read the code below by matching it to the three bullets above. The main things to notice are the schedule definitions and the different corruption rules for continuous versus discrete variables.


T = 1000  # number of diffusion steps

# Continuous schedules, indexed by 1..T, with alpha_bar[0] = 1 for convenience.
cont_betas = torch.linspace(1e-4, 2e-2, T, dtype=torch.float32)
cont_alphas = 1.0 - cont_betas
cont_alpha_bars = torch.cat([torch.ones(1, dtype=torch.float32), torch.cumprod(cont_alphas, dim=0)], dim=0)

# Discrete absorbing-mask schedule, stronger so that x_T is mostly masked.
disc_betas = torch.linspace(0.02, 0.15, T, dtype=torch.float32)
disc_alphas = 1.0 - disc_betas
disc_alpha_bars = torch.cat([torch.ones(1, dtype=torch.float32), torch.cumprod(disc_alphas, dim=0)], dim=0)

def extract_schedule(arr: torch.Tensor, t: torch.Tensor, batch_idx: torch.Tensor | None, target: torch.Tensor) -> torch.Tensor:
    # arr indexed by timestep in 0..T
    out = arr.to(target.device)[t]
    if batch_idx is not None:
        out = out[batch_idx]
    while out.ndim < target.ndim:
        out = out.unsqueeze(-1)
    return out

def q_sample_wrapped_coords(x0: torch.Tensor, t_graph: torch.Tensor, batch_idx: torch.Tensor):
    sqrt_ab = extract_schedule(cont_alpha_bars.sqrt(), t_graph, batch_idx, x0)
    sqrt_one_minus_ab = extract_schedule((1.0 - cont_alpha_bars).sqrt(), t_graph, batch_idx, x0)
    eps = torch.randn_like(x0)
    x_unwrapped = sqrt_ab * x0 + sqrt_one_minus_ab * eps
    x_t = wrap_centered_torch(x_unwrapped)
    # wrapped epsilon target: nearest periodic residual
    eps_target = wrap_centered_torch(x_t - sqrt_ab * x0) / torch.clamp(sqrt_one_minus_ab, min=1e-6)
    return x_t, eps_target

def q_sample_lattice(x0: torch.Tensor, t_graph: torch.Tensor):
    sqrt_ab = extract_schedule(cont_alpha_bars.sqrt(), t_graph, None, x0)
    sqrt_one_minus_ab = extract_schedule((1.0 - cont_alpha_bars).sqrt(), t_graph, None, x0)
    eps = torch.randn_like(x0)
    x_t = sqrt_ab * x0 + sqrt_one_minus_ab * eps
    return x_t, eps

def q_sample_atom_types(x0: torch.Tensor, t_graph: torch.Tensor, batch_idx: torch.Tensor):
    keep_prob = extract_schedule(disc_alpha_bars, t_graph, batch_idx, x0.float()).squeeze(-1)
    keep = (torch.rand_like(keep_prob) < keep_prob)
    x_t = torch.where(keep, x0, torch.full_like(x0, MASK_TOKEN))
    return x_t, keep

def x0_to_structure(frac_centered: np.ndarray, lattice_norm: np.ndarray, atom_tokens: np.ndarray) -> Structure:
    frac = ((frac_centered + 0.5) % 1.0).astype(np.float64)
    lattice_features = lattice_norm * lat_std + lat_mean
    lattice = features_to_lattice(lattice_features)
    species = [Element.from_Z(int(token_to_z[int(tok)])).symbol for tok in atom_tokens]
    return Structure(lattice=lattice, species=species, coords=frac, coords_are_cartesian=False)

Try this: increase or decrease T.
Question: what do you expect to happen if the number of diffusion steps is too small? Too large?

Answer

If T is too small, the forward process never reaches a genuinely high-noise regime, so the reverse model learns a shallow denoising problem and sampling can look biased or brittle. If T is too large, the learning problem becomes unnecessarily long and noisy, which slows training and can make the reverse chain harder to stabilize in a small notebook setting.

Visual intuition: corrupting one real crystal in three different ways

The forward process is easier to understand when you look at one structure and perturb it channel by channel.

The next plot compares coordinate noise, lattice noise, and atom-type masking on the same training crystal.

# Use a representative crystal from the training split.
example_record = train_data[0]

@torch.no_grad()
def forward_corruption_triptych(record, timestep: int = 60):
    frac0 = torch.tensor(record["frac0"], dtype=torch.float32, device=device)
    lattice0 = torch.tensor(record["lattice0_norm"][None], dtype=torch.float32, device=device)
    atom0 = torch.tensor(record["atom_tokens0"], dtype=torch.long, device=device)
    batch_idx = torch.zeros(len(atom0), dtype=torch.long, device=device)
    t_graph = torch.tensor([int(timestep)], dtype=torch.long, device=device)

    frac_t, _ = q_sample_wrapped_coords(frac0, t_graph, batch_idx)
    lattice_t, _ = q_sample_lattice(lattice0, t_graph)
    atom_t, _ = q_sample_atom_types(atom0, t_graph, batch_idx)

    clean = x0_to_structure(record["frac0"], record["lattice0_norm"], record["atom_tokens0"])
    coord_only = x0_to_structure(frac_t.detach().cpu().numpy(), record["lattice0_norm"], record["atom_tokens0"])
    lattice_only = x0_to_structure(record["frac0"], lattice_t[0].detach().cpu().numpy(), record["atom_tokens0"])
    atom_tokens = atom_t.detach().cpu().numpy().copy()
    frac_np = frac_t.detach().cpu().numpy()
    lattice_np = lattice_t[0].detach().cpu().numpy()
    valid = atom_tokens != MASK_TOKEN
    if valid.sum() == 0:
        valid = np.zeros_like(atom_tokens, dtype=bool)
        valid[0] = True
        atom_tokens[0] = 0
    masked = x0_to_structure(frac_np[valid], lattice_np, atom_tokens[valid])
    return clean, coord_only, lattice_only, masked

clean, coord_only, lattice_only, masked = forward_corruption_triptych(example_record, timestep=70)
labels = [
    f"clean | {clean.composition.reduced_formula}",
    "coordinate noise",
    "lattice noise",
    "coordinate + lattice + atom masking",
]
display(show_structures([clean, coord_only, lattice_only, masked], labels=labels, columns=2))

Question: which corruption do you expect to be hardest to undo, and why?

Answer

Atom masking is often hardest because it removes discrete identity information, while coordinates and lattice still leave geometric clues behind.

6) Build a compact MatterGen-like denoiser

The real MatterGen denoiser is stronger than ours, but the interface is the same: the network sees a noised crystal and predicts

  • coordinate noise,

  • lattice noise,

  • atom logits.

Inputs to the denoiser

For each atom, the model sees:

  • the current noisy fractional coordinate,

  • the current noisy atom token,

  • the current noisy lattice (broadcast from crystal level to atom level),

  • the timestep embedding,

  • the number of atoms in the crystal,

  • optionally a condition embedding.

Output heads

The model predicts three things:

ϵ^coord,ϵ^lattice,^atom\hat{\epsilon}_{\text{coord}}, \qquad \hat{\epsilon}_{\text{lattice}}, \qquad \hat{\ell}_{\text{atom}}

where ^atom\hat{\ell}_{\text{atom}} are logits over atom classes.

How we encode the conditions

We now use three small condition encoders:

  1. a scalar-property encoder for density and band gap,

  2. a composition encoder for the element-fraction vector,

  3. a space-group embedding.

Those are fused into one graph-level condition vector before message passing.

Why message passing?

Atoms are not independent. Their local environment matters.
So we build a graph with periodic edges and update atom features by message passing.

A useful mental model

This denoiser is a small graph UNet without the UNet shape: it repeatedly mixes local geometry, noisy atom identity, lattice context, time information, and optional conditioning.

Try this: halve HIDDEN_DIM and rerun a short training job. What gets worse first: the losses, the samples, or both?

The next cell is intentionally collapsed because it is implementation-heavy. On first pass, you can skip it and just remember:

  1. embed atoms, coordinates, lattice, time, and atom-count information,

  2. build periodic edges,

  3. apply several interaction blocks,

  4. decode three heads.

Open it later if you want the exact message-passing details.

# @title
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Linear(dim, dim),
        )

    def forward(self, t_scaled: torch.Tensor) -> torch.Tensor:
        half = self.dim // 2
        freqs = torch.exp(
            -math.log(10000.0) * torch.arange(half, device=t_scaled.device, dtype=torch.float32) / max(half - 1, 1)
        )
        args = t_scaled.float().unsqueeze(1) * freqs.unsqueeze(0)
        emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
        if emb.shape[-1] < self.dim:
            emb = F.pad(emb, (0, 1))
        return self.mlp(emb)

class InteractionBlock(nn.Module):
    def __init__(self, hidden_dim: int, use_adapter: bool = False):
        super().__init__()
        self.edge_mlp = nn.Sequential(
            nn.Linear(hidden_dim * 2 + 16, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.use_adapter = use_adapter
        if use_adapter:
            self.adapter = nn.Sequential(
                nn.Linear(hidden_dim * 2, hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, hidden_dim, bias=False),
            )
            nn.init.zeros_(self.adapter[-1].weight)

    def forward(self, h, src, dst, edge_feat, cond_per_atom=None, use_uncond_per_atom=None):
        m = self.edge_mlp(torch.cat([h[src], h[dst], edge_feat], dim=-1))
        agg = torch.zeros_like(h)
        agg.index_add_(0, dst, m)
        deg = torch.zeros(h.shape[0], 1, device=h.device)
        deg.index_add_(0, dst, torch.ones(dst.shape[0], 1, device=h.device))
        agg = agg / deg.clamp(min=1.0)
        h = h + self.node_mlp(torch.cat([h, agg], dim=-1))
        if self.use_adapter and cond_per_atom is not None and use_uncond_per_atom is not None:
            adapt = self.adapter(torch.cat([h, cond_per_atom], dim=-1))
            h = h + (~use_uncond_per_atom).float() * adapt
        return h

class MiniMatterGenDenoiser(nn.Module):
    def __init__(
        self,
        num_atom_classes: int,
        hidden_dim: int = 128,
        num_blocks: int = 4,
        max_atoms: int = MAX_ATOMS,
        conditional: bool = False,
        cond_dim: int = len(CONDITION_NAMES),
    ):
        super().__init__()
        self.num_atom_classes = num_atom_classes
        self.conditional = conditional
        self.hidden_dim = hidden_dim
        self.cond_dim = cond_dim

        self.atom_embed = nn.Embedding(num_atom_classes + 1, hidden_dim)
        self.coord_proj = nn.Linear(3, hidden_dim)
        self.lattice_proj = nn.Linear(6, hidden_dim)
        self.time_embed = SinusoidalTimeEmbedding(hidden_dim)
        self.num_atoms_embed = nn.Embedding(max_atoms + 1, hidden_dim)

        if conditional:
            self.cond_scalar_embed = nn.Sequential(
                nn.Linear(cond_dim, hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.cond_comp_embed = nn.Sequential(
                nn.Linear(num_atom_classes, hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.cond_spacegroup_embed = nn.Embedding(MAX_SPACEGROUP_NUMBER + 1, hidden_dim)
            self.cond_fuse = nn.Sequential(
                nn.Linear(hidden_dim * 3, hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.uncond_embedding = nn.Parameter(torch.zeros(1, hidden_dim))

        self.blocks = nn.ModuleList(
            [InteractionBlock(hidden_dim, use_adapter=conditional) for _ in range(num_blocks)]
        )

        self.coord_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 3),
        )
        self.atom_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, num_atom_classes),
        )
        self.lattice_head = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 6),
        )

    def build_periodic_edges(self, frac_t: torch.Tensor, lattice_t: torch.Tensor, num_atoms: torch.Tensor):
        src_all, dst_all, edge_all = [], [], []
        offset = 0

        for g, n in enumerate(num_atoms.tolist()):
            if n <= 1:
                offset += n
                continue

            frac_g = frac_t[offset : offset + n]
            lattice_g = lattice_t[g]

            idx = torch.arange(n, device=frac_t.device)
            ii, jj = torch.meshgrid(idx, idx, indexing="ij")
            mask = ii != jj
            ii = ii[mask]
            jj = jj[mask]

            dfrac = wrap_centered_torch(frac_g[jj] - frac_g[ii])
            dcart = torch.matmul(dfrac, lattice_g)
            dcart = torch.clamp(dcart, min=-GEOM_MAX_DIST, max=GEOM_MAX_DIST)
            dist = torch.norm(dcart, dim=-1, keepdim=True).clamp(max=GEOM_MAX_DIST)

            edge_feat = torch.cat(
                [
                    dfrac,
                    dcart,
                    dist,
                    (dist**2).clamp(max=GEOM_MAX_DIST**2),
                    torch.sin(dist),
                    torch.cos(dist),
                    torch.exp(-dist),
                ],
                dim=-1,
            )
            if edge_feat.shape[-1] < 16:
                edge_feat = F.pad(edge_feat, (0, 16 - edge_feat.shape[-1]))
            edge_feat = edge_feat[:, :16]

            src_all.append(ii + offset)
            dst_all.append(jj + offset)
            edge_all.append(edge_feat)
            offset += n

        if len(src_all) == 0:
            src = torch.zeros(0, dtype=torch.long, device=frac_t.device)
            dst = torch.zeros(0, dtype=torch.long, device=frac_t.device)
            edge_feat = torch.zeros(0, 16, dtype=torch.float32, device=frac_t.device)
        else:
            src = torch.cat(src_all, dim=0)
            dst = torch.cat(dst_all, dim=0)
            edge_feat = torch.cat(edge_all, dim=0)

        return src, dst, edge_feat

    def forward(
        self,
        frac_t: torch.Tensor,
        lattice_t: torch.Tensor,
        atom_t: torch.Tensor,
        num_atoms: torch.Tensor,
        batch_idx: torch.Tensor,
        t_graph: torch.Tensor,
        continuous_cond: torch.Tensor | None = None,
        composition_cond: torch.Tensor | None = None,
        spacegroup_cond: torch.Tensor | None = None,
        use_uncond_embedding: torch.Tensor | None = None,
    ):
        B = num_atoms.shape[0]
        t_scaled = t_graph.float() / float(T)

        lattice_features_t = denormalize_lattice_features_torch(lattice_t)
        lattice_matrix_t = lattice_matrix_from_features_torch(lattice_features_t)

        graph_ctx = (
            self.time_embed(t_scaled)
            + self.lattice_proj(lattice_t)
            + self.num_atoms_embed(num_atoms)
        )

        cond_per_graph = None
        cond_per_atom = None
        use_uncond_per_atom = None

        if self.conditional:
            if continuous_cond is None or composition_cond is None or spacegroup_cond is None:
                cond_per_graph = self.uncond_embedding.expand(B, -1)
                use_uncond_embedding = torch.ones(B, 1, dtype=torch.bool, device=frac_t.device)
            else:
                if use_uncond_embedding is None:
                    use_uncond_embedding = torch.zeros(B, 1, dtype=torch.bool, device=frac_t.device)
                scalar_embed = self.cond_scalar_embed(continuous_cond)
                comp_embed = self.cond_comp_embed(composition_cond)
                sg_embed = self.cond_spacegroup_embed(
                    torch.clamp(spacegroup_cond.long(), min=0, max=MAX_SPACEGROUP_NUMBER)
                )
                cond_embed = self.cond_fuse(torch.cat([scalar_embed, comp_embed, sg_embed], dim=-1))
                uncond_embed = self.uncond_embedding.expand(B, -1)
                cond_per_graph = torch.where(use_uncond_embedding, uncond_embed, cond_embed)

            graph_ctx = graph_ctx + cond_per_graph
            cond_per_atom = cond_per_graph[batch_idx]
            use_uncond_per_atom = use_uncond_embedding[batch_idx]

        h = self.atom_embed(atom_t) + self.coord_proj(frac_t) + graph_ctx[batch_idx]

        src, dst, edge_feat = self.build_periodic_edges(frac_t, lattice_matrix_t, num_atoms)
        for block in self.blocks:
            h = block(h, src, dst, edge_feat, cond_per_atom=cond_per_atom, use_uncond_per_atom=use_uncond_per_atom)

        pooled = torch.zeros(B, self.hidden_dim, device=h.device)
        pooled.index_add_(0, batch_idx, h)
        pooled = pooled / num_atoms.unsqueeze(-1)

        pred_coord_eps = self.coord_head(h)
        pred_atom_logits = self.atom_head(h)
        pred_lattice_eps = self.lattice_head(torch.cat([pooled, graph_ctx], dim=-1))
        return pred_coord_eps, pred_lattice_eps, pred_atom_logits

This short visible cell sets the model size and instantiates the unconditional base model. These are the easiest architectural knobs to play with.

HIDDEN_DIM = int(globals().get("HIDDEN_DIM", 256))
NUM_BLOCKS = int(globals().get("NUM_BLOCKS", 8))

base_model = MiniMatterGenDenoiser(
    NUM_ATOM_CLASSES,
    hidden_dim=HIDDEN_DIM,
    num_blocks=NUM_BLOCKS,
    conditional=False,
).to(device)

num_params = sum(p.numel() for p in base_model.parameters())
print(f"base model parameters: {num_params:,}")
print(f"Using HIDDEN_DIM={HIDDEN_DIM}, NUM_BLOCKS={NUM_BLOCKS}")

Try this: change HIDDEN_DIM or NUM_BLOCKS.
Question: which do you think matters more in this notebook: wider layers or more message-passing steps?

Answer

In a notebook-scale crystal model, wider layers usually matter first because they control how much geometric and chemical information can be stored at each atom. More message-passing steps help too, but if the hidden state is too small, extra propagation just moves around a weak signal.

7) Define the multi-part training objective

At a random timestep tt, we corrupt a clean crystal and train the denoiser to undo the corruption.

Loss

We use a weighted sum of three terms:

L=λcoordLcoord+λlatLlat+λatomLatom.\mathcal{L} = \lambda_{\text{coord}} \, \mathcal{L}_{\text{coord}} + \lambda_{\text{lat}} \, \mathcal{L}_{\text{lat}} + \lambda_{\text{atom}} \, \mathcal{L}_{\text{atom}}.

For coordinates and lattice we use a robust Smooth L1 loss:

Lcoord=SmoothL1 ⁣(ϵ^coord,ϵcoord),Llat=SmoothL1 ⁣(ϵ^lat,ϵlat),\mathcal{L}_{\text{coord}} = \operatorname{SmoothL1}\!\left(\hat{\epsilon}_{\text{coord}}, \epsilon_{\text{coord}}\right), \qquad \mathcal{L}_{\text{lat}} = \operatorname{SmoothL1}\!\left(\hat{\epsilon}_{\text{lat}}, \epsilon_{\text{lat}}\right),

and for atoms we use cross-entropy on the clean atom identity:

Latom=CE ⁣(^atom,a0).\mathcal{L}_{\text{atom}} = \operatorname{CE}\!\left(\hat{\ell}_{\text{atom}}, a_0\right).

Why Smooth L1 here?

Small toy crystal datasets can produce noisy batches. Smooth L1 makes the continuous losses less brittle than plain MSE.

Try this: double ATOM_LOSS_WEIGHT for one run. Do you get cleaner compositions at the cost of worse geometry?

The short cell below contains the main user-facing training knobs. The long helper cell after it handles evaluation, EMA-smoothed logging, early stopping, and checkpointing.

P_UNCOND = 0.2
COORD_LOSS_WEIGHT = 1.0
LATTICE_LOSS_WEIGHT = 1.0
ATOM_LOSS_WEIGHT = 1.5

ROBUST_BETA = 0.25
GRAD_CLIP_NORM = 0.5
CHECKPOINT_DIR = Path("mattergen_checkpoints")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
# @title
def compute_training_losses(model: nn.Module, batch: CrystalBatch, training: bool = True):
    batch = batch.to(device)
    B = batch.num_atoms.shape[0]
    t_graph = torch.randint(1, T + 1, (B,), device=device)

    frac_t, coord_eps_target = q_sample_wrapped_coords(batch.frac0, t_graph, batch.batch_idx)
    lattice_t, lattice_eps_target = q_sample_lattice(batch.lattice0, t_graph)
    atom_t, keep_mask = q_sample_atom_types(batch.atom_tokens0, t_graph, batch.batch_idx)

    continuous_cond = None
    composition_cond = None
    spacegroup_cond = None
    use_uncond = None

    if getattr(model, "conditional", False):
        continuous_cond = batch.continuous_conditions
        composition_cond = batch.composition_conditions
        spacegroup_cond = batch.spacegroup_conditions
        if training:
            use_uncond = (torch.rand(B, 1, device=device) < P_UNCOND)
        else:
            use_uncond = torch.zeros(B, 1, device=device, dtype=torch.bool)

    pred_coord_eps, pred_lattice_eps, pred_atom_logits = model(
        frac_t=frac_t,
        lattice_t=lattice_t,
        atom_t=atom_t,
        num_atoms=batch.num_atoms,
        batch_idx=batch.batch_idx,
        t_graph=t_graph,
        continuous_cond=continuous_cond,
        composition_cond=composition_cond,
        spacegroup_cond=spacegroup_cond,
        use_uncond_embedding=use_uncond,
    )

    coord_loss = F.smooth_l1_loss(pred_coord_eps, coord_eps_target, beta=ROBUST_BETA)
    lattice_loss = F.smooth_l1_loss(pred_lattice_eps, lattice_eps_target, beta=ROBUST_BETA)
    atom_loss = F.cross_entropy(pred_atom_logits, batch.atom_tokens0)

    total = (
        COORD_LOSS_WEIGHT * coord_loss
        + LATTICE_LOSS_WEIGHT * lattice_loss
        + ATOM_LOSS_WEIGHT * atom_loss
    )
    metrics = {
        "loss": float(total.detach().cpu()),
        "coord_loss": float(coord_loss.detach().cpu()),
        "lattice_loss": float(lattice_loss.detach().cpu()),
        "atom_loss": float(atom_loss.detach().cpu()),
        "masked_atom_fraction": float((atom_t == MASK_TOKEN).float().mean().detach().cpu()),
    }
    return total, metrics

@torch.no_grad()
def evaluate_model(model: nn.Module, loader: DataLoader):
    model.eval()
    logs = []
    for batch in loader:
        _, metrics = compute_training_losses(model, batch, training=False)
        logs.append(metrics)
    return pd.DataFrame(logs).mean().to_dict()

def save_training_checkpoint(
    path: Path,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler,
    epoch: int,
    history_rows,
    run_name: str,
    best_val_loss: float,
    best_epoch: int,
):
    payload = {
        "run_name": run_name,
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict() if scheduler is not None else None,
        "history": history_rows,
        "best_val_loss": best_val_loss,
        "best_epoch": best_epoch,
        "rng_state_torch": torch.get_rng_state(),
        "rng_state_numpy": {
            "bit_generator": np.random.get_state()[0],
            "state": np.random.get_state()[1].tolist(),
            "pos": int(np.random.get_state()[2]),
            "has_gauss": int(np.random.get_state()[3]),
            "cached_gaussian": float(np.random.get_state()[4]),
        },
        "rng_state_python": random.getstate(),
    }
    torch.save(payload, path)

def restore_rng_states_from_checkpoint(ckpt: dict):
    if ckpt.get("rng_state_torch") is not None:
        torch.set_rng_state(ckpt["rng_state_torch"])

    np_state = ckpt.get("rng_state_numpy")
    if isinstance(np_state, dict) and "state" in np_state:
        np.random.set_state((
            np_state["bit_generator"],
            np.array(np_state["state"], dtype=np.uint32),
            int(np_state["pos"]),
            int(np_state["has_gauss"]),
            float(np_state["cached_gaussian"]),
        ))

    py_state = ckpt.get("rng_state_python")
    if py_state is not None:
        random.setstate(py_state)

def load_model_checkpoint(
    model: nn.Module,
    path: str | Path,
    optimizer=None,
    scheduler=None,
    map_location=None,
    restore_rng_state: bool = False,
):
    ckpt = torch.load(path, map_location=map_location or device, weights_only=False)
    model.load_state_dict(ckpt["model_state"])
    if optimizer is not None and ckpt.get("optimizer_state") is not None:
        optimizer.load_state_dict(ckpt["optimizer_state"])
    if scheduler is not None and ckpt.get("scheduler_state") is not None:
        scheduler.load_state_dict(ckpt["scheduler_state"])
    if restore_rng_state:
        restore_rng_states_from_checkpoint(ckpt)
    return ckpt

def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    epochs: int,
    lr: float,
    run_name: str,
    weight_decay: float = 1e-4,
    warmup_frac: float = 0.1,
    ema_decay: float = 0.95,
    patience: int | None = 6,
    min_delta: float = 1e-3,
    checkpoint_dir: str | Path = CHECKPOINT_DIR,
    restore_best_at_end: bool = True,
):
    checkpoint_dir = Path(checkpoint_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    total_steps = max(1, epochs * len(train_loader))
    warmup_steps = max(1, int(total_steps * warmup_frac))

    def lr_lambda(step: int):
        if step < warmup_steps:
            return float(step + 1) / float(warmup_steps)
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return 0.1 + 0.9 * 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    history = []
    global_step = 0
    ema_loss = None
    best_val_loss = float("inf")
    best_epoch = 0
    epochs_since_improvement = 0
    stopped_early = False

    last_ckpt_path = checkpoint_dir / f"{run_name}_last.pt"
    best_ckpt_path = checkpoint_dir / f"{run_name}_best.pt"

    for epoch in range(1, epochs + 1):
        model.train()
        train_logs = []
        epoch_ema_values = []

        pbar = tqdm(train_loader, desc=f"epoch {epoch}/{epochs}")
        for batch in pbar:
            optimizer.zero_grad(set_to_none=True)
            loss, metrics = compute_training_losses(model, batch, training=True)

            if not torch.isfinite(loss):
                print(f"Skipping non-finite batch at epoch {epoch}")
                continue

            loss.backward()
            grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRAD_CLIP_NORM)
            optimizer.step()
            scheduler.step()
            global_step += 1

            batch_loss = float(metrics["loss"])
            ema_loss = batch_loss if ema_loss is None else (ema_decay * ema_loss + (1.0 - ema_decay) * batch_loss)
            epoch_ema_values.append(float(ema_loss))

            metrics["ema_loss"] = float(ema_loss)
            metrics["grad_norm"] = float(grad_norm.detach().cpu()) if torch.is_tensor(grad_norm) else float(grad_norm)
            metrics["lr"] = float(optimizer.param_groups[0]["lr"])
            train_logs.append(metrics)

            pbar.set_postfix(
                loss=f'{metrics["loss"]:.3f}',
                ema=f'{metrics["ema_loss"]:.3f}',
                atom=f'{metrics["atom_loss"]:.3f}',
                lr=f'{metrics["lr"]:.1e}'
            )

        if len(train_logs) == 0:
            raise RuntimeError("All training batches were skipped. Try lowering the learning rate.")

        train_mean = pd.DataFrame(train_logs).mean().to_dict()
        val_mean = evaluate_model(model, val_loader)

        row = {
            "epoch": epoch,
            "train_loss": train_mean["loss"],
            "train_loss_ema": float(np.mean(epoch_ema_values)),
            "train_coord_loss": train_mean["coord_loss"],
            "train_lattice_loss": train_mean["lattice_loss"],
            "train_atom_loss": train_mean["atom_loss"],
            "train_grad_norm": train_mean.get("grad_norm", float("nan")),
            "train_lr": train_mean.get("lr", float("nan")),
            "val_loss": val_mean["loss"],
            "val_coord_loss": val_mean["coord_loss"],
            "val_lattice_loss": val_mean["lattice_loss"],
            "val_atom_loss": val_mean["atom_loss"],
        }

        improved = row["val_loss"] < (best_val_loss - min_delta)
        row["is_best"] = bool(improved)

        if improved:
            best_val_loss = row["val_loss"]
            best_epoch = epoch
            epochs_since_improvement = 0
            save_training_checkpoint(
                best_ckpt_path,
                model,
                optimizer,
                scheduler,
                epoch=epoch,
                history_rows=history + [row],
                run_name=run_name,
                best_val_loss=best_val_loss,
                best_epoch=best_epoch,
            )
        else:
            epochs_since_improvement += 1

        row["epochs_since_improvement"] = epochs_since_improvement
        history.append(row)

        save_training_checkpoint(
            last_ckpt_path,
            model,
            optimizer,
            scheduler,
            epoch=epoch,
            history_rows=history,
            run_name=run_name,
            best_val_loss=best_val_loss,
            best_epoch=best_epoch,
        )

        print(row)

        if patience is not None and epochs_since_improvement >= patience:
            print(
                f"Early stopping at epoch {epoch}. "
                f"Best val_loss={best_val_loss:.4f} from epoch {best_epoch}."
            )
            stopped_early = True
            break

    if restore_best_at_end and best_ckpt_path.exists():
        _ = load_model_checkpoint(model, best_ckpt_path, map_location=device)

    history_df = pd.DataFrame(history)
    run_info = {
        "run_name": run_name,
        "best_checkpoint": str(best_ckpt_path),
        "last_checkpoint": str(last_ckpt_path),
        "best_val_loss": float(best_val_loss),
        "best_epoch": int(best_epoch),
        "stopped_early": bool(stopped_early),
        "epochs_completed": int(len(history_df)),
    }
    print("run info:", run_info)
    return history_df, run_info

8) Train the unconditional base model

This is the first real training phase.
We train without any property conditioning so the model first learns a generic crystal distribution.

What to watch

  • train loss tells you whether optimization is working,

  • EMA-smoothed train loss is easier to read than the raw batch loss,

  • validation loss tells you whether the model is improving on held-out structures,

  • early stopping prevents wasting time after the curve plateaus.

Good first-run defaults

The quick preset keeps the denoiser, batch size, and epoch budget small enough for a first pass through the full notebook.
The full preset strengthens the model and training budget without changing the overall workflow.
If you manually widen the dataset further, it usually makes sense to raise the epoch cap again.

The cell below launches the full training loop. Read the arguments as “how long to train”, “learning rate”, and “when to stop if validation stops improving.”

BASE_MAX_EPOCHS = int(globals().get("BASE_MAX_EPOCHS", 200))
BASE_PATIENCE = int(globals().get("BASE_PATIENCE", 50))

base_history, base_run = train_model(
    base_model,
    train_loader,
    val_loader,
    epochs=BASE_MAX_EPOCHS,
    lr=1e-4,
    run_name="base_model",
    patience=BASE_PATIENCE,
    ema_decay=0.95,
)
print(base_run)
base_history.tail()

This plot is your first health check. Ideally, the validation loss should go down early and then level off.

plt.figure(figsize=(8, 4))
plt.plot(base_history["epoch"], base_history["train_loss"], alpha=0.35, label="train")
plt.plot(base_history["epoch"], base_history["train_loss_ema"], label="train (EMA)")
plt.plot(base_history["epoch"], base_history["val_loss"], label="val")
plt.axvline(base_run["best_epoch"], linestyle="--", alpha=0.6, label=f'best epoch={base_run["best_epoch"]}')
plt.xlabel("epoch")
plt.ylabel("loss")
plt.title("Unconditional base model learning curve")
plt.legend()
plt.show()

Reflection

If the validation loss barely moves, check the dataset and model size.
If training is unstable, reduce the learning rate or broaden the dataset slightly.

Try this: lower the learning rate to 5e-5 and compare the curve.
Question: why might a tiny, highly filtered dataset make the diffusion problem harder rather than easier?

Answer

A tiny, highly filtered dataset can actually make the problem harder because the model sees too few structural variations to learn robust denoising rules. Instead of learning a broad crystal prior, it can overfit a narrow set of motifs and become fragile when sampling starts from heavily corrupted states.

9) Reverse sampling and classifier-free guidance

Now that the base model is trained, we need the reverse diffusion loop that turns pure noise and masked atom types back into a candidate crystal.

Two uses of the same sampler

  • For the unconditional base model, each reverse step is just a single denoiser call.

  • For the conditional adapter that we train later, we reuse the same loop and add classifier-free guidance.

Guidance rule for the conditional case

Once the conditional model exists, we run the denoiser twice at each step:

  • once without the condition,

  • once with the condition.

Then we combine the predictions as

y^guided=y^uncond+s(y^condy^uncond),\hat{y}_{\text{guided}} = \hat{y}_{\text{uncond}} + s \left(\hat{y}_{\text{cond}} - \hat{y}_{\text{uncond}}\right),

where ss is the guidance scale.

Interpretation

  • s=1s = 1 means “just use the conditional prediction,”

  • larger ss pushes the sample harder toward the requested target package,

  • too large a value can produce worse or less stable structures.

The sampling implementation is a bit long, so it is collapsed below.

Try this later: after training the adapter, sample with guidance scales 0.0, 1.0, 2.0, and 4.0. At what point does stronger guidance stop helping?

# @title
@torch.no_grad()
def guided_predictions(
    model: MiniMatterGenDenoiser,
    frac_t: torch.Tensor,
    lattice_t: torch.Tensor,
    atom_t: torch.Tensor,
    num_atoms: torch.Tensor,
    batch_idx: torch.Tensor,
    t_graph: torch.Tensor,
    continuous_cond: torch.Tensor | None,
    composition_cond: torch.Tensor | None,
    spacegroup_cond: torch.Tensor | None,
    guidance_scale: float,
):
    if not model.conditional or continuous_cond is None or composition_cond is None or spacegroup_cond is None:
        return model(
            frac_t=frac_t,
            lattice_t=lattice_t,
            atom_t=atom_t,
            num_atoms=num_atoms,
            batch_idx=batch_idx,
            t_graph=t_graph,
            continuous_cond=None,
            composition_cond=None,
            spacegroup_cond=None,
            use_uncond_embedding=None,
        )

    B = num_atoms.shape[0]
    uncond_mask = torch.ones(B, 1, dtype=torch.bool, device=frac_t.device)
    cond_mask = torch.zeros(B, 1, dtype=torch.bool, device=frac_t.device)

    pred_coord_u, pred_lattice_u, pred_atom_u = model(
        frac_t=frac_t,
        lattice_t=lattice_t,
        atom_t=atom_t,
        num_atoms=num_atoms,
        batch_idx=batch_idx,
        t_graph=t_graph,
        continuous_cond=continuous_cond,
        composition_cond=composition_cond,
        spacegroup_cond=spacegroup_cond,
        use_uncond_embedding=uncond_mask,
    )
    pred_coord_c, pred_lattice_c, pred_atom_c = model(
        frac_t=frac_t,
        lattice_t=lattice_t,
        atom_t=atom_t,
        num_atoms=num_atoms,
        batch_idx=batch_idx,
        t_graph=t_graph,
        continuous_cond=continuous_cond,
        composition_cond=composition_cond,
        spacegroup_cond=spacegroup_cond,
        use_uncond_embedding=cond_mask,
    )

    pred_coord = pred_coord_u + guidance_scale * (pred_coord_c - pred_coord_u)
    pred_lattice = pred_lattice_u + guidance_scale * (pred_lattice_c - pred_lattice_u)
    pred_atom = pred_atom_u + guidance_scale * (pred_atom_c - pred_atom_u)
    return pred_coord, pred_lattice, pred_atom

def ddpm_reverse_step_continuous(x_t, pred_eps, t_graph, batch_idx=None, wrap=False):
    beta_t = extract_schedule(torch.cat([torch.zeros(1), cont_betas]).to(x_t.device), t_graph, batch_idx, x_t)
    alpha_t = extract_schedule(torch.cat([torch.ones(1), cont_alphas]).to(x_t.device), t_graph, batch_idx, x_t)
    alpha_bar_t = extract_schedule(cont_alpha_bars.to(x_t.device), t_graph, batch_idx, x_t)

    noise = torch.randn_like(x_t)
    mask_t = (t_graph if batch_idx is None else t_graph[batch_idx]).view(-1, *([1] * (x_t.ndim - 1)))
    noise = torch.where(mask_t > 1, noise, torch.zeros_like(noise))
    mean = (x_t - (beta_t / torch.clamp(torch.sqrt(1.0 - alpha_bar_t), min=1e-6)) * pred_eps) / torch.clamp(torch.sqrt(alpha_t), min=1e-6)
    x_prev = mean + torch.sqrt(beta_t) * noise

    if wrap:
        x_prev = wrap_centered_torch(x_prev)
    return x_prev

def discrete_reverse_step_absorbing(atom_t, pred_atom_logits, t_graph, batch_idx):
    pred_x0 = torch.distributions.Categorical(logits=pred_atom_logits).sample()

    t_per_atom = t_graph[batch_idx]
    ab_t = disc_alpha_bars.to(atom_t.device)[t_per_atom]
    ab_prev = disc_alpha_bars.to(atom_t.device)[torch.clamp(t_per_atom - 1, min=0)]

    denom = torch.clamp(1.0 - ab_t, min=1e-6)
    reveal_prob = torch.where(
        t_per_atom > 1,
        torch.clamp((ab_prev - ab_t) / denom, min=0.0, max=1.0),
        torch.ones_like(ab_t),
    )

    masked = atom_t == MASK_TOKEN
    reveal = masked & (torch.rand_like(reveal_prob) < reveal_prob)
    atom_prev = atom_t.clone()
    atom_prev[reveal] = pred_x0[reveal]
    atom_prev = torch.where(t_per_atom == 1, pred_x0, atom_prev)
    return atom_prev

def sample_num_atoms(batch_size: int):
    choices = np.array([r["num_atoms"] for r in train_data], dtype=np.int64)
    sampled = np.random.choice(choices, size=batch_size, replace=True)
    return torch.tensor(sampled, dtype=torch.long, device=device)

def normalize_composition_vector(vec: np.ndarray) -> np.ndarray:
    vec = np.asarray(vec, dtype=np.float32).reshape(-1)
    vec = np.clip(vec, 0.0, None)
    total = float(vec.sum())
    if total <= 1e-8:
        return composition_mean.copy()
    return (vec / total).astype(np.float32)

def composition_vector_from_formula(formula: str) -> np.ndarray:
    comp = Composition(formula)
    vec = np.zeros(NUM_ATOM_CLASSES, dtype=np.float32)
    total = 0.0
    for el, amt in comp.get_el_amt_dict().items():
        z = Element(el).Z
        if int(z) in z_to_token:
            vec[z_to_token[int(z)]] = float(amt)
            total += float(amt)
    if total <= 1e-8:
        raise ValueError(
            f"Formula {formula!r} does not overlap with the current tiny dataset vocabulary. "
            "Choose a formula using elements seen in the curated dataset."
        )
    return normalize_composition_vector(vec)

def composition_vector_to_formula_hint(vec: np.ndarray, top_k: int = 6) -> str:
    vec = np.asarray(vec, dtype=np.float32)
    idx = np.argsort(-vec)[:top_k]
    parts = []
    for i in idx:
        if vec[int(i)] <= 1e-4:
            continue
        sym = Element.from_Z(token_to_z[int(i)]).symbol
        parts.append(f"{sym}:{float(vec[int(i)]):.2f}")
    return ", ".join(parts) if parts else "mean composition"

def resolve_spacegroup_number(target_spacegroup) -> int:
    if target_spacegroup is None:
        return int(spacegroup_mode)
    if isinstance(target_spacegroup, str):
        text = target_spacegroup.strip()
        if text.isdigit():
            return int(np.clip(int(text), 0, MAX_SPACEGROUP_NUMBER))
        if text in SPACEGROUP_SYMBOL_TO_NUM:
            return int(SPACEGROUP_SYMBOL_TO_NUM[text])
        raise ValueError(
            f"Unknown space-group symbol {target_spacegroup!r} for this tiny dataset. "
            "Use an integer 1..230 or a symbol already present in the curated records."
        )
    return int(np.clip(int(target_spacegroup), 0, MAX_SPACEGROUP_NUMBER))

def build_condition_tensors(
    batch_size: int,
    target_density: float | None = None,
    target_band_gap: float | None = None,
    target_formula: str | None = None,
    target_composition: np.ndarray | list | tuple | None = None,
    target_spacegroup: int | str | None = None,
    target_record: dict | None = None,
    target_conditions: dict | list | tuple | np.ndarray | None = None,
):
    if (
        target_conditions is None
        and target_density is None
        and target_band_gap is None
        and target_formula is None
        and target_composition is None
        and target_spacegroup is None
        and target_record is None
    ):
        return None

    continuous_values = cond_mean.astype(np.float32).copy()
    composition_values = composition_mean.astype(np.float32).copy()
    spacegroup_value = int(spacegroup_mode)

    if target_record is not None:
        continuous_values[CONDITION_INDEX["density"]] = float(target_record.get("density", continuous_values[0]))
        continuous_values[CONDITION_INDEX["band_gap"]] = float(target_record.get("band_gap", continuous_values[1]))
        if "composition_cond" in target_record:
            composition_values = normalize_composition_vector(target_record["composition_cond"])
        else:
            composition_values = composition_fraction_vector_from_atomic_numbers(target_record["atomic_numbers"])
        spacegroup_value = int(target_record.get("spacegroup_number", spacegroup_value))

    if isinstance(target_conditions, dict):
        for name, value in target_conditions.items():
            if name in CONDITION_INDEX and value is not None:
                continuous_values[CONDITION_INDEX[name]] = float(value)
    elif target_conditions is not None:
        arr = np.asarray(target_conditions, dtype=np.float32).reshape(-1)
        if len(arr) != len(CONDITION_NAMES):
            raise ValueError(f"target_conditions must have length {len(CONDITION_NAMES)}")
        continuous_values[:] = arr

    if target_density is not None:
        continuous_values[CONDITION_INDEX["density"]] = float(target_density)
    if target_band_gap is not None:
        continuous_values[CONDITION_INDEX["band_gap"]] = float(target_band_gap)

    if target_formula is not None:
        composition_values = composition_vector_from_formula(target_formula)
    if target_composition is not None:
        composition_values = normalize_composition_vector(np.asarray(target_composition, dtype=np.float32))

    if target_spacegroup is not None:
        spacegroup_value = resolve_spacegroup_number(target_spacegroup)

    continuous_norm = (continuous_values - cond_mean) / cond_std

    return {
        "continuous": torch.tensor(
            np.repeat(continuous_norm[None, :], batch_size, axis=0),
            dtype=torch.float32,
            device=device,
        ),
        "composition": torch.tensor(
            np.repeat(composition_values[None, :], batch_size, axis=0),
            dtype=torch.float32,
            device=device,
        ),
        "spacegroup": torch.full(
            (batch_size,),
            int(np.clip(spacegroup_value, 0, MAX_SPACEGROUP_NUMBER)),
            dtype=torch.long,
            device=device,
        ),
    }

def pretty_condition_target(cond_tensors: dict | None):
    if cond_tensors is None:
        return {"mode": "unconditional"}

    scalar_vals = cond_tensors["continuous"][0].detach().cpu().numpy() * cond_std + cond_mean
    comp_vals = cond_tensors["composition"][0].detach().cpu().numpy()
    sg_num = int(cond_tensors["spacegroup"][0].detach().cpu().item())
    return {
        "density": float(scalar_vals[CONDITION_INDEX["density"]]),
        "band_gap": float(scalar_vals[CONDITION_INDEX["band_gap"]]),
        "composition_hint": composition_vector_to_formula_hint(comp_vals),
        "spacegroup_number": sg_num,
        "spacegroup_symbol": SPACEGROUP_NUM_TO_SYMBOL.get(sg_num, "unknown"),
    }

@torch.no_grad()
def sample_crystals(
    model: MiniMatterGenDenoiser,
    batch_size: int = 8,
    target_density: float | None = None,
    target_band_gap: float | None = None,
    target_formula: str | None = None,
    target_composition: np.ndarray | list | tuple | None = None,
    target_spacegroup: int | str | None = None,
    target_record: dict | None = None,
    target_conditions: dict | list | tuple | np.ndarray | None = None,
    guidance_scale: float = 2.0,
    num_atoms: torch.Tensor | None = None,
    record: bool = False,
):
    model.eval()

    if num_atoms is None:
        num_atoms = sample_num_atoms(batch_size)
    else:
        num_atoms = num_atoms.to(device)

    batch_idx = torch.repeat_interleave(torch.arange(batch_size, device=device), num_atoms)
    n_total = int(num_atoms.sum().item())

    frac_t = wrap_centered_torch(torch.randn(n_total, 3, device=device))
    lattice_t = torch.randn(batch_size, 6, device=device)
    atom_t = torch.full((n_total,), MASK_TOKEN, dtype=torch.long, device=device)

    cond_tensors = build_condition_tensors(
        batch_size=batch_size,
        target_density=target_density,
        target_band_gap=target_band_gap,
        target_formula=target_formula,
        target_composition=target_composition,
        target_spacegroup=target_spacegroup,
        target_record=target_record,
        target_conditions=target_conditions,
    )

    records = []
    checkpoints = sorted(set(np.linspace(T, 1, 10).round().astype(int).tolist()), reverse=True)
    for step in range(T, 0, -1):
        t_graph = torch.full((batch_size,), step, dtype=torch.long, device=device)

        pred_coord_eps, pred_lattice_eps, pred_atom_logits = guided_predictions(
            model=model,
            frac_t=frac_t,
            lattice_t=lattice_t,
            atom_t=atom_t,
            num_atoms=num_atoms,
            batch_idx=batch_idx,
            t_graph=t_graph,
            continuous_cond=None if cond_tensors is None else cond_tensors["continuous"],
            composition_cond=None if cond_tensors is None else cond_tensors["composition"],
            spacegroup_cond=None if cond_tensors is None else cond_tensors["spacegroup"],
            guidance_scale=guidance_scale,
        )

        frac_t = ddpm_reverse_step_continuous(frac_t, pred_coord_eps, t_graph, batch_idx=batch_idx, wrap=True)
        lattice_t = ddpm_reverse_step_continuous(lattice_t, pred_lattice_eps, t_graph, batch_idx=None, wrap=False)
        lattice_t = clamp_lattice_features_norm_torch(lattice_t)
        atom_t = discrete_reverse_step_absorbing(atom_t, pred_atom_logits, t_graph, batch_idx)

        if record and (step in checkpoints or step == 1):
            records.append(
                {
                    "step": step,
                    "frac_t": frac_t.detach().cpu().clone(),
                    "lattice_t": lattice_t.detach().cpu().clone(),
                    "atom_t": atom_t.detach().cpu().clone(),
                    "num_atoms": num_atoms.detach().cpu().clone(),
                    "batch_idx": batch_idx.detach().cpu().clone(),
                    "target_conditions": pretty_condition_target(cond_tensors),
                }
            )

    return frac_t.cpu(), lattice_t.cpu(), atom_t.cpu(), num_atoms.cpu(), records

def decode_sampled_structures(frac_t, lattice_t, atom_t, num_atoms):
    structures = []
    offset = 0
    for i, n in enumerate(num_atoms.tolist()):
        frac_i = frac_t[offset : offset + n].numpy()
        atom_i = atom_t[offset : offset + n].numpy()
        offset += n

        valid = atom_i != MASK_TOKEN
        if valid.sum() == 0:
            valid[0] = True
            atom_i[0] = 0

        structure = x0_to_structure(frac_i[valid], lattice_t[i].numpy(), atom_i[valid])
        structures.append(structure)
    return structures

def safe_structure_density(structure: Structure) -> float:
    try:
        volume = float(structure.volume)
        if (not np.isfinite(volume)) or volume <= 1e-8:
            return float("nan")
        density = float(structure.density)
        return density if np.isfinite(density) else float("nan")
    except Exception:
        return float("nan")

def safe_spacegroup_info(structure: Structure):
    try:
        sga = SpacegroupAnalyzer(structure, symprec=0.1)
        return int(sga.get_space_group_number()), str(sga.get_space_group_symbol())
    except Exception:
        return np.nan, "unknown"

def safe_min_pair_distance(structure: Structure) -> float:
    try:
        dm = np.asarray(structure.distance_matrix, dtype=float)
    except Exception:
        return float("nan")
    if dm.ndim != 2 or dm.shape[0] == 0:
        return float("nan")
    mask = np.isfinite(dm) & (dm > 1e-8)
    if not mask.any():
        return float("nan")
    return float(dm[mask].min())


def lightweight_validity_dict(structure: Structure) -> dict:
    try:
        num_atoms = int(len(structure))
    except Exception:
        num_atoms = 0

    try:
        volume = float(structure.volume)
    except Exception:
        volume = float("nan")

    density = safe_structure_density(structure)
    min_pair_distance = safe_min_pair_distance(structure)

    try:
        lengths = np.asarray(structure.lattice.abc, dtype=float)
        angles = np.asarray(structure.lattice.angles, dtype=float)
    except Exception:
        lengths = np.array([np.nan, np.nan, np.nan], dtype=float)
        angles = np.array([np.nan, np.nan, np.nan], dtype=float)

    valid_volume = bool(np.isfinite(volume) and volume > 1e-3)
    valid_density = bool(np.isfinite(density) and 0.2 <= density <= 25.0)
    distance_ok = bool(np.isfinite(min_pair_distance) and min_pair_distance >= 0.6)
    lengths_ok = bool(np.all(np.isfinite(lengths)) and np.all((lengths >= 2.0) & (lengths <= 25.0)))
    angles_ok = bool(np.all(np.isfinite(angles)) and np.all((angles >= 20.0) & (angles <= 160.0)))
    atom_count_ok = bool(num_atoms >= 2)

    reasons = []
    if not atom_count_ok:
        reasons.append("too few atoms")
    if not valid_volume:
        reasons.append("bad cell volume")
    if not valid_density:
        reasons.append("bad density")
    if not distance_ok:
        reasons.append("atoms too close")
    if not lengths_ok:
        reasons.append("bad lattice lengths")
    if not angles_ok:
        reasons.append("bad lattice angles")

    volume_per_atom = float(volume / max(num_atoms, 1)) if np.isfinite(volume) else float("nan")
    lightweight_valid = len(reasons) == 0

    return {
        "min_pair_distance": min_pair_distance,
        "volume_per_atom": volume_per_atom,
        "valid_volume": valid_volume,
        "valid_density": valid_density,
        "distance_ok": distance_ok,
        "lattice_lengths_ok": lengths_ok,
        "lattice_angles_ok": angles_ok,
        "lightweight_valid": lightweight_valid,
        "failure_reason": "ok" if lightweight_valid else "; ".join(reasons),
    }


def summarize_structures(structures):
    rows = []
    for idx, s in enumerate(structures):
        try:
            formula = s.composition.reduced_formula
        except Exception:
            formula = "INVALID"

        try:
            volume = float(s.volume)
        except Exception:
            volume = float("nan")

        density = safe_structure_density(s)
        sg_num, sg_symbol = safe_spacegroup_info(s)
        validity = lightweight_validity_dict(s)
        rows.append(
            {
                "sample_id": idx,
                "formula": formula,
                "num_atoms": len(s),
                "density": density,
                "volume": volume,
                "spacegroup_number": sg_num,
                "spacegroup_symbol": sg_symbol,
                **validity,
            }
        )
    return pd.DataFrame(rows)


def validity_report(summary_df: pd.DataFrame, label: str) -> pd.DataFrame:
    if summary_df.empty:
        return pd.DataFrame(
            [{
                "label": label,
                "n_samples": 0,
                "valid_count": 0,
                "valid_fraction": float("nan"),
                "median_density": float("nan"),
                "median_min_pair_distance": float("nan"),
                "most_common_issue": "no samples",
            }]
        )

    invalid_df = summary_df.loc[~summary_df["lightweight_valid"]]
    if invalid_df.empty:
        most_common_issue = "all passed"
    else:
        most_common_issue = invalid_df["failure_reason"].value_counts().idxmax()

    return pd.DataFrame(
        [{
            "label": label,
            "n_samples": int(len(summary_df)),
            "valid_count": int(summary_df["lightweight_valid"].sum()),
            "valid_fraction": float(summary_df["lightweight_valid"].mean()),
            "median_density": float(summary_df["density"].median()),
            "median_min_pair_distance": float(summary_df["min_pair_distance"].median()),
            "most_common_issue": most_common_issue,
        }]
    )

10) Unconditional generation

Let’s sample from the base model before we teach the network any target properties at all.

This is the cleanest test of whether the unconditional prior has learned a plausible crystal distribution.

Lightweight validity checks

Every generated-structure summary in this notebook includes a small geometric sanity screen:

  • finite positive cell volume,

  • density in a broad but reasonable range,

  • minimum pair distance above 0.6A˚0.6\,\AA,

  • lattice lengths and angles inside a broad physical window.

These are fast notebook checks, not a substitute for relaxation, symmetry cleanup, oxidation-state analysis, or DFT validation. They are still useful because they catch the most obvious broken decodes before you inspect galleries or save CIF files.

The first cell generates and summarizes samples. The second one shows a few of them as static ASE previews.

frac_u, lattice_u, atom_u, num_u, record_u = sample_crystals(
    base_model,
    batch_size=6,
    guidance_scale=1.0,
    record=True,
)
uncond_structures = decode_sampled_structures(frac_u, lattice_u, atom_u, num_u)
uncond_summary = summarize_structures(uncond_structures)
uncond_validity = validity_report(uncond_summary, "unconditional")
display(uncond_summary)
display(uncond_validity)
if (~uncond_summary["lightweight_valid"]).any():
    display(uncond_summary.loc[~uncond_summary["lightweight_valid"], ["sample_id", "formula", "min_pair_distance", "failure_reason"]])
display(
    show_structures(
        uncond_structures[:4],
        [f'{s.composition.reduced_formula} | density={safe_structure_density(s):.2f}' for s in uncond_structures[:4]],
    )
)

What to look for

  • Are the compositions and densities at least vaguely reasonable?

  • Do different samples look visibly different?

  • Are the unit cells non-degenerate?

  • Do the samples look like they came from one broad crystal family rather than collapsing to a single motif?

If not, the usual fixes are:

  • a slightly larger or cleaner dataset,

  • more training,

  • a better-matched distribution over num_atoms,

  • or a slightly larger denoiser.

Once the unconditional prior looks usable, we can ask the next question: can we steer it without destroying plausibility?

11) Add a conditional adapter for density, band gap, composition, and space group

Now that we have seen what the base model can generate on its own, we can move closer to the MatterGen workflow:

  1. start from the unconditional base model,

  2. copy its weights into a conditional model,

  3. fine-tune only the condition-specific pieces and output heads.

Conditioning package

We use four targets, but they enter through three encoders:

  • scalar properties:

    cscalar=[density,band gap]\mathbf{c}_{\text{scalar}} = [\text{density}, \text{band gap}]
  • composition fractions over the dataset vocabulary:

    ccompRK,ici=1\mathbf{c}_{\text{comp}} \in \mathbb{R}^{K}, \qquad \sum_i c_i = 1
  • a discrete space-group label:

    csg{1,,230}c_{\text{sg}} \in \{1, \dots, 230\}

Classifier-free dropout during training

Sometimes we hide the whole condition package during training.
That lets us use classifier-free guidance at sampling time later.

If puncondp_{\text{uncond}} is the dropout probability, the model learns both:

  • a conditional denoiser,

  • and an unconditional fallback denoiser.

Important interpretation note

Composition conditioning in this notebook is soft: it encourages the generated atom distribution toward a requested composition vector, but it does not guarantee an exact formula.

The code below builds the conditional model, matches its width and depth to the base model so weight transfer is meaningful, and then freezes most parameters so the fine-tune behaves like a lightweight adapter stage.

cond_model = MiniMatterGenDenoiser(NUM_ATOM_CLASSES, hidden_dim=HIDDEN_DIM, num_blocks=NUM_BLOCKS, conditional=True).to(device)

# Load base weights wherever the parameter shapes match.
base_state = base_model.state_dict()
cond_state = cond_model.state_dict()
for k, v in base_state.items():
    if k in cond_state and cond_state[k].shape == v.shape:
        cond_state[k] = v
cond_model.load_state_dict(cond_state)

# Optional: freeze everything except the conditional pieces and the output heads.
for name, param in cond_model.named_parameters():
    trainable = (
        ("cond_scalar_embed" in name)
        or ("cond_comp_embed" in name)
        or ("cond_spacegroup_embed" in name)
        or ("cond_fuse" in name)
        or ("uncond_embedding" in name)
        or ("adapter" in name)
        or ("coord_head" in name)
        or ("atom_head" in name)
        or ("lattice_head" in name)
    )
    param.requires_grad = trainable

num_trainable = sum(p.numel() for p in cond_model.parameters() if p.requires_grad)
num_total = sum(p.numel() for p in cond_model.parameters())
print(f"trainable parameters: {num_trainable:,} / {num_total:,}")

Then we train the adapter. This phase is usually faster than training from scratch because most of the geometric representation has already been learned.

ADAPTER_MAX_EPOCHS = int(globals().get("ADAPTER_MAX_EPOCHS", 64))
ADAPTER_PATIENCE = int(globals().get("ADAPTER_PATIENCE", 12))

adapter_history, adapter_run = train_model(
    cond_model,
    train_loader,
    val_loader,
    epochs=ADAPTER_MAX_EPOCHS,
    lr=1e-4,
    run_name="scalar_composition_spacegroup_adapter",
    patience=ADAPTER_PATIENCE,
    ema_decay=0.95,
)
print(adapter_run)
adapter_history.tail()
plt.figure(figsize=(8, 4))
plt.plot(adapter_history["epoch"], adapter_history["train_loss"], alpha=0.35, label="train")
plt.plot(adapter_history["epoch"], adapter_history["train_loss_ema"], label="train (EMA)")
plt.plot(adapter_history["epoch"], adapter_history["val_loss"], label="val")
plt.axvline(adapter_run["best_epoch"], linestyle="--", alpha=0.6, label=f'best epoch={adapter_run["best_epoch"]}')
plt.xlabel("epoch")
plt.ylabel("loss")
plt.title("Density + band-gap adapter learning curve")
plt.legend()
plt.show()

Try this: unfreeze the whole network instead of only the adapter-related pieces.
Try this too: keep scalar conditioning on, but ablate either the composition encoder or the space-group embedding.
Question: which targets seem easiest for the model to learn, and which ones most strongly affect sample plausibility?

Answer

Scalar targets such as density are often the easiest to learn because they are smooth and low-dimensional. Composition and space group usually affect plausibility more strongly: they constrain chemistry and symmetry more directly, so mistakes there can make a structure look unrealistic even when the scalar targets are roughly correct.

12) Conditional generation with scalar, composition, and space-group targets

Now we ask the adapter model to steer the unconditional prior toward different crystal-property regimes.

Important interpretation note

  • For density, we can directly summarize the generated structures after decoding.

  • For band gap, this notebook uses the MP band-gap labels during training but does not recompute the generated band gap afterward.

  • For composition, the target is a soft composition vector. The model is encouraged toward that stoichiometric mix, but exact formulas are not guaranteed.

  • For space group, we can run a symmetry analyzer on decoded structures, but for noisy/generated outputs this should be read as a rough diagnostic rather than a guaranteed exact match.

So the conditional examples are best interpreted as targeted generation demos, not as strict constrained generation.

Try this: change only one target at a time. For example, fix composition and space group but sweep density. That makes it easier to see which condition the model is actually responding to.

First we pick low/high targets from the training-set property distributions.

train_density_real = np.array([r["density"] for r in train_data], dtype=np.float32)
train_gap_real = np.array([r["band_gap"] for r in train_data], dtype=np.float32)

low_density_target = float(np.quantile(train_density_real, 0.2))
high_density_target = float(np.quantile(train_density_real, 0.8))
low_gap_target = float(np.quantile(train_gap_real, 0.2))
high_gap_target = float(np.quantile(train_gap_real, 0.8))

print("low target density:", low_density_target)
print("high target density:", high_density_target)
print("low target band gap:", low_gap_target)
print("high target band gap:", high_gap_target)

Then we sample four groups of structures: low-density, high-density, low-band-gap, and high-band-gap. After that, we will add a composition + space-group example anchored to one training crystal.

frac_low_d, lattice_low_d, atom_low_d, num_low_d, _ = sample_crystals(
    cond_model,
    batch_size=6,
    target_density=low_density_target,
    guidance_scale=2.0,
)
low_density_structures = decode_sampled_structures(frac_low_d, lattice_low_d, atom_low_d, num_low_d)
low_density_summary = summarize_structures(low_density_structures)

frac_high_d, lattice_high_d, atom_high_d, num_high_d, _ = sample_crystals(
    cond_model,
    batch_size=6,
    target_density=high_density_target,
    guidance_scale=2.0,
)
high_density_structures = decode_sampled_structures(frac_high_d, lattice_high_d, atom_high_d, num_high_d)
high_density_summary = summarize_structures(high_density_structures)

frac_low_g, lattice_low_g, atom_low_g, num_low_g, _ = sample_crystals(
    cond_model,
    batch_size=6,
    target_band_gap=low_gap_target,
    guidance_scale=2.0,
)
low_gap_structures = decode_sampled_structures(frac_low_g, lattice_low_g, atom_low_g, num_low_g)
low_gap_summary = summarize_structures(low_gap_structures)

frac_high_g, lattice_high_g, atom_high_g, num_high_g, _ = sample_crystals(
    cond_model,
    batch_size=6,
    target_band_gap=high_gap_target,
    guidance_scale=2.0,
)
high_gap_structures = decode_sampled_structures(frac_high_g, lattice_high_g, atom_high_g, num_high_g)
high_gap_summary = summarize_structures(high_gap_structures)

conditional_validity = pd.concat([
    validity_report(low_density_summary, "low density target"),
    validity_report(high_density_summary, "high density target"),
    validity_report(low_gap_summary, "low band-gap target"),
    validity_report(high_gap_summary, "high band-gap target"),
], ignore_index=True)

This plot helps you compare the targets against the empirical training distribution, and—where possible—the generated samples.

fig, axes = plt.subplots(1, 2, figsize=(10, 3.8))

axes[0].hist(train_density_real, bins=24, alpha=0.35, label="train densities")
axes[0].axvline(low_density_target, linestyle="--", label="low density target")
axes[0].axvline(high_density_target, linestyle="--", label="high density target")
axes[0].scatter(low_density_summary["density"], np.full(len(low_density_summary), -0.02), marker="x", label="generated low density")
axes[0].scatter(high_density_summary["density"], np.full(len(high_density_summary), -0.05), marker="x", label="generated high density")
axes[0].set_xlabel("density (g/cm³)")
axes[0].set_title("Density-conditioned sampling")
axes[0].legend(fontsize=8)

axes[1].hist(train_gap_real, bins=24, alpha=0.35, label="train band gaps")
axes[1].axvline(low_gap_target, linestyle="--", label="low band-gap target")
axes[1].axvline(high_gap_target, linestyle="--", label="high band-gap target")
axes[1].set_xlabel("band gap (eV)")
axes[1].set_title("Band-gap targets used for conditioning")
axes[1].legend(fontsize=8)

plt.tight_layout()
plt.show()

print("Low-density conditioned samples")
display(low_density_summary)
print("High-density conditioned samples")
display(high_density_summary)

print("Low-band-gap conditioned samples")
display(low_gap_summary)
print("High-band-gap conditioned samples")
display(high_gap_summary)
print("Lightweight validity report across conditional sweeps")
display(conditional_validity)

Finally, inspect a few generated structures from each scalar-property regime.

print("Low-density conditioned examples")
display(
    show_structures(
        low_density_structures[:4],
        [f'{s.composition.reduced_formula} | density={safe_structure_density(s):.2f}' for s in low_density_structures[:4]],
    )
)

print("High-density conditioned examples")
display(
    show_structures(
        high_density_structures[:4],
        [f'{s.composition.reduced_formula} | density={safe_structure_density(s):.2f}' for s in high_density_structures[:4]],
    )
)

print("Low-band-gap conditioned examples")
display(
    show_structures(
        low_gap_structures[:4],
        [f'{s.composition.reduced_formula} | target gap={low_gap_target:.2f} eV' for s in low_gap_structures[:4]],
    )
)

print("High-band-gap conditioned examples")
display(
    show_structures(
        high_gap_structures[:4],
        [f'{s.composition.reduced_formula} | target gap={high_gap_target:.2f} eV' for s in high_gap_structures[:4]],
    )
)

Composition + space-group conditioning demo

A simple way to demonstrate the richer conditioning is to borrow a target composition and space group from one real crystal.

In the next cell we choose one anchor crystal from the training set and ask the model to sample new structures with:

  • the same target composition,

  • the same target space-group label,

  • the same number of atoms,

  • and roughly the same density.

This is not strict constrained generation, but it is a good teaching experiment because you can compare “what was requested” versus “what came out”.

anchor_record = train_data[0]
print("Anchor record")
print({
    "formula": anchor_record["formula"],
    "density": float(anchor_record["density"]),
    "band_gap": float(anchor_record["band_gap"]),
    "spacegroup_number": int(anchor_record["spacegroup_number"]),
    "spacegroup_symbol": anchor_record["spacegroup_symbol"],
    "num_atoms": int(anchor_record["num_atoms"]),
})

anchor_formula = Composition(anchor_record["formula"]).reduced_formula
anchor_num_atoms = torch.full((6,), int(anchor_record["num_atoms"]), dtype=torch.long)

frac_comp_sg, lattice_comp_sg, atom_comp_sg, num_comp_sg, _ = sample_crystals(
    cond_model,
    batch_size=6,
    target_formula=anchor_record["formula"],
    target_spacegroup=int(anchor_record["spacegroup_number"]),
    target_density=float(anchor_record["density"]),
    guidance_scale=3.0,
    num_atoms=anchor_num_atoms,
)

comp_sg_structures = decode_sampled_structures(frac_comp_sg, lattice_comp_sg, atom_comp_sg, num_comp_sg)
comp_sg_summary = summarize_structures(comp_sg_structures).assign(
    formula_match=lambda df: df["formula"].eq(anchor_formula),
    spacegroup_match=lambda df: df["spacegroup_number"].eq(int(anchor_record["spacegroup_number"])),
)
comp_sg_validity = validity_report(comp_sg_summary, "composition + space-group target")
display(comp_sg_summary)
display(comp_sg_validity)
print("Formula-match fraction:", float(comp_sg_summary["formula_match"].mean()))
print("Space-group-match fraction:", float(comp_sg_summary["spacegroup_match"].mean()))

Things to notice:

  • Do the generated formulas look related to the target composition?

  • Does the inferred space group sometimes match the requested one?

  • What happens if you keep the composition fixed but change the target density or guidance scale?

A fun follow-up is to swap target_formula=anchor_record["formula"] for a hand-written formula such as "SrTiO3" or "BaTiO3" provided those elements appear in your tiny curated dataset.

display(
    show_structures(
        [anchor_record["structure"]] + comp_sg_structures[:3],
        labels=[
            f'target anchor | {anchor_record["formula"]} | sg={anchor_record["spacegroup_symbol"]}',
            *[
                f'{s.composition.reduced_formula} | sg={safe_spacegroup_info(s)[1]}'
                for s in comp_sg_structures[:3]
            ],
        ],
        columns=2,
    )
)

Try this: change guidance_scale from 2.0 to 1.0 or 3.0.
Try this too: keep target_formula fixed and swap only the target space group, or vice versa.
Question: what tradeoff do you observe between stronger conditioning and structural plausibility?

Answer

As guidance increases, the model usually follows the requested condition more strongly, but diversity drops first and geometry can start to degrade if the guidance becomes too aggressive. The sweet spot is the smallest scale that visibly moves the samples toward the target without making them collapse or distort.

13) Show the diffusion trajectory as a strip of snapshots

Instead of a GIF, we render a single long image containing selected timesteps.

Why do this?

  • it is easy to compare steps side by side,

  • it is stable in notebooks and exported documents,

  • and it makes it easier to ask “what changed between these two times?”

We will show two trajectories:

  1. a forward trajectory for a real crystal,

  2. a reverse trajectory starting from a maximally corrupted version of that crystal.

Try this: choose a different anchor crystal for the trajectory strip. Does the reverse process look smoother for some structures than for others?

The helper code that decodes intermediate states and draws the strip is tucked away below. Open it later if you want to understand exactly how the trajectory images are constructed.

# @title
def decode_state_to_structure(frac_centered: np.ndarray, lattice_norm: np.ndarray, atom_tokens: np.ndarray) -> Structure:
    atom_tokens = np.asarray(atom_tokens, dtype=np.int64)
    frac_centered = np.asarray(frac_centered, dtype=np.float32)
    valid = atom_tokens != MASK_TOKEN

    if valid.sum() == 0:
        valid = np.zeros_like(atom_tokens, dtype=bool)
        valid[0] = True
        atom_tokens = atom_tokens.copy()
        atom_tokens[0] = 0

    return x0_to_structure(frac_centered[valid], lattice_norm, atom_tokens[valid])

def render_structure_strip(structures, titles, out_path: str, rotation: str = "20x,30y,0z"):
    n = len(structures)
    fig, axes = plt.subplots(1, n, figsize=(2.8 * n, 3.2), squeeze=False)
    axes = axes[0]

    for ax, structure, title in zip(axes, structures, titles):
        atoms = structure_to_ase_atoms(structure)
        plot_atoms(atoms, ax, rotation=rotation, radii=0.35, show_unit_cell=2)
        ax.set_title(title, fontsize=9)
        ax.set_axis_off()

    plt.tight_layout()
    fig.savefig(out_path, dpi=180, bbox_inches="tight", facecolor="white")
    plt.close(fig)
    return out_path

def show_png(path: str):
    return IPythonImage(filename=path)

@torch.no_grad()
def make_forward_diffusion_strip(record, out_path="forward_diffusion_strip.png", num_frames=10):
    frac0 = torch.tensor(record["frac0"], dtype=torch.float32, device=device)
    lattice0 = torch.tensor(record["lattice0_norm"][None], dtype=torch.float32, device=device)
    atom0 = torch.tensor(record["atom_tokens0"], dtype=torch.long, device=device)
    batch_idx = torch.zeros(len(atom0), dtype=torch.long, device=device)

    timesteps = torch.linspace(0, T, steps=num_frames).round().long().tolist()
    structures, titles = [], []

    for step in timesteps:
        t_graph = torch.tensor([int(step)], dtype=torch.long, device=device)

        if step == 0:
            frac_t = frac0
            lattice_t = lattice0[0]
            atom_t = atom0
        else:
            frac_t, _ = q_sample_wrapped_coords(frac0, t_graph, batch_idx)
            lattice_t, _ = q_sample_lattice(lattice0, t_graph)
            atom_t, _ = q_sample_atom_types(atom0, t_graph, batch_idx)
            lattice_t = lattice_t[0]

        structure_t = decode_state_to_structure(
            frac_t.detach().cpu().numpy(),
            lattice_t.detach().cpu().numpy(),
            atom_t.detach().cpu().numpy(),
        )
        masked_fraction = float((atom_t == MASK_TOKEN).float().mean().detach().cpu())
        structures.append(structure_t)
        titles.append(f"t={step}\nmasked={masked_fraction:.0%}")

    return render_structure_strip(structures, titles, out_path=out_path)

@torch.no_grad()
def reconstruct_example_with_record(model, record, guidance_scale=2.0, out_path="reverse_diffusion_strip.png", num_frames=10):
    model.eval()

    num_atoms = torch.tensor([record["num_atoms"]], dtype=torch.long, device=device)
    batch_idx = torch.zeros(record["num_atoms"], dtype=torch.long, device=device)

    frac0 = torch.tensor(record["frac0"], dtype=torch.float32, device=device)
    lattice0 = torch.tensor(record["lattice0_norm"][None], dtype=torch.float32, device=device)
    atom0 = torch.tensor(record["atom_tokens0"], dtype=torch.long, device=device)
    cond_tensors = build_condition_tensors(batch_size=1, target_record=record)

    t_init = torch.tensor([T], dtype=torch.long, device=device)
    frac_t, _ = q_sample_wrapped_coords(frac0, t_init, batch_idx)
    lattice_t, _ = q_sample_lattice(lattice0, t_init)
    atom_t, _ = q_sample_atom_types(atom0, t_init, batch_idx)
    lattice_t = lattice_t.clone()

    checkpoint_steps = sorted(set(np.linspace(T, 1, num_frames).round().astype(int).tolist()), reverse=True)
    strip_structures, strip_titles = [], []

    start_structure = decode_state_to_structure(
        frac_t.detach().cpu().numpy(),
        lattice_t[0].detach().cpu().numpy(),
        atom_t.detach().cpu().numpy(),
    )
    strip_structures.append(start_structure)
    strip_titles.append(f"start\nt={T}")

    for step in range(T, 0, -1):
        t_graph = torch.tensor([step], dtype=torch.long, device=device)

        pred_coord_eps, pred_lattice_eps, pred_atom_logits = guided_predictions(
            model=model,
            frac_t=frac_t,
            lattice_t=lattice_t,
            atom_t=atom_t,
            num_atoms=num_atoms,
            batch_idx=batch_idx,
            t_graph=t_graph,
            continuous_cond=cond_tensors["continuous"],
            composition_cond=cond_tensors["composition"],
            spacegroup_cond=cond_tensors["spacegroup"],
            guidance_scale=guidance_scale,
        )

        frac_t = ddpm_reverse_step_continuous(frac_t, pred_coord_eps, t_graph, batch_idx=batch_idx, wrap=True)
        lattice_t = ddpm_reverse_step_continuous(lattice_t, pred_lattice_eps, t_graph, batch_idx=None, wrap=False)
        lattice_t = clamp_lattice_features_norm_torch(lattice_t)
        atom_t = discrete_reverse_step_absorbing(atom_t, pred_atom_logits, t_graph, batch_idx)

        if step in checkpoint_steps or step == 1:
            structure_t = decode_state_to_structure(
                frac_t.detach().cpu().numpy(),
                lattice_t[0].detach().cpu().numpy(),
                atom_t.detach().cpu().numpy(),
            )
            strip_structures.append(structure_t)
            strip_titles.append(f"t={step-1}")

    strip_path = render_structure_strip(strip_structures, strip_titles, out_path=out_path)
    final_structure = x0_to_structure(
        frac_t.detach().cpu().numpy(),
        lattice_t[0].detach().cpu().numpy(),
        atom_t.detach().cpu().numpy(),
    )
    return strip_path, final_structure

example_record = train_data[0]
print(
    "chosen example:",
    example_record["name"],
    "|",
    example_record["formula"],
    "| density =",
    example_record["density"],
    "| band gap =",
    example_record["band_gap"],
    "| space group =",
    example_record["spacegroup_symbol"],
)
display(
    show_structures(
        [example_record["structure"]],
        [f'real example | {example_record["formula"]} | density={example_record["density"]:.2f} | gap={example_record["band_gap"]:.2f} eV | sg={example_record["spacegroup_symbol"]}'],
        columns=1,
    )
)

This first cell shows how a real example is gradually destroyed by the forward process.

forward_strip = make_forward_diffusion_strip(
    example_record,
    out_path="forward_diffusion_strip.png",
    num_frames=10,
)
display(show_png(forward_strip))

This second cell shows a reverse trajectory under the learned model.

reverse_strip, reconstructed_structure = reconstruct_example_with_record(
    cond_model,
    example_record,
    guidance_scale=2.0,
    out_path="reverse_diffusion_strip.png",
    num_frames=10,
)
display(show_png(reverse_strip))
print("Final reconstruction-like sample")
display(
    show_structures(
        [reconstructed_structure],
        [f'reverse result | density={safe_structure_density(reconstructed_structure):.2f}'],
        columns=1,
    )
)

What to notice in the strips

In the forward strip:

  • coordinates get noisier,

  • the lattice drifts,

  • atom identities become increasingly masked.

In the reverse strip:

  • the model gradually proposes a coherent lattice,

  • atom identities are revealed,

  • coordinates settle into a more structured arrangement.

Question: which of the three state components seems hardest for the model to reconstruct?

Answer

Atom identity is often the hardest piece to reconstruct because masking removes discrete chemical information entirely, while noisy coordinates and lattices still leave geometric hints behind. In practice, you often see lattice and position coherence return before composition fully settles.

14) Save generated structures and keep the useful artifacts

At the end of the notebook we save:

  • generated CIF files,

  • checkpoint files,

  • trajectory-strip images.

That makes it easy to inspect results outside the notebook or compare different runs.

The helper is short enough that it is worth reading: it just loops over structures and writes them as CIFs.

def save_structures_as_cifs(structures, out_dir: str):
    out_path = Path(out_dir)
    out_path.mkdir(parents=True, exist_ok=True)
    for i, structure in enumerate(structures):
        file_path = out_path / f"sample_{i:03d}.cif"
        CifWriter(structure).write_file(str(file_path))
    return out_path

out_uncond = save_structures_as_cifs(uncond_structures, "generated_unconditional_cifs")
out_low_density = save_structures_as_cifs(low_density_structures, "generated_low_density_cifs")
out_high_density = save_structures_as_cifs(high_density_structures, "generated_high_density_cifs")
out_low_gap = save_structures_as_cifs(low_gap_structures, "generated_low_gap_cifs")
out_high_gap = save_structures_as_cifs(high_gap_structures, "generated_high_gap_cifs")

print("saved unconditional CIFs to:", out_uncond)
print("saved low-density CIFs to:", out_low_density)
print("saved high-density CIFs to:", out_high_density)
print("saved low-gap CIFs to:", out_low_gap)
print("saved high-gap CIFs to:", out_high_gap)
print("forward trajectory strip:", forward_strip)
print("reverse trajectory strip:", reverse_strip)
print("cached MP dataset:", MP_CACHE_PATH)
print("base best checkpoint:", base_run["best_checkpoint"])
print("base last checkpoint:", base_run["last_checkpoint"])
print("adapter best checkpoint:", adapter_run["best_checkpoint"])
print("adapter last checkpoint:", adapter_run["last_checkpoint"])

15) What is faithful to MatterGen, and what is simplified?

Faithful ideas

This notebook really does implement the main design pattern:

  • a ChemGraph-like flattened representation,

  • separate corruption processes for coordinates / lattice / atoms,

  • a graph denoiser with three output heads,

  • unconditional base training first,

  • adapter-style conditional fine-tuning,

  • classifier-free guidance for sampling.

Simplifications

It is still intentionally smaller than the real system:

  • the denoiser is a compact message-passing network rather than full GemNet-T,

  • the discrete atom diffusion is a simplified absorbing-mask variant,

  • the training corpus is tiny compared with production-scale materials generation,

  • band-gap conditioning is demonstrated as a label-conditioned task, not a validated property-prediction pipeline,

  • composition conditioning is a soft fraction-vector target rather than exact formula control,

  • space-group conditioning is a lightweight symmetry label rather than a full equivariant crystallographic constraint.

Those simplifications are what make the notebook teachable.

Suggested experiments

  1. Dataset experiments

    • switch CHEMIS...

After this notebook, compare the design against sota-crystals-comparison.ipynb to see where MatterGen and Chemeleon2 sit relative to the scratch-built model.

Crystal Exercises

  1. Representation: Why is batch_idx more natural than padding every crystal to the same number of atoms?

Answer

batch_idx keeps the computation aligned with the true number of atoms in each crystal, avoids wasted padded nodes, and matches graph neural network tooling directly.

  1. Corruption design: Why do we use different forward corruption processes for coordinates, lattice, and atom identity instead of one generic noise rule?

Answer

Those three channels live in different spaces: coordinates are periodic continuous variables, lattice parameters are constrained continuous variables, and atom identities are discrete. Using one corruption rule for all three would ignore those structural differences and make denoising less natural.

  1. Training strategy: What is the point of training the unconditional base model before the conditional adapter?

Answer

The base model learns a generic crystal prior first. The adapter then learns how to steer that prior toward requested properties without having to relearn all of crystal plausibility from scratch.

  1. Guidance: If you raise the classifier-free guidance scale, what usually changes first: diversity or condition adherence?

Answer

Condition adherence usually improves while diversity drops. If the scale becomes too large, structural plausibility can start to degrade as well.

  1. Interpretation: Why should the band-gap and space-group demos here be read as targeted generation illustrations rather than strict guarantees?

Answer

The notebook conditions on labels and analyzes decoded outputs, but it does not solve a fully constrained crystallographic inverse problem. The conditions steer the generator, yet the decoded structures still need downstream validation if exact property or symmetry guarantees matter.

References
  1. Jain, A., Ong, S. P., Hautier, G., Chen, W., Richards, W. D., Dacek, S., Cholia, S., Gunter, D., Skinner, D., Ceder, G., & Persson, K. A. (2013). Commentary: The Materials Project: A materials genome approach to accelerating materials innovation. APL Materials, 1(1). 10.1063/1.4812323