From 3ee36e676f3914a8a33b470255f10bd95be415b8 Mon Sep 17 00:00:00 2001 From: Bruno Cucco Date: Wed, 15 Oct 2025 16:36:06 -0500 Subject: [PATCH 1/5] Add CHGNet model support: - Integrates CHGNet (Crystal Hamiltonian Graph Neural Network), a neural network potential developed by Ceder's Group. Similar to MACE, CHGNet is a message-passing network trained on the Materials Project database, however, it uses a different architecture. Some of these differences include: (1) CHGNet includes magnetic moment calculations, (2) uses charge-informed graph features, (3) trained on a different set of Materials Project data. - CHGNetModel wrapper inherits from ModelInterface - Add comprehensive test suite - Create example script demonstrating CHGNet usage with MACE comparison - Add chgnet>=0.4.2 as optional dependency in pyproject.toml - Support batched calculations, PBC, and TorchSim workflows - Include proper error handling for missing dependencies --- examples/scripts/1_Introduction/1.4_CHGNet.py | 140 +++++++++++ pyproject.toml | 1 + tests/models/test_chgnet.py | 222 ++++++++++++++++++ torch_sim/models/chgnet.py | 193 +++++++++++++++ 4 files changed, 556 insertions(+) create mode 100644 examples/scripts/1_Introduction/1.4_CHGNet.py create mode 100644 tests/models/test_chgnet.py create mode 100644 torch_sim/models/chgnet.py diff --git a/examples/scripts/1_Introduction/1.4_CHGNet.py b/examples/scripts/1_Introduction/1.4_CHGNet.py new file mode 100644 index 00000000..8f4481db --- /dev/null +++ b/examples/scripts/1_Introduction/1.4_CHGNet.py @@ -0,0 +1,140 @@ +"""CHGNet model example for TorchSim.""" + +# /// script +# dependencies = ["chgnet>=0.4.2"] +# /// + +import warnings +import os +import torch +import torch_sim as ts +from ase.build import bulk +from ase import Atoms +from torch_sim.models.chgnet import CHGNetModel + +# Silence warnings +warnings.filterwarnings("ignore") +os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" + +# Set device +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +dtype = torch.float32 + +print("CHGNet Example for TorchSim") +print("=" * 40) + +# Create CHGNet model +model = CHGNetModel( + device=device, + dtype=dtype, + compute_forces=True, + compute_stress=True, +) + +# Create test systems +al_atoms = bulk("Al", "fcc", a=4.05, cubic=True) +c_atoms = bulk("C", "diamond", a=3.57, cubic=True) +mg_atoms = bulk("Mg", "hcp", a=3.21, c=5.21) +a_perovskite = 3.84 +ca_tio3_atoms = Atoms(['Ca', 'Ti', 'O', 'O', 'O'], + positions=[[0, 0, 0], [a_perovskite/2, a_perovskite/2, a_perovskite/2], + [a_perovskite/2, 0, 0], [0, a_perovskite/2, 0], [0, 0, a_perovskite/2]], + cell=[a_perovskite, a_perovskite, a_perovskite], pbc=True) + +# Convert to TorchSim state +state = ts.io.atoms_to_state([al_atoms, c_atoms, mg_atoms], device, dtype) + +# Load MACE model for comparison +try: + from torch_sim.models.mace import MaceModel, MaceUrls + from mace.calculators.foundations_models import mace_mp + + raw_mace_mp = mace_mp(model=MaceUrls.mace_mp_small, return_raw_model=True) + mace_model = MaceModel( + model=raw_mace_mp, + device=device, + dtype=dtype, + compute_forces=True, + compute_stress=True, + ) + mace_available = True +except ImportError: + mace_available = False + print("MACE not available for comparison") + +# In this table we compare CHGNet and MACE on the equilibrium structures +print("\nTABLE 1: Equilibrium Structures") +print("=" * 80) +print(f"{'System':<10} {'CHGNet E (eV)':<15} {'CHGNet F (eV/Å)':<15} {'MACE E (eV)':<15} {'MACE F (eV/Å)':<15}") +print("-" * 80) + +for i, system_name in enumerate(["Al", "C", "Mg"]): + # Get states + single_state = ts.io.atoms_to_state([[al_atoms, c_atoms, mg_atoms][i]], device, dtype) + + # CHGNet results + chgnet_result = model.forward(single_state) + chgnet_energy = chgnet_result['energy'].item() + chgnet_force = torch.norm(chgnet_result['forces'], dim=1).max().item() + + # MACE results + if mace_available: + mace_result = mace_model.forward(single_state) + mace_energy = mace_result['energy'].item() + mace_force = torch.norm(mace_result['forces'], dim=1).max().item() + print(f"{system_name:<10} {chgnet_energy:<15.6f} {chgnet_force:<15.6f} {mace_energy:<15.6f} {mace_force:<15.6f}") + else: + print(f"{system_name:<10} {chgnet_energy:<15.6f} {chgnet_force:<15.6f} {'N/A':<15} {'N/A':<15}") + +# In this table we compare CHGNet and MACE on the displaced and optimized structures +print("\nTABLE 2: Displaced and Optimized Structures") +print("=" * 100) +print(f"{'System':<10} {'CHGNet Init E':<15} {'CHGNet Fin E':<15} {'CHGNet Fin F':<15} {'MACE Init E':<15} {'MACE Fin E':<15} {'MACE Fin F':<15}") +print("-" * 120) + +for i, (atoms, system_name) in enumerate(zip([al_atoms, c_atoms, ca_tio3_atoms], ["Al", "C", "CaTiO3"])): + # Create displaced state + single_state = ts.io.atoms_to_state([atoms], device, dtype) + displacement = torch.randn_like(single_state.positions) * 0.1 + displaced_state = single_state.clone() + displaced_state.positions = single_state.positions + displacement + + # CHGNet optimization + chgnet_initial = model.forward(displaced_state) + chgnet_initial_energy = chgnet_initial['energy'].item() + + chgnet_optimized = ts.optimize( + displaced_state, + model, + optimizer=ts.optimizers.Optimizer.fire, + max_steps=100, + ) + + chgnet_final = model.forward(chgnet_optimized) + chgnet_final_energy = chgnet_final['energy'].item() + chgnet_final_force = torch.norm(chgnet_final['forces'], dim=1).max().item() + + # MACE optimization + if mace_available: + mace_initial = mace_model.forward(displaced_state) + mace_initial_energy = mace_initial['energy'].item() + + mace_optimized = ts.optimize( + displaced_state, + mace_model, + optimizer=ts.optimizers.Optimizer.fire, + max_steps=100, + ) + + mace_final = mace_model.forward(mace_optimized) + mace_final_energy = mace_final['energy'].item() + mace_final_force = torch.norm(mace_final['forces'], dim=1).max().item() + + print(f"{system_name:<10} {chgnet_initial_energy:<15.6f} {chgnet_final_energy:<15.6f} {chgnet_final_force:<15.6f} {mace_initial_energy:<15.6f} {mace_final_energy:<15.6f} {mace_final_force:<15.6f}") + else: + print(f"{system_name:<10} {chgnet_initial_energy:<15.6f} {chgnet_final_energy:<15.6f} {chgnet_final_force:<15.6f} {'N/A':<15} {'N/A':<15} {'N/A':<15}") + +print("\n" + "="*100) +print("CHGNet example completed successfully!") +print("="*100) + diff --git a/pyproject.toml b/pyproject.toml index 1fa82d12..b46705a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ sevenn = ["sevenn>=0.11.0"] graphpes = ["graph-pes>=0.0.34,<=0.2.0", "mace-torch>=0.3.12"] nequip = ["nequip>=0.12.0"] fairchem = ["fairchem-core>=2.7"] +chgnet = ["chgnet>=0.4.2"] docs = [ "autodoc_pydantic==2.2.0", "furo==2024.8.6", diff --git a/tests/models/test_chgnet.py b/tests/models/test_chgnet.py new file mode 100644 index 00000000..394a4972 --- /dev/null +++ b/tests/models/test_chgnet.py @@ -0,0 +1,222 @@ +import traceback + +import pytest +import torch +from ase.atoms import Atoms + +import torch_sim as ts +from tests.conftest import DEVICE +from tests.models.conftest import ( + make_validate_model_outputs_test, +) + + +try: + from chgnet.model.model import CHGNet + from torch_sim.models.chgnet import CHGNetModel +except (ImportError, ValueError): + pytest.skip(f"CHGNet not installed: {traceback.format_exc()}", allow_module_level=True) + + +DTYPE = torch.float32 + + +@pytest.fixture +def ts_chgnet_model() -> CHGNetModel: + """Create a TorchSim CHGNet model for testing.""" + return CHGNetModel( + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_chgnet_dtype_working(si_atoms: Atoms, dtype: torch.dtype) -> None: + """Test that CHGNet works with different dtypes.""" + model = CHGNetModel( + device=DEVICE, + dtype=dtype, + compute_forces=True, + ) + + state = ts.io.atoms_to_state([si_atoms], DEVICE, dtype) + result = model.forward(state) + + # Check that results have correct dtype + assert result["energy"].dtype == dtype + assert result["forces"].dtype == dtype + assert result["stress"].dtype == dtype + + +def test_chgnet_batched_calculations() -> None: + """Test that CHGNet handles batched calculations correctly.""" + from ase.build import bulk + + # Create multiple systems + si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) + cu_atoms = bulk("Cu", "fcc", a=3.6, cubic=True) + + model = CHGNetModel( + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + # Test batched calculation + state = ts.io.atoms_to_state([si_atoms, cu_atoms], DEVICE, DTYPE) + result = model.forward(state) + + # Check output shapes + assert result["energy"].shape == (2,) + assert result["forces"].shape == (si_atoms.get_global_number_of_atoms() + cu_atoms.get_global_number_of_atoms(), 3) + assert result["stress"].shape == (2, 3, 3) + + # Check that energies are different (different materials) + assert not torch.allclose(result["energy"][0], result["energy"][1], atol=1e-3) + + +def test_chgnet_single_vs_batched_consistency() -> None: + """Test that single and batched calculations give consistent results.""" + from ase.build import bulk + + si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) + + model = CHGNetModel( + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + # Single system calculation + single_state = ts.io.atoms_to_state([si_atoms], DEVICE, DTYPE) + single_result = model.forward(single_state) + + # Batched calculation (same system twice) + batched_state = ts.io.atoms_to_state([si_atoms, si_atoms], DEVICE, DTYPE) + batched_result = model.forward(batched_state) + + # Check consistency + assert torch.allclose(single_result["energy"][0], batched_result["energy"][0], atol=1e-5) + assert torch.allclose(single_result["energy"][0], batched_result["energy"][1], atol=1e-5) + assert torch.allclose(single_result["forces"], batched_result["forces"][:single_state.n_atoms], atol=1e-5) + assert torch.allclose(single_result["forces"], batched_result["forces"][single_state.n_atoms:], atol=1e-5) + assert torch.allclose(single_result["stress"][0], batched_result["stress"][0], atol=1e-5) + assert torch.allclose(single_result["stress"][0], batched_result["stress"][1], atol=1e-5) + + +def test_chgnet_missing_atomic_numbers() -> None: + """Test that CHGNet raises appropriate error when atomic numbers are missing.""" + model = CHGNetModel( + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + ) + + # Create state without atomic numbers by using a state dict + state_dict = { + "positions": torch.randn(8, 3, device=DEVICE, dtype=DTYPE), + "cell": torch.eye(3, device=DEVICE, dtype=DTYPE).unsqueeze(0), + "pbc": True, + "atomic_numbers": None, # Missing atomic numbers + "system_idx": torch.zeros(8, dtype=torch.long, device=DEVICE), + } + + with pytest.raises(ValueError, match="Atomic numbers must be provided"): + model.forward(state_dict) + + +def test_chgnet_custom_model() -> None: + """Test that CHGNet can be initialized with a custom model.""" + # Load a custom model instance + custom_model = CHGNet.load() + + model = CHGNetModel( + model=custom_model, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + ) + + from ase.build import bulk + si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) + state = ts.io.atoms_to_state([si_atoms], DEVICE, DTYPE) + + result = model.forward(state) + assert "energy" in result + assert "forces" in result + assert "stress" in result + + +def test_chgnet_compute_forces_false() -> None: + """Test CHGNet with compute_forces=False.""" + model = CHGNetModel( + device=DEVICE, + dtype=DTYPE, + compute_forces=False, + compute_stress=True, + ) + + from ase.build import bulk + si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) + state = ts.io.atoms_to_state([si_atoms], DEVICE, DTYPE) + + result = model.forward(state) + assert "energy" in result + assert "forces" not in result + assert "stress" in result + + +def test_chgnet_compute_stress_false() -> None: + """Test CHGNet with compute_stress=False.""" + model = CHGNetModel( + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=False, + ) + + from ase.build import bulk + si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) + state = ts.io.atoms_to_state([si_atoms], DEVICE, DTYPE) + + result = model.forward(state) + assert "energy" in result + assert "forces" in result + assert "stress" not in result + + +def test_chgnet_compute_both_false() -> None: + """Test CHGNet with both compute_forces=False and compute_stress=False.""" + model = CHGNetModel( + device=DEVICE, + dtype=DTYPE, + compute_forces=False, + compute_stress=False, + ) + + from ase.build import bulk + si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) + state = ts.io.atoms_to_state([si_atoms], DEVICE, DTYPE) + + result = model.forward(state) + assert "energy" in result + assert "forces" not in result + assert "stress" not in result + + +test_chgnet_model_outputs = make_validate_model_outputs_test( + model_fixture_name="ts_chgnet_model", dtype=DTYPE +) + + +def test_chgnet_import_error() -> None: + """Test that CHGNetModel raises ImportError when CHGNet is not available.""" + from torch_sim.models.chgnet import CHGNetModel + + # Should not raise an error when CHGNet is available + model = CHGNetModel(device=DEVICE, dtype=DTYPE) + assert isinstance(model, CHGNetModel) diff --git a/torch_sim/models/chgnet.py b/torch_sim/models/chgnet.py new file mode 100644 index 00000000..fb688353 --- /dev/null +++ b/torch_sim/models/chgnet.py @@ -0,0 +1,193 @@ +"""Wrapper for CHGNet model in TorchSim. + +This module provides a TorchSim wrapper of the CHGNet model for computing +energies, forces, and stresses for atomistic systems. It integrates the CHGNet model +with TorchSim's simulation framework, handling batched computations for multiple +systems simultaneously. + +The implementation supports various features including: + +* Computing energies, forces, and stresses +* Handling periodic boundary conditions (PBC) +* Native batching support for multiple systems +* Magnetic moment calculations + +Notes: + This module depends on the CHGNet package and implements the ModelInterface + for compatibility with the broader TorchSim framework. +""" + +import traceback +import warnings +from typing import Any + +import torch + +import torch_sim as ts +from torch_sim.models.interface import ModelInterface +from torch_sim.typing import StateDict + + +try: + from chgnet.model.model import CHGNet +except (ImportError, ModuleNotFoundError) as exc: + warnings.warn(f"CHGNet import failed: {traceback.format_exc()}", stacklevel=2) + + class CHGNetModel(ModelInterface): + """CHGNet model wrapper for torch-sim. + + This class is a placeholder for the CHGNetModel class. + It raises an ImportError if CHGNet is not installed. + """ + + def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: + """Dummy init for type checking.""" + raise err + + +class CHGNetModel(ModelInterface): + """Computes energies for multiple systems using a CHGNet model. + + This class wraps a CHGNet model to compute energies, forces, and stresses for + atomic systems within the TorchSim framework. It supports batched calculations + for multiple systems and handles the necessary transformations between + TorchSim's data structures and CHGNet's expected inputs. + + Attributes: + model (CHGNet): The underlying CHGNet neural network model. + _memory_scales_with (str): Memory scaling metric, set to "n_atoms_x_density". + """ + + def __init__( + self, + model: CHGNet | None = None, + *, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + compute_forces: bool = True, + compute_stress: bool = True, + ) -> None: + """Initialize the CHGNet model for energy and force calculations. + + Sets up the CHGNet model for energy, force, and stress calculations within + the TorchSim framework. The model can be initialized with a pre-loaded CHGNet + instance or will load the default pre-trained model. + + Args: + model (CHGNet | None): The CHGNet neural network model instance. + If None, loads the default pre-trained model. + device (torch.device | None): The device to run computations on. + Defaults to CUDA if available, otherwise CPU. + dtype (torch.dtype): The data type for tensor operations. + Defaults to torch.float64. + compute_forces (bool): Whether to compute forces. Defaults to True. + compute_stress (bool): Whether to compute stress. Defaults to True. + + Raises: + ImportError: If CHGNet is not installed. + """ + super().__init__() + self._device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self._dtype = dtype + self._compute_forces = compute_forces + self._compute_stress = compute_stress + self._memory_scales_with = "n_atoms_x_density" + + # Load model + if model is None: + self.model = CHGNet.load() + else: + self.model = model + + # Move model to device + self.model = self.model.to(self._device) + if hasattr(self.model, 'eval'): + self.model = self.model.eval() + + def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: + """Compute energies, forces, and stresses for the given atomic systems. + + Processes the provided state information and computes energies, forces, and + stresses using the underlying CHGNet model. Handles batched calculations for + multiple systems and constructs the necessary data structures. + + Args: + state (SimState | StateDict): State object containing positions, cell, + and other system information. Can be either a SimState object or a + dictionary with the relevant fields. + + Returns: + dict[str, torch.Tensor]: Computed properties: + - 'energy': System energies with shape [n_systems] + - 'forces': Atomic forces with shape [n_atoms, 3] if compute_forces=True + - 'stress': System stresses with shape [n_systems, 3, 3] if + compute_stress=True + + Raises: + ValueError: If atomic numbers are not provided in the state. + """ + # Handle state dict + if isinstance(state, dict) and state.get("atomic_numbers") is None: + raise ValueError("Atomic numbers must be provided in the state for CHGNet.") + + sim_state = ( + state + if isinstance(state, ts.SimState) + else ts.SimState(**state, masses=torch.ones_like(state["positions"])) + ) + + # Validate that atomic numbers + if sim_state.atomic_numbers is None: + raise ValueError("Atomic numbers must be provided in the state for CHGNet.") + + # Convert SimState to list of pymatgen Structures + structures = sim_state.to_structures() + + # Use CHGNet's batching support + chgnet_results = self.model.predict_structure(structures) + + # Handle both single and multiple structures + if len(structures) == 1: + # Single structure returns a single dict + chgnet_results = [chgnet_results] + + # Convert results to TorchSim format + results: dict[str, torch.Tensor] = {} + + # Process energy (CHGNet returns energy per atom, multiply by number of atoms) + energies = [] + for i, result in enumerate(chgnet_results): + chgnet_energy_per_atom = result['e'].item() if hasattr(result['e'], 'item') else result['e'] + + # Get number of atoms in this structure + structure = structures[i] + total_atoms = len(structure) + + # Multiply by number of atoms to get total energy + total_energy = chgnet_energy_per_atom * total_atoms + + energies.append(torch.tensor(total_energy, device=self.device, dtype=self.dtype)) + + results["energy"] = torch.stack(energies) + + # Process forces + if self.compute_forces: + forces_list = [] + for result in chgnet_results: + forces_list.append( + torch.tensor(result['f'], device=self.device, dtype=self.dtype) + ) + forces = torch.cat(forces_list, dim=0) + results["forces"] = forces + + # Process stress + if self.compute_stress: + stresses = torch.stack([ + torch.tensor(result['s'], device=self.device, dtype=self.dtype) + for result in chgnet_results + ]) + results["stress"] = stresses + + return results From 05691fdf5d3598549ecd7b08b03bed7ba0398783 Mon Sep 17 00:00:00 2001 From: Bruno Cucco Date: Wed, 15 Oct 2025 17:04:30 -0500 Subject: [PATCH 2/5] Apply ruff formatting to CHGNet files --- examples/scripts/1_Introduction/1.4_CHGNet.py | 113 ++++++++++++------ tests/models/test_chgnet.py | 92 ++++++++------ torch_sim/models/chgnet.py | 41 ++++--- 3 files changed, 155 insertions(+), 91 deletions(-) diff --git a/examples/scripts/1_Introduction/1.4_CHGNet.py b/examples/scripts/1_Introduction/1.4_CHGNet.py index 8f4481db..95f44fa1 100644 --- a/examples/scripts/1_Introduction/1.4_CHGNet.py +++ b/examples/scripts/1_Introduction/1.4_CHGNet.py @@ -4,14 +4,17 @@ # dependencies = ["chgnet>=0.4.2"] # /// -import warnings import os +import warnings + import torch -import torch_sim as ts -from ase.build import bulk from ase import Atoms +from ase.build import bulk + +import torch_sim as ts from torch_sim.models.chgnet import CHGNetModel + # Silence warnings warnings.filterwarnings("ignore") os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" @@ -35,20 +38,29 @@ al_atoms = bulk("Al", "fcc", a=4.05, cubic=True) c_atoms = bulk("C", "diamond", a=3.57, cubic=True) mg_atoms = bulk("Mg", "hcp", a=3.21, c=5.21) -a_perovskite = 3.84 -ca_tio3_atoms = Atoms(['Ca', 'Ti', 'O', 'O', 'O'], - positions=[[0, 0, 0], [a_perovskite/2, a_perovskite/2, a_perovskite/2], - [a_perovskite/2, 0, 0], [0, a_perovskite/2, 0], [0, 0, a_perovskite/2]], - cell=[a_perovskite, a_perovskite, a_perovskite], pbc=True) +a_perovskite = 3.84 +ca_tio3_atoms = Atoms( + ["Ca", "Ti", "O", "O", "O"], + positions=[ + [0, 0, 0], + [a_perovskite / 2, a_perovskite / 2, a_perovskite / 2], + [a_perovskite / 2, 0, 0], + [0, a_perovskite / 2, 0], + [0, 0, a_perovskite / 2], + ], + cell=[a_perovskite, a_perovskite, a_perovskite], + pbc=True, +) # Convert to TorchSim state state = ts.io.atoms_to_state([al_atoms, c_atoms, mg_atoms], device, dtype) # Load MACE model for comparison try: - from torch_sim.models.mace import MaceModel, MaceUrls from mace.calculators.foundations_models import mace_mp - + + from torch_sim.models.mace import MaceModel, MaceUrls + raw_mace_mp = mace_mp(model=MaceUrls.mace_mp_small, return_raw_model=True) mace_model = MaceModel( model=raw_mace_mp, @@ -65,76 +77,99 @@ # In this table we compare CHGNet and MACE on the equilibrium structures print("\nTABLE 1: Equilibrium Structures") print("=" * 80) -print(f"{'System':<10} {'CHGNet E (eV)':<15} {'CHGNet F (eV/Å)':<15} {'MACE E (eV)':<15} {'MACE F (eV/Å)':<15}") +print( + f"{'System':<10} {'CHGNet E (eV)':<15} {'CHGNet F (eV/Å)':<15} " + f"{'MACE E (eV)':<15} {'MACE F (eV/Å)':<15}" +) print("-" * 80) for i, system_name in enumerate(["Al", "C", "Mg"]): # Get states single_state = ts.io.atoms_to_state([[al_atoms, c_atoms, mg_atoms][i]], device, dtype) - + # CHGNet results chgnet_result = model.forward(single_state) - chgnet_energy = chgnet_result['energy'].item() - chgnet_force = torch.norm(chgnet_result['forces'], dim=1).max().item() - + chgnet_energy = chgnet_result["energy"].item() + chgnet_force = torch.norm(chgnet_result["forces"], dim=1).max().item() + # MACE results if mace_available: mace_result = mace_model.forward(single_state) - mace_energy = mace_result['energy'].item() - mace_force = torch.norm(mace_result['forces'], dim=1).max().item() - print(f"{system_name:<10} {chgnet_energy:<15.6f} {chgnet_force:<15.6f} {mace_energy:<15.6f} {mace_force:<15.6f}") + mace_energy = mace_result["energy"].item() + mace_force = torch.norm(mace_result["forces"], dim=1).max().item() + print( + f"{system_name:<10} {chgnet_energy:<15.6f} {chgnet_force:<15.6f} " + f"{mace_energy:<15.6f} {mace_force:<15.6f}" + ) else: - print(f"{system_name:<10} {chgnet_energy:<15.6f} {chgnet_force:<15.6f} {'N/A':<15} {'N/A':<15}") + print( + f"{system_name:<10} {chgnet_energy:<15.6f} {chgnet_force:<15.6f} " + f"{'N/A':<15} {'N/A':<15}" + ) # In this table we compare CHGNet and MACE on the displaced and optimized structures print("\nTABLE 2: Displaced and Optimized Structures") print("=" * 100) -print(f"{'System':<10} {'CHGNet Init E':<15} {'CHGNet Fin E':<15} {'CHGNet Fin F':<15} {'MACE Init E':<15} {'MACE Fin E':<15} {'MACE Fin F':<15}") +print( + f"{'System':<10} {'CHGNet Init E':<15} {'CHGNet Fin E':<15} " + f"{'CHGNet Fin F':<15} {'MACE Init E':<15} {'MACE Fin E':<15} " + f"{'MACE Fin F':<15}" +) print("-" * 120) -for i, (atoms, system_name) in enumerate(zip([al_atoms, c_atoms, ca_tio3_atoms], ["Al", "C", "CaTiO3"])): +for atoms, system_name in zip( + [al_atoms, c_atoms, ca_tio3_atoms], ["Al", "C", "CaTiO3"], strict=False +): # Create displaced state single_state = ts.io.atoms_to_state([atoms], device, dtype) displacement = torch.randn_like(single_state.positions) * 0.1 displaced_state = single_state.clone() displaced_state.positions = single_state.positions + displacement - + # CHGNet optimization chgnet_initial = model.forward(displaced_state) - chgnet_initial_energy = chgnet_initial['energy'].item() - + chgnet_initial_energy = chgnet_initial["energy"].item() + chgnet_optimized = ts.optimize( displaced_state, model, optimizer=ts.optimizers.Optimizer.fire, max_steps=100, ) - + chgnet_final = model.forward(chgnet_optimized) - chgnet_final_energy = chgnet_final['energy'].item() - chgnet_final_force = torch.norm(chgnet_final['forces'], dim=1).max().item() - + chgnet_final_energy = chgnet_final["energy"].item() + chgnet_final_force = torch.norm(chgnet_final["forces"], dim=1).max().item() + # MACE optimization if mace_available: mace_initial = mace_model.forward(displaced_state) - mace_initial_energy = mace_initial['energy'].item() - + mace_initial_energy = mace_initial["energy"].item() + mace_optimized = ts.optimize( displaced_state, mace_model, optimizer=ts.optimizers.Optimizer.fire, max_steps=100, ) - + mace_final = mace_model.forward(mace_optimized) - mace_final_energy = mace_final['energy'].item() - mace_final_force = torch.norm(mace_final['forces'], dim=1).max().item() - - print(f"{system_name:<10} {chgnet_initial_energy:<15.6f} {chgnet_final_energy:<15.6f} {chgnet_final_force:<15.6f} {mace_initial_energy:<15.6f} {mace_final_energy:<15.6f} {mace_final_force:<15.6f}") + mace_final_energy = mace_final["energy"].item() + mace_final_force = torch.norm(mace_final["forces"], dim=1).max().item() + + print( + f"{system_name:<10} {chgnet_initial_energy:<15.6f} " + f"{chgnet_final_energy:<15.6f} {chgnet_final_force:<15.6f} " + f"{mace_initial_energy:<15.6f} {mace_final_energy:<15.6f} " + f"{mace_final_force:<15.6f}" + ) else: - print(f"{system_name:<10} {chgnet_initial_energy:<15.6f} {chgnet_final_energy:<15.6f} {chgnet_final_force:<15.6f} {'N/A':<15} {'N/A':<15} {'N/A':<15}") + print( + f"{system_name:<10} {chgnet_initial_energy:<15.6f} " + f"{chgnet_final_energy:<15.6f} {chgnet_final_force:<15.6f} " + f"{'N/A':<15} {'N/A':<15} {'N/A':<15}" + ) -print("\n" + "="*100) +print("\n" + "=" * 100) print("CHGNet example completed successfully!") -print("="*100) - +print("=" * 100) diff --git a/tests/models/test_chgnet.py b/tests/models/test_chgnet.py index 394a4972..2ffa23d0 100644 --- a/tests/models/test_chgnet.py +++ b/tests/models/test_chgnet.py @@ -6,16 +6,17 @@ import torch_sim as ts from tests.conftest import DEVICE -from tests.models.conftest import ( - make_validate_model_outputs_test, -) +from tests.models.conftest import make_validate_model_outputs_test try: from chgnet.model.model import CHGNet + from torch_sim.models.chgnet import CHGNetModel except (ImportError, ValueError): - pytest.skip(f"CHGNet not installed: {traceback.format_exc()}", allow_module_level=True) + pytest.skip( + f"CHGNet not installed: {traceback.format_exc()}", allow_module_level=True + ) DTYPE = torch.float32 @@ -43,7 +44,7 @@ def test_chgnet_dtype_working(si_atoms: Atoms, dtype: torch.dtype) -> None: state = ts.io.atoms_to_state([si_atoms], DEVICE, dtype) result = model.forward(state) - + # Check that results have correct dtype assert result["energy"].dtype == dtype assert result["forces"].dtype == dtype @@ -53,27 +54,30 @@ def test_chgnet_dtype_working(si_atoms: Atoms, dtype: torch.dtype) -> None: def test_chgnet_batched_calculations() -> None: """Test that CHGNet handles batched calculations correctly.""" from ase.build import bulk - + # Create multiple systems si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) cu_atoms = bulk("Cu", "fcc", a=3.6, cubic=True) - + model = CHGNetModel( device=DEVICE, dtype=DTYPE, compute_forces=True, compute_stress=True, ) - + # Test batched calculation state = ts.io.atoms_to_state([si_atoms, cu_atoms], DEVICE, DTYPE) result = model.forward(state) - + # Check output shapes assert result["energy"].shape == (2,) - assert result["forces"].shape == (si_atoms.get_global_number_of_atoms() + cu_atoms.get_global_number_of_atoms(), 3) + assert result["forces"].shape == ( + si_atoms.get_global_number_of_atoms() + cu_atoms.get_global_number_of_atoms(), + 3, + ) assert result["stress"].shape == (2, 3, 3) - + # Check that energies are different (different materials) assert not torch.allclose(result["energy"][0], result["energy"][1], atol=1e-3) @@ -81,31 +85,47 @@ def test_chgnet_batched_calculations() -> None: def test_chgnet_single_vs_batched_consistency() -> None: """Test that single and batched calculations give consistent results.""" from ase.build import bulk - + si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) - + model = CHGNetModel( device=DEVICE, dtype=DTYPE, compute_forces=True, compute_stress=True, ) - + # Single system calculation single_state = ts.io.atoms_to_state([si_atoms], DEVICE, DTYPE) single_result = model.forward(single_state) - + # Batched calculation (same system twice) batched_state = ts.io.atoms_to_state([si_atoms, si_atoms], DEVICE, DTYPE) batched_result = model.forward(batched_state) - + # Check consistency - assert torch.allclose(single_result["energy"][0], batched_result["energy"][0], atol=1e-5) - assert torch.allclose(single_result["energy"][0], batched_result["energy"][1], atol=1e-5) - assert torch.allclose(single_result["forces"], batched_result["forces"][:single_state.n_atoms], atol=1e-5) - assert torch.allclose(single_result["forces"], batched_result["forces"][single_state.n_atoms:], atol=1e-5) - assert torch.allclose(single_result["stress"][0], batched_result["stress"][0], atol=1e-5) - assert torch.allclose(single_result["stress"][0], batched_result["stress"][1], atol=1e-5) + assert torch.allclose( + single_result["energy"][0], batched_result["energy"][0], atol=1e-5 + ) + assert torch.allclose( + single_result["energy"][0], batched_result["energy"][1], atol=1e-5 + ) + assert torch.allclose( + single_result["forces"], + batched_result["forces"][: single_state.n_atoms], + atol=1e-5, + ) + assert torch.allclose( + single_result["forces"], + batched_result["forces"][single_state.n_atoms :], + atol=1e-5, + ) + assert torch.allclose( + single_result["stress"][0], batched_result["stress"][0], atol=1e-5 + ) + assert torch.allclose( + single_result["stress"][0], batched_result["stress"][1], atol=1e-5 + ) def test_chgnet_missing_atomic_numbers() -> None: @@ -115,7 +135,7 @@ def test_chgnet_missing_atomic_numbers() -> None: dtype=DTYPE, compute_forces=True, ) - + # Create state without atomic numbers by using a state dict state_dict = { "positions": torch.randn(8, 3, device=DEVICE, dtype=DTYPE), @@ -124,7 +144,7 @@ def test_chgnet_missing_atomic_numbers() -> None: "atomic_numbers": None, # Missing atomic numbers "system_idx": torch.zeros(8, dtype=torch.long, device=DEVICE), } - + with pytest.raises(ValueError, match="Atomic numbers must be provided"): model.forward(state_dict) @@ -133,18 +153,19 @@ def test_chgnet_custom_model() -> None: """Test that CHGNet can be initialized with a custom model.""" # Load a custom model instance custom_model = CHGNet.load() - + model = CHGNetModel( model=custom_model, device=DEVICE, dtype=DTYPE, compute_forces=True, ) - + from ase.build import bulk + si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) state = ts.io.atoms_to_state([si_atoms], DEVICE, DTYPE) - + result = model.forward(state) assert "energy" in result assert "forces" in result @@ -159,11 +180,12 @@ def test_chgnet_compute_forces_false() -> None: compute_forces=False, compute_stress=True, ) - + from ase.build import bulk + si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) state = ts.io.atoms_to_state([si_atoms], DEVICE, DTYPE) - + result = model.forward(state) assert "energy" in result assert "forces" not in result @@ -178,11 +200,12 @@ def test_chgnet_compute_stress_false() -> None: compute_forces=True, compute_stress=False, ) - + from ase.build import bulk + si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) state = ts.io.atoms_to_state([si_atoms], DEVICE, DTYPE) - + result = model.forward(state) assert "energy" in result assert "forces" in result @@ -197,11 +220,12 @@ def test_chgnet_compute_both_false() -> None: compute_forces=False, compute_stress=False, ) - + from ase.build import bulk + si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) state = ts.io.atoms_to_state([si_atoms], DEVICE, DTYPE) - + result = model.forward(state) assert "energy" in result assert "forces" not in result @@ -216,7 +240,7 @@ def test_chgnet_compute_both_false() -> None: def test_chgnet_import_error() -> None: """Test that CHGNetModel raises ImportError when CHGNet is not available.""" from torch_sim.models.chgnet import CHGNetModel - + # Should not raise an error when CHGNet is available model = CHGNetModel(device=DEVICE, dtype=DTYPE) assert isinstance(model, CHGNetModel) diff --git a/torch_sim/models/chgnet.py b/torch_sim/models/chgnet.py index fb688353..2deb2a10 100644 --- a/torch_sim/models/chgnet.py +++ b/torch_sim/models/chgnet.py @@ -103,7 +103,7 @@ def __init__( # Move model to device self.model = self.model.to(self._device) - if hasattr(self.model, 'eval'): + if hasattr(self.model, "eval"): self.model = self.model.eval() def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: @@ -159,35 +159,40 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # Process energy (CHGNet returns energy per atom, multiply by number of atoms) energies = [] for i, result in enumerate(chgnet_results): - chgnet_energy_per_atom = result['e'].item() if hasattr(result['e'], 'item') else result['e'] - + chgnet_energy_per_atom = ( + result["e"].item() if hasattr(result["e"], "item") else result["e"] + ) + # Get number of atoms in this structure structure = structures[i] total_atoms = len(structure) - + # Multiply by number of atoms to get total energy total_energy = chgnet_energy_per_atom * total_atoms - - energies.append(torch.tensor(total_energy, device=self.device, dtype=self.dtype)) - + + energies.append( + torch.tensor(total_energy, device=self.device, dtype=self.dtype) + ) + results["energy"] = torch.stack(energies) - # Process forces + # Process forces if self.compute_forces: - forces_list = [] - for result in chgnet_results: - forces_list.append( - torch.tensor(result['f'], device=self.device, dtype=self.dtype) - ) + forces_list = [ + torch.tensor(result["f"], device=self.device, dtype=self.dtype) + for result in chgnet_results + ] forces = torch.cat(forces_list, dim=0) results["forces"] = forces - # Process stress + # Process stress if self.compute_stress: - stresses = torch.stack([ - torch.tensor(result['s'], device=self.device, dtype=self.dtype) - for result in chgnet_results - ]) + stresses = torch.stack( + [ + torch.tensor(result["s"], device=self.device, dtype=self.dtype) + for result in chgnet_results + ] + ) results["stress"] = stresses return results From b60a0c00a7cb9babe1b99bdbd17157e66bea6d2b Mon Sep 17 00:00:00 2001 From: Bruno Cucco Date: Wed, 15 Oct 2025 17:04:30 -0500 Subject: [PATCH 3/5] Apply ruff formatting to CHGNet files --- .github/workflows/test.yml | 1 + docs/conf.py | 1 + torch_sim/models/chgnet.py | 12 ++++++++++++ 3 files changed, 14 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cfbb9635..541527cb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -58,6 +58,7 @@ jobs: - { python: '3.12', resolution: highest } - { python: '3.13', resolution: lowest-direct } model: + - { name: chgnet, test_path: "tests/models/test_chgnet.py" } - { name: fairchem, test_path: "tests/models/test_fairchem.py" } - { name: fairchem-legacy, test_path: "tests/models/test_fairchem_legacy.py" } - { name: graphpes, test_path: "tests/models/test_graphpes.py" } diff --git a/docs/conf.py b/docs/conf.py index b42066ac..43b8fc5e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -62,6 +62,7 @@ ] autodoc_mock_imports = [ + "chgnet", "fairchem", "mace", "mattersim", diff --git a/torch_sim/models/chgnet.py b/torch_sim/models/chgnet.py index 2deb2a10..c31719a2 100644 --- a/torch_sim/models/chgnet.py +++ b/torch_sim/models/chgnet.py @@ -124,6 +124,8 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: - 'forces': Atomic forces with shape [n_atoms, 3] if compute_forces=True - 'stress': System stresses with shape [n_systems, 3, 3] if compute_stress=True + - 'magnetic_moments': Magnetic moments with shape [n_atoms, 3] if + available in CHGNet output Raises: ValueError: If atomic numbers are not provided in the state. @@ -195,4 +197,14 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: ) results["stress"] = stresses + # Process magnetic moments (if available) + if "m" in chgnet_results[0]: + magnetic_moments_list = [ + torch.tensor(result["m"], device=self.device, dtype=self.dtype) + for result in chgnet_results + ] + # Concatenate along atom dimension, similar to forces + magnetic_moments = torch.cat(magnetic_moments_list, dim=0) + results["magnetic_moments"] = magnetic_moments + return results From 95da3280850af892e6033fc4d475305f51801e56 Mon Sep 17 00:00:00 2001 From: Bruno Cucco Date: Thu, 23 Oct 2025 13:26:19 -0500 Subject: [PATCH 4/5] Addressed comments and suggestions given by the reviewers - Expanded CHGNet model wrapper to include stress and magnetic moments. - Add custom ASE calculator for consistency testing. - Integrate CHGNet into CI pipeline and documentation. - Enhanced example script with stress/magnetic moment output and reduced printing with a single clean table. - Rewrote tests to follow MatterSim, removing redundant tests. --- examples/scripts/1_Introduction/1.4_CHGNet.py | 150 +++++------ tests/models/test_chgnet.py | 239 +++++------------- 2 files changed, 121 insertions(+), 268 deletions(-) diff --git a/examples/scripts/1_Introduction/1.4_CHGNet.py b/examples/scripts/1_Introduction/1.4_CHGNet.py index 95f44fa1..a2d92e23 100644 --- a/examples/scripts/1_Introduction/1.4_CHGNet.py +++ b/examples/scripts/1_Introduction/1.4_CHGNet.py @@ -1,7 +1,7 @@ """CHGNet model example for TorchSim.""" # /// script -# dependencies = ["chgnet>=0.4.2"] +# dependencies = ["chgnet>=0.4.2", "mace-torch>=0.3.12"] # /// import os @@ -10,9 +10,11 @@ import torch from ase import Atoms from ase.build import bulk +from mace.calculators.foundations_models import mace_mp import torch_sim as ts from torch_sim.models.chgnet import CHGNetModel +from torch_sim.models.mace import MaceModel, MaceUrls # Silence warnings @@ -56,120 +58,92 @@ state = ts.io.atoms_to_state([al_atoms, c_atoms, mg_atoms], device, dtype) # Load MACE model for comparison -try: - from mace.calculators.foundations_models import mace_mp - - from torch_sim.models.mace import MaceModel, MaceUrls - - raw_mace_mp = mace_mp(model=MaceUrls.mace_mp_small, return_raw_model=True) - mace_model = MaceModel( - model=raw_mace_mp, - device=device, - dtype=dtype, - compute_forces=True, - compute_stress=True, - ) - mace_available = True -except ImportError: - mace_available = False - print("MACE not available for comparison") - -# In this table we compare CHGNet and MACE on the equilibrium structures -print("\nTABLE 1: Equilibrium Structures") -print("=" * 80) +raw_mace_mp = mace_mp(model=MaceUrls.mace_mp_small, return_raw_model=True) +mace_model = MaceModel( + model=raw_mace_mp, + device=device, + dtype=dtype, + compute_forces=True, + compute_stress=True, +) +mace_available = True + +# Single comprehensive results table print( - f"{'System':<10} {'CHGNet E (eV)':<15} {'CHGNet F (eV/Å)':<15} " - f"{'MACE E (eV)':<15} {'MACE F (eV/Å)':<15}" + "\nCHGNet vs MACE Results " + "(E: Total Energy, F: Maximum Force, S: Maximum Stress, M: Maximum Magnetic Moment)" ) -print("-" * 80) +print("=" * 87) +print( + f"{'System':<10} {'CHGNet E':<12} {'CHGNet F':<12} {'CHGNet S':<12} " + f"{'CHGNet M':<12} {'MACE E':<12} {'MACE F':<12}" +) +print("-" * 87) +# Test equilibrium structures for i, system_name in enumerate(["Al", "C", "Mg"]): - # Get states single_state = ts.io.atoms_to_state([[al_atoms, c_atoms, mg_atoms][i]], device, dtype) # CHGNet results chgnet_result = model.forward(single_state) chgnet_energy = chgnet_result["energy"].item() chgnet_force = torch.norm(chgnet_result["forces"], dim=1).max().item() + chgnet_stress = torch.norm(chgnet_result["stress"], dim=(1, 2)).max().item() + chgnet_magmom = ( + torch.norm(chgnet_result.get("magnetic_moments", torch.zeros(1, 3)), dim=-1) + .max() + .item() + ) # MACE results - if mace_available: - mace_result = mace_model.forward(single_state) - mace_energy = mace_result["energy"].item() - mace_force = torch.norm(mace_result["forces"], dim=1).max().item() - print( - f"{system_name:<10} {chgnet_energy:<15.6f} {chgnet_force:<15.6f} " - f"{mace_energy:<15.6f} {mace_force:<15.6f}" - ) - else: - print( - f"{system_name:<10} {chgnet_energy:<15.6f} {chgnet_force:<15.6f} " - f"{'N/A':<15} {'N/A':<15}" - ) - -# In this table we compare CHGNet and MACE on the displaced and optimized structures -print("\nTABLE 2: Displaced and Optimized Structures") -print("=" * 100) -print( - f"{'System':<10} {'CHGNet Init E':<15} {'CHGNet Fin E':<15} " - f"{'CHGNet Fin F':<15} {'MACE Init E':<15} {'MACE Fin E':<15} " - f"{'MACE Fin F':<15}" -) -print("-" * 120) + mace_result = mace_model.forward(single_state) + mace_energy = mace_result["energy"].item() + mace_force = torch.norm(mace_result["forces"], dim=1).max().item() + print( + f"{system_name:<10} {chgnet_energy:<12.3f} {chgnet_force:<12.3f} " + f"{chgnet_stress:<12.3f} {chgnet_magmom:<12.3f} {mace_energy:<12.3f} " + f"{mace_force:<12.3f}" + ) +# Test optimization on displaced structures for atoms, system_name in zip( [al_atoms, c_atoms, ca_tio3_atoms], ["Al", "C", "CaTiO3"], strict=False ): - # Create displaced state single_state = ts.io.atoms_to_state([atoms], device, dtype) displacement = torch.randn_like(single_state.positions) * 0.1 displaced_state = single_state.clone() displaced_state.positions = single_state.positions + displacement # CHGNet optimization - chgnet_initial = model.forward(displaced_state) - chgnet_initial_energy = chgnet_initial["energy"].item() - chgnet_optimized = ts.optimize( - displaced_state, - model, - optimizer=ts.optimizers.Optimizer.fire, - max_steps=100, + displaced_state, model, optimizer=ts.optimizers.Optimizer.fire, max_steps=100 ) - chgnet_final = model.forward(chgnet_optimized) chgnet_final_energy = chgnet_final["energy"].item() chgnet_final_force = torch.norm(chgnet_final["forces"], dim=1).max().item() + chgnet_final_stress = torch.norm(chgnet_final["stress"], dim=(1, 2)).max().item() + chgnet_final_magmom = ( + torch.norm(chgnet_final.get("magnetic_moments", torch.zeros(1, 3)), dim=-1) + .max() + .item() + ) # MACE optimization - if mace_available: - mace_initial = mace_model.forward(displaced_state) - mace_initial_energy = mace_initial["energy"].item() - - mace_optimized = ts.optimize( - displaced_state, - mace_model, - optimizer=ts.optimizers.Optimizer.fire, - max_steps=100, - ) - - mace_final = mace_model.forward(mace_optimized) - mace_final_energy = mace_final["energy"].item() - mace_final_force = torch.norm(mace_final["forces"], dim=1).max().item() - - print( - f"{system_name:<10} {chgnet_initial_energy:<15.6f} " - f"{chgnet_final_energy:<15.6f} {chgnet_final_force:<15.6f} " - f"{mace_initial_energy:<15.6f} {mace_final_energy:<15.6f} " - f"{mace_final_force:<15.6f}" - ) - else: - print( - f"{system_name:<10} {chgnet_initial_energy:<15.6f} " - f"{chgnet_final_energy:<15.6f} {chgnet_final_force:<15.6f} " - f"{'N/A':<15} {'N/A':<15} {'N/A':<15}" - ) - -print("\n" + "=" * 100) + mace_optimized = ts.optimize( + displaced_state, + mace_model, + optimizer=ts.optimizers.Optimizer.fire, + max_steps=100, + ) + mace_final = mace_model.forward(mace_optimized) + mace_final_energy = mace_final["energy"].item() + mace_final_force = torch.norm(mace_final["forces"], dim=1).max().item() + print( + f"{system_name + '_opt':<10} {chgnet_final_energy:<12.3f} " + f"{chgnet_final_force:<12.3f} {chgnet_final_stress:<12.3f} " + f"{chgnet_final_magmom:<12.3f} {mace_final_energy:<12.3f} " + f"{mace_final_force:<12.3f}" + ) + +print("=" * 87) print("CHGNet example completed successfully!") -print("=" * 100) diff --git a/tests/models/test_chgnet.py b/tests/models/test_chgnet.py index 2ffa23d0..d7dfe254 100644 --- a/tests/models/test_chgnet.py +++ b/tests/models/test_chgnet.py @@ -1,12 +1,16 @@ import traceback +from typing import Any, ClassVar import pytest import torch from ase.atoms import Atoms +from ase.calculators.calculator import Calculator, all_changes -import torch_sim as ts from tests.conftest import DEVICE -from tests.models.conftest import make_validate_model_outputs_test +from tests.models.conftest import ( + make_model_calculator_consistency_test, + make_validate_model_outputs_test, +) try: @@ -19,113 +23,65 @@ ) -DTYPE = torch.float32 - +class CHGNetCalculator(Calculator): + """ASE Calculator wrapper for CHGNet.""" -@pytest.fixture -def ts_chgnet_model() -> CHGNetModel: - """Create a TorchSim CHGNet model for testing.""" - return CHGNetModel( - device=DEVICE, - dtype=DTYPE, - compute_forces=True, - compute_stress=True, - ) + implemented_properties: ClassVar[list[str]] = ["energy", "forces", "stress"] + def __init__(self, model: CHGNet | None = None, **kwargs) -> None: + Calculator.__init__(self, **kwargs) + self.model = model or CHGNet.load() -@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) -def test_chgnet_dtype_working(si_atoms: Atoms, dtype: torch.dtype) -> None: - """Test that CHGNet works with different dtypes.""" - model = CHGNetModel( - device=DEVICE, - dtype=dtype, - compute_forces=True, - ) + def calculate( + self, + atoms: Atoms | None = None, + properties: list[str] | None = None, + system_changes: Any = all_changes, + ): + if properties is None: + properties = ["energy"] + Calculator.calculate(self, atoms, properties, system_changes) - state = ts.io.atoms_to_state([si_atoms], DEVICE, dtype) - result = model.forward(state) + # Convert ASE atoms to pymatgen Structure + from pymatgen.io.ase import AseAtomsAdaptor - # Check that results have correct dtype - assert result["energy"].dtype == dtype - assert result["forces"].dtype == dtype - assert result["stress"].dtype == dtype + structure = AseAtomsAdaptor.get_structure(atoms) + # Get CHGNet predictions + result = self.model.predict_structure(structure) -def test_chgnet_batched_calculations() -> None: - """Test that CHGNet handles batched calculations correctly.""" - from ase.build import bulk + # Convert to ASE format + self.results = {} + if "energy" in properties: + # CHGNet returns energy per atom, convert to total energy + self.results["energy"] = result["e"] * len(structure) - # Create multiple systems - si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) - cu_atoms = bulk("Cu", "fcc", a=3.6, cubic=True) + if "forces" in properties: + self.results["forces"] = result["f"] - model = CHGNetModel( - device=DEVICE, - dtype=DTYPE, - compute_forces=True, - compute_stress=True, - ) + if "stress" in properties: + self.results["stress"] = result["s"] - # Test batched calculation - state = ts.io.atoms_to_state([si_atoms, cu_atoms], DEVICE, DTYPE) - result = model.forward(state) - # Check output shapes - assert result["energy"].shape == (2,) - assert result["forces"].shape == ( - si_atoms.get_global_number_of_atoms() + cu_atoms.get_global_number_of_atoms(), - 3, - ) - assert result["stress"].shape == (2, 3, 3) - - # Check that energies are different (different materials) - assert not torch.allclose(result["energy"][0], result["energy"][1], atol=1e-3) - - -def test_chgnet_single_vs_batched_consistency() -> None: - """Test that single and batched calculations give consistent results.""" - from ase.build import bulk +DTYPE = torch.float32 - si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) - model = CHGNetModel( +@pytest.fixture +def ts_chgnet_model() -> CHGNetModel: + """Create a TorchSim CHGNet model for testing.""" + return CHGNetModel( device=DEVICE, dtype=DTYPE, compute_forces=True, compute_stress=True, ) - # Single system calculation - single_state = ts.io.atoms_to_state([si_atoms], DEVICE, DTYPE) - single_result = model.forward(single_state) - - # Batched calculation (same system twice) - batched_state = ts.io.atoms_to_state([si_atoms, si_atoms], DEVICE, DTYPE) - batched_result = model.forward(batched_state) - # Check consistency - assert torch.allclose( - single_result["energy"][0], batched_result["energy"][0], atol=1e-5 - ) - assert torch.allclose( - single_result["energy"][0], batched_result["energy"][1], atol=1e-5 - ) - assert torch.allclose( - single_result["forces"], - batched_result["forces"][: single_state.n_atoms], - atol=1e-5, - ) - assert torch.allclose( - single_result["forces"], - batched_result["forces"][single_state.n_atoms :], - atol=1e-5, - ) - assert torch.allclose( - single_result["stress"][0], batched_result["stress"][0], atol=1e-5 - ) - assert torch.allclose( - single_result["stress"][0], batched_result["stress"][1], atol=1e-5 - ) +@pytest.fixture +def ase_chgnet_calculator(ts_chgnet_model: CHGNetModel) -> CHGNetCalculator: + """Create an ASE CHGNet calculator for testing.""" + # Use the same model instance to ensure consistency + return CHGNetCalculator(model=ts_chgnet_model.model) def test_chgnet_missing_atomic_numbers() -> None: @@ -134,6 +90,7 @@ def test_chgnet_missing_atomic_numbers() -> None: device=DEVICE, dtype=DTYPE, compute_forces=True, + compute_stress=True, ) # Create state without atomic numbers by using a state dict @@ -149,98 +106,20 @@ def test_chgnet_missing_atomic_numbers() -> None: model.forward(state_dict) -def test_chgnet_custom_model() -> None: - """Test that CHGNet can be initialized with a custom model.""" - # Load a custom model instance - custom_model = CHGNet.load() - - model = CHGNetModel( - model=custom_model, - device=DEVICE, - dtype=DTYPE, - compute_forces=True, - ) - - from ase.build import bulk - - si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) - state = ts.io.atoms_to_state([si_atoms], DEVICE, DTYPE) - - result = model.forward(state) - assert "energy" in result - assert "forces" in result - assert "stress" in result - - -def test_chgnet_compute_forces_false() -> None: - """Test CHGNet with compute_forces=False.""" - model = CHGNetModel( - device=DEVICE, - dtype=DTYPE, - compute_forces=False, - compute_stress=True, - ) - - from ase.build import bulk - - si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) - state = ts.io.atoms_to_state([si_atoms], DEVICE, DTYPE) - - result = model.forward(state) - assert "energy" in result - assert "forces" not in result - assert "stress" in result - - -def test_chgnet_compute_stress_false() -> None: - """Test CHGNet with compute_stress=False.""" - model = CHGNetModel( - device=DEVICE, - dtype=DTYPE, - compute_forces=True, - compute_stress=False, - ) - - from ase.build import bulk - - si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) - state = ts.io.atoms_to_state([si_atoms], DEVICE, DTYPE) - - result = model.forward(state) - assert "energy" in result - assert "forces" in result - assert "stress" not in result - - -def test_chgnet_compute_both_false() -> None: - """Test CHGNet with both compute_forces=False and compute_stress=False.""" - model = CHGNetModel( - device=DEVICE, - dtype=DTYPE, - compute_forces=False, - compute_stress=False, - ) - - from ase.build import bulk - - si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) - state = ts.io.atoms_to_state([si_atoms], DEVICE, DTYPE) - - result = model.forward(state) - assert "energy" in result - assert "forces" not in result - assert "stress" not in result - - test_chgnet_model_outputs = make_validate_model_outputs_test( model_fixture_name="ts_chgnet_model", dtype=DTYPE ) - -def test_chgnet_import_error() -> None: - """Test that CHGNetModel raises ImportError when CHGNet is not available.""" - from torch_sim.models.chgnet import CHGNetModel - - # Should not raise an error when CHGNet is available - model = CHGNetModel(device=DEVICE, dtype=DTYPE) - assert isinstance(model, CHGNetModel) +test_chgnet_consistency = make_model_calculator_consistency_test( + test_name="chgnet", + model_fixture_name="ts_chgnet_model", + calculator_fixture_name="ase_chgnet_calculator", + sim_state_names=("si_sim_state", "cu_sim_state", "mg_sim_state", "ti_sim_state"), + dtype=DTYPE, + energy_rtol=1e-4, + energy_atol=1e-4, + force_rtol=1e-4, + force_atol=1e-4, + stress_rtol=1e-3, + stress_atol=1e-3, +) From 35d1b3df20bc7780caa8bc91e9a534de3bd7e9b4 Mon Sep 17 00:00:00 2001 From: Bruno Cucco Date: Tue, 28 Oct 2025 16:26:45 -0500 Subject: [PATCH 5/5] Fixes for chgnet: - Replace custom CHGNetCalculator with official one from chgnet.model.dynamics - Add stress unit conversion (GPa to eV/A^3) to match ASE calculator - Remove redundant atomic numbers validation - Remove redundant test test_chgnet_missing_atomic_numbers - Fixed failing tests --- tests/models/test_chgnet.py | 73 ++----------------------------------- torch_sim/models/chgnet.py | 14 ++----- 2 files changed, 8 insertions(+), 79 deletions(-) diff --git a/tests/models/test_chgnet.py b/tests/models/test_chgnet.py index d7dfe254..d4214671 100644 --- a/tests/models/test_chgnet.py +++ b/tests/models/test_chgnet.py @@ -1,10 +1,7 @@ import traceback -from typing import Any, ClassVar import pytest import torch -from ase.atoms import Atoms -from ase.calculators.calculator import Calculator, all_changes from tests.conftest import DEVICE from tests.models.conftest import ( @@ -14,7 +11,7 @@ try: - from chgnet.model.model import CHGNet + from chgnet.model.dynamics import CHGNetCalculator as CHGNetAseCalculator from torch_sim.models.chgnet import CHGNetModel except (ImportError, ValueError): @@ -23,46 +20,6 @@ ) -class CHGNetCalculator(Calculator): - """ASE Calculator wrapper for CHGNet.""" - - implemented_properties: ClassVar[list[str]] = ["energy", "forces", "stress"] - - def __init__(self, model: CHGNet | None = None, **kwargs) -> None: - Calculator.__init__(self, **kwargs) - self.model = model or CHGNet.load() - - def calculate( - self, - atoms: Atoms | None = None, - properties: list[str] | None = None, - system_changes: Any = all_changes, - ): - if properties is None: - properties = ["energy"] - Calculator.calculate(self, atoms, properties, system_changes) - - # Convert ASE atoms to pymatgen Structure - from pymatgen.io.ase import AseAtomsAdaptor - - structure = AseAtomsAdaptor.get_structure(atoms) - - # Get CHGNet predictions - result = self.model.predict_structure(structure) - - # Convert to ASE format - self.results = {} - if "energy" in properties: - # CHGNet returns energy per atom, convert to total energy - self.results["energy"] = result["e"] * len(structure) - - if "forces" in properties: - self.results["forces"] = result["f"] - - if "stress" in properties: - self.results["stress"] = result["s"] - - DTYPE = torch.float32 @@ -78,32 +35,10 @@ def ts_chgnet_model() -> CHGNetModel: @pytest.fixture -def ase_chgnet_calculator(ts_chgnet_model: CHGNetModel) -> CHGNetCalculator: +def ase_chgnet_calculator() -> CHGNetAseCalculator: """Create an ASE CHGNet calculator for testing.""" - # Use the same model instance to ensure consistency - return CHGNetCalculator(model=ts_chgnet_model.model) - - -def test_chgnet_missing_atomic_numbers() -> None: - """Test that CHGNet raises appropriate error when atomic numbers are missing.""" - model = CHGNetModel( - device=DEVICE, - dtype=DTYPE, - compute_forces=True, - compute_stress=True, - ) - - # Create state without atomic numbers by using a state dict - state_dict = { - "positions": torch.randn(8, 3, device=DEVICE, dtype=DTYPE), - "cell": torch.eye(3, device=DEVICE, dtype=DTYPE).unsqueeze(0), - "pbc": True, - "atomic_numbers": None, # Missing atomic numbers - "system_idx": torch.zeros(8, dtype=torch.long, device=DEVICE), - } - - with pytest.raises(ValueError, match="Atomic numbers must be provided"): - model.forward(state_dict) + # Use the official CHGNet calculator + return CHGNetAseCalculator() test_chgnet_model_outputs = make_validate_model_outputs_test( diff --git a/torch_sim/models/chgnet.py b/torch_sim/models/chgnet.py index c31719a2..ecace2a3 100644 --- a/torch_sim/models/chgnet.py +++ b/torch_sim/models/chgnet.py @@ -127,23 +127,14 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: - 'magnetic_moments': Magnetic moments with shape [n_atoms, 3] if available in CHGNet output - Raises: - ValueError: If atomic numbers are not provided in the state. """ # Handle state dict - if isinstance(state, dict) and state.get("atomic_numbers") is None: - raise ValueError("Atomic numbers must be provided in the state for CHGNet.") - sim_state = ( state if isinstance(state, ts.SimState) else ts.SimState(**state, masses=torch.ones_like(state["positions"])) ) - # Validate that atomic numbers - if sim_state.atomic_numbers is None: - raise ValueError("Atomic numbers must be provided in the state for CHGNet.") - # Convert SimState to list of pymatgen Structures structures = sim_state.to_structures() @@ -195,7 +186,10 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: for result in chgnet_results ] ) - results["stress"] = stresses + # Convert from GPa to eV/A^3 to match ASE calculator convention + # stress_weight converts GPa to eV/A^3 (approximately 1/160.21) + stress_weight = 0.006241509125883258 + results["stress"] = stresses * stress_weight # Process magnetic moments (if available) if "m" in chgnet_results[0]: