diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index 17a705f1..31901e16 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -1,6 +1,8 @@ import traceback +import numpy as np import pytest +from ase.geometry.cell import cell_to_cellpar as ase_c2p from tests.conftest import DEVICE from tests.models.conftest import ( @@ -8,6 +10,8 @@ make_model_calculator_consistency_test, make_validate_model_outputs_test, ) +from torch_sim import SimState +from torch_sim.models.orb import cell_to_cellpar try: @@ -74,3 +78,16 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator: test_validate_direct_model_outputs = make_validate_model_outputs_test( model_fixture_name="orbv3_direct_20_omat_model", ) + + +def test_cell_to_cellpar(ti_sim_state: SimState) -> None: + assert np.allclose( + ase_c2p(ti_sim_state.row_vector_cell.squeeze()), + cell_to_cellpar(ti_sim_state.row_vector_cell.squeeze(0)).cpu().numpy(), + ) + assert np.allclose( + ase_c2p(ti_sim_state.row_vector_cell.squeeze(), radians=True), + cell_to_cellpar(ti_sim_state.row_vector_cell.squeeze(0), radians=True) + .cpu() + .numpy(), + ) diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 5ddb1d09..e3de08e0 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -29,7 +29,6 @@ try: - from ase.geometry import cell_to_cellpar from orb_models.forcefield import featurization_utilities as feat_util from orb_models.forcefield.atomic_system import SystemConfig from orb_models.forcefield.base import AtomGraphs, _map_concat @@ -59,6 +58,37 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: from torch_sim.typing import StateDict +def cell_to_cellpar( + cell: torch.Tensor, + radians: bool = False, # noqa: FBT001, FBT002 +) -> torch.Tensor: + """Returns the cell parameters [a, b, c, alpha, beta, gamma]. + torch version of ase's cell_to_cellpar. + + Args: + cell: lattice vector in row vector convention, same as ase + radians: If True, return angles in radians. Otherwise, return degrees (default). + + Returns: + Tensor with [a, b, c, alpha, beta, gamma]. + """ + lengths = torch.linalg.norm(cell, dim=1) + angles = [] + for i in range(3): + j = i - 1 + k = i - 2 + ll = lengths[j] * lengths[k] + if ll.item() > 1e-16: + x = torch.dot(cell[j], cell[k]) / ll + angle = 180.0 / torch.pi * torch.arccos(x) + else: + angle = 90.0 + angles.append(angle) + if radians: + angles = [angle * torch.pi / 180 for angle in angles] + return torch.concat((torch.tensor(lengths), torch.tensor(angles))) + + def state_to_atom_graphs( # noqa: PLR0915 state: ts.SimState, *, @@ -181,9 +211,7 @@ def state_to_atom_graphs( # noqa: PLR0915 num_edges.append(len(edges[0])) # Calculate lattice parameters - lattice_per_system = torch.from_numpy( - cell_to_cellpar(cell_per_system.squeeze(0).cpu().numpy()) - ) + lattice_per_system = cell_to_cellpar(cell_per_system.squeeze(0)) # Create features dictionaries node_feats = {