diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 90e8efe5..69c9cf7b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -67,6 +67,7 @@ jobs: - { python: '3.13', resolution: highest } - { python: '3.13', resolution: lowest-direct } model: + - { name: chgnet, test_path: "tests/models/test_chgnet.py" } - { name: fairchem, test_path: "tests/models/test_fairchem.py" } - { name: fairchem-legacy, test_path: "tests/models/test_fairchem_legacy.py" } - { name: graphpes, test_path: "tests/models/test_graphpes.py" } diff --git a/docs/conf.py b/docs/conf.py index b42066ac..43b8fc5e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -62,6 +62,7 @@ ] autodoc_mock_imports = [ + "chgnet", "fairchem", "mace", "mattersim", diff --git a/examples/scripts/1_Introduction/1.4_CHGNet.py b/examples/scripts/1_Introduction/1.4_CHGNet.py new file mode 100644 index 00000000..a2d92e23 --- /dev/null +++ b/examples/scripts/1_Introduction/1.4_CHGNet.py @@ -0,0 +1,149 @@ +"""CHGNet model example for TorchSim.""" + +# /// script +# dependencies = ["chgnet>=0.4.2", "mace-torch>=0.3.12"] +# /// + +import os +import warnings + +import torch +from ase import Atoms +from ase.build import bulk +from mace.calculators.foundations_models import mace_mp + +import torch_sim as ts +from torch_sim.models.chgnet import CHGNetModel +from torch_sim.models.mace import MaceModel, MaceUrls + + +# Silence warnings +warnings.filterwarnings("ignore") +os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" + +# Set device +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +dtype = torch.float32 + +print("CHGNet Example for TorchSim") +print("=" * 40) + +# Create CHGNet model +model = CHGNetModel( + device=device, + dtype=dtype, + compute_forces=True, + compute_stress=True, +) + +# Create test systems +al_atoms = bulk("Al", "fcc", a=4.05, cubic=True) +c_atoms = bulk("C", "diamond", a=3.57, cubic=True) +mg_atoms = bulk("Mg", "hcp", a=3.21, c=5.21) +a_perovskite = 3.84 +ca_tio3_atoms = Atoms( + ["Ca", "Ti", "O", "O", "O"], + positions=[ + [0, 0, 0], + [a_perovskite / 2, a_perovskite / 2, a_perovskite / 2], + [a_perovskite / 2, 0, 0], + [0, a_perovskite / 2, 0], + [0, 0, a_perovskite / 2], + ], + cell=[a_perovskite, a_perovskite, a_perovskite], + pbc=True, +) + +# Convert to TorchSim state +state = ts.io.atoms_to_state([al_atoms, c_atoms, mg_atoms], device, dtype) + +# Load MACE model for comparison +raw_mace_mp = mace_mp(model=MaceUrls.mace_mp_small, return_raw_model=True) +mace_model = MaceModel( + model=raw_mace_mp, + device=device, + dtype=dtype, + compute_forces=True, + compute_stress=True, +) +mace_available = True + +# Single comprehensive results table +print( + "\nCHGNet vs MACE Results " + "(E: Total Energy, F: Maximum Force, S: Maximum Stress, M: Maximum Magnetic Moment)" +) +print("=" * 87) +print( + f"{'System':<10} {'CHGNet E':<12} {'CHGNet F':<12} {'CHGNet S':<12} " + f"{'CHGNet M':<12} {'MACE E':<12} {'MACE F':<12}" +) +print("-" * 87) + +# Test equilibrium structures +for i, system_name in enumerate(["Al", "C", "Mg"]): + single_state = ts.io.atoms_to_state([[al_atoms, c_atoms, mg_atoms][i]], device, dtype) + + # CHGNet results + chgnet_result = model.forward(single_state) + chgnet_energy = chgnet_result["energy"].item() + chgnet_force = torch.norm(chgnet_result["forces"], dim=1).max().item() + chgnet_stress = torch.norm(chgnet_result["stress"], dim=(1, 2)).max().item() + chgnet_magmom = ( + torch.norm(chgnet_result.get("magnetic_moments", torch.zeros(1, 3)), dim=-1) + .max() + .item() + ) + + # MACE results + mace_result = mace_model.forward(single_state) + mace_energy = mace_result["energy"].item() + mace_force = torch.norm(mace_result["forces"], dim=1).max().item() + print( + f"{system_name:<10} {chgnet_energy:<12.3f} {chgnet_force:<12.3f} " + f"{chgnet_stress:<12.3f} {chgnet_magmom:<12.3f} {mace_energy:<12.3f} " + f"{mace_force:<12.3f}" + ) + +# Test optimization on displaced structures +for atoms, system_name in zip( + [al_atoms, c_atoms, ca_tio3_atoms], ["Al", "C", "CaTiO3"], strict=False +): + single_state = ts.io.atoms_to_state([atoms], device, dtype) + displacement = torch.randn_like(single_state.positions) * 0.1 + displaced_state = single_state.clone() + displaced_state.positions = single_state.positions + displacement + + # CHGNet optimization + chgnet_optimized = ts.optimize( + displaced_state, model, optimizer=ts.optimizers.Optimizer.fire, max_steps=100 + ) + chgnet_final = model.forward(chgnet_optimized) + chgnet_final_energy = chgnet_final["energy"].item() + chgnet_final_force = torch.norm(chgnet_final["forces"], dim=1).max().item() + chgnet_final_stress = torch.norm(chgnet_final["stress"], dim=(1, 2)).max().item() + chgnet_final_magmom = ( + torch.norm(chgnet_final.get("magnetic_moments", torch.zeros(1, 3)), dim=-1) + .max() + .item() + ) + + # MACE optimization + mace_optimized = ts.optimize( + displaced_state, + mace_model, + optimizer=ts.optimizers.Optimizer.fire, + max_steps=100, + ) + mace_final = mace_model.forward(mace_optimized) + mace_final_energy = mace_final["energy"].item() + mace_final_force = torch.norm(mace_final["forces"], dim=1).max().item() + print( + f"{system_name + '_opt':<10} {chgnet_final_energy:<12.3f} " + f"{chgnet_final_force:<12.3f} {chgnet_final_stress:<12.3f} " + f"{chgnet_final_magmom:<12.3f} {mace_final_energy:<12.3f} " + f"{mace_final_force:<12.3f}" + ) + +print("=" * 87) +print("CHGNet example completed successfully!") diff --git a/pyproject.toml b/pyproject.toml index 850f35b6..2e8c0fd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ sevenn = ["sevenn>=0.11.0"] graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"] nequip = ["nequip>=0.12.0"] fairchem = ["fairchem-core>=2.7"] +chgnet = ["chgnet>=0.4.2"] docs = [ "autodoc_pydantic==2.2.0", "furo==2024.8.6", diff --git a/tests/models/test_chgnet.py b/tests/models/test_chgnet.py new file mode 100644 index 00000000..d4214671 --- /dev/null +++ b/tests/models/test_chgnet.py @@ -0,0 +1,60 @@ +import traceback + +import pytest +import torch + +from tests.conftest import DEVICE +from tests.models.conftest import ( + make_model_calculator_consistency_test, + make_validate_model_outputs_test, +) + + +try: + from chgnet.model.dynamics import CHGNetCalculator as CHGNetAseCalculator + + from torch_sim.models.chgnet import CHGNetModel +except (ImportError, ValueError): + pytest.skip( + f"CHGNet not installed: {traceback.format_exc()}", allow_module_level=True + ) + + +DTYPE = torch.float32 + + +@pytest.fixture +def ts_chgnet_model() -> CHGNetModel: + """Create a TorchSim CHGNet model for testing.""" + return CHGNetModel( + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + +@pytest.fixture +def ase_chgnet_calculator() -> CHGNetAseCalculator: + """Create an ASE CHGNet calculator for testing.""" + # Use the official CHGNet calculator + return CHGNetAseCalculator() + + +test_chgnet_model_outputs = make_validate_model_outputs_test( + model_fixture_name="ts_chgnet_model", dtype=DTYPE +) + +test_chgnet_consistency = make_model_calculator_consistency_test( + test_name="chgnet", + model_fixture_name="ts_chgnet_model", + calculator_fixture_name="ase_chgnet_calculator", + sim_state_names=("si_sim_state", "cu_sim_state", "mg_sim_state", "ti_sim_state"), + dtype=DTYPE, + energy_rtol=1e-4, + energy_atol=1e-4, + force_rtol=1e-4, + force_atol=1e-4, + stress_rtol=1e-3, + stress_atol=1e-3, +) diff --git a/torch_sim/models/chgnet.py b/torch_sim/models/chgnet.py new file mode 100644 index 00000000..ecace2a3 --- /dev/null +++ b/torch_sim/models/chgnet.py @@ -0,0 +1,204 @@ +"""Wrapper for CHGNet model in TorchSim. + +This module provides a TorchSim wrapper of the CHGNet model for computing +energies, forces, and stresses for atomistic systems. It integrates the CHGNet model +with TorchSim's simulation framework, handling batched computations for multiple +systems simultaneously. + +The implementation supports various features including: + +* Computing energies, forces, and stresses +* Handling periodic boundary conditions (PBC) +* Native batching support for multiple systems +* Magnetic moment calculations + +Notes: + This module depends on the CHGNet package and implements the ModelInterface + for compatibility with the broader TorchSim framework. +""" + +import traceback +import warnings +from typing import Any + +import torch + +import torch_sim as ts +from torch_sim.models.interface import ModelInterface +from torch_sim.typing import StateDict + + +try: + from chgnet.model.model import CHGNet +except (ImportError, ModuleNotFoundError) as exc: + warnings.warn(f"CHGNet import failed: {traceback.format_exc()}", stacklevel=2) + + class CHGNetModel(ModelInterface): + """CHGNet model wrapper for torch-sim. + + This class is a placeholder for the CHGNetModel class. + It raises an ImportError if CHGNet is not installed. + """ + + def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: + """Dummy init for type checking.""" + raise err + + +class CHGNetModel(ModelInterface): + """Computes energies for multiple systems using a CHGNet model. + + This class wraps a CHGNet model to compute energies, forces, and stresses for + atomic systems within the TorchSim framework. It supports batched calculations + for multiple systems and handles the necessary transformations between + TorchSim's data structures and CHGNet's expected inputs. + + Attributes: + model (CHGNet): The underlying CHGNet neural network model. + _memory_scales_with (str): Memory scaling metric, set to "n_atoms_x_density". + """ + + def __init__( + self, + model: CHGNet | None = None, + *, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + compute_forces: bool = True, + compute_stress: bool = True, + ) -> None: + """Initialize the CHGNet model for energy and force calculations. + + Sets up the CHGNet model for energy, force, and stress calculations within + the TorchSim framework. The model can be initialized with a pre-loaded CHGNet + instance or will load the default pre-trained model. + + Args: + model (CHGNet | None): The CHGNet neural network model instance. + If None, loads the default pre-trained model. + device (torch.device | None): The device to run computations on. + Defaults to CUDA if available, otherwise CPU. + dtype (torch.dtype): The data type for tensor operations. + Defaults to torch.float64. + compute_forces (bool): Whether to compute forces. Defaults to True. + compute_stress (bool): Whether to compute stress. Defaults to True. + + Raises: + ImportError: If CHGNet is not installed. + """ + super().__init__() + self._device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self._dtype = dtype + self._compute_forces = compute_forces + self._compute_stress = compute_stress + self._memory_scales_with = "n_atoms_x_density" + + # Load model + if model is None: + self.model = CHGNet.load() + else: + self.model = model + + # Move model to device + self.model = self.model.to(self._device) + if hasattr(self.model, "eval"): + self.model = self.model.eval() + + def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: + """Compute energies, forces, and stresses for the given atomic systems. + + Processes the provided state information and computes energies, forces, and + stresses using the underlying CHGNet model. Handles batched calculations for + multiple systems and constructs the necessary data structures. + + Args: + state (SimState | StateDict): State object containing positions, cell, + and other system information. Can be either a SimState object or a + dictionary with the relevant fields. + + Returns: + dict[str, torch.Tensor]: Computed properties: + - 'energy': System energies with shape [n_systems] + - 'forces': Atomic forces with shape [n_atoms, 3] if compute_forces=True + - 'stress': System stresses with shape [n_systems, 3, 3] if + compute_stress=True + - 'magnetic_moments': Magnetic moments with shape [n_atoms, 3] if + available in CHGNet output + + """ + # Handle state dict + sim_state = ( + state + if isinstance(state, ts.SimState) + else ts.SimState(**state, masses=torch.ones_like(state["positions"])) + ) + + # Convert SimState to list of pymatgen Structures + structures = sim_state.to_structures() + + # Use CHGNet's batching support + chgnet_results = self.model.predict_structure(structures) + + # Handle both single and multiple structures + if len(structures) == 1: + # Single structure returns a single dict + chgnet_results = [chgnet_results] + + # Convert results to TorchSim format + results: dict[str, torch.Tensor] = {} + + # Process energy (CHGNet returns energy per atom, multiply by number of atoms) + energies = [] + for i, result in enumerate(chgnet_results): + chgnet_energy_per_atom = ( + result["e"].item() if hasattr(result["e"], "item") else result["e"] + ) + + # Get number of atoms in this structure + structure = structures[i] + total_atoms = len(structure) + + # Multiply by number of atoms to get total energy + total_energy = chgnet_energy_per_atom * total_atoms + + energies.append( + torch.tensor(total_energy, device=self.device, dtype=self.dtype) + ) + + results["energy"] = torch.stack(energies) + + # Process forces + if self.compute_forces: + forces_list = [ + torch.tensor(result["f"], device=self.device, dtype=self.dtype) + for result in chgnet_results + ] + forces = torch.cat(forces_list, dim=0) + results["forces"] = forces + + # Process stress + if self.compute_stress: + stresses = torch.stack( + [ + torch.tensor(result["s"], device=self.device, dtype=self.dtype) + for result in chgnet_results + ] + ) + # Convert from GPa to eV/A^3 to match ASE calculator convention + # stress_weight converts GPa to eV/A^3 (approximately 1/160.21) + stress_weight = 0.006241509125883258 + results["stress"] = stresses * stress_weight + + # Process magnetic moments (if available) + if "m" in chgnet_results[0]: + magnetic_moments_list = [ + torch.tensor(result["m"], device=self.device, dtype=self.dtype) + for result in chgnet_results + ] + # Concatenate along atom dimension, similar to forces + magnetic_moments = torch.cat(magnetic_moments_list, dim=0) + results["magnetic_moments"] = magnetic_moments + + return results