Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions aimnet/calculators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@
from .aimnet2pysis import AIMNet2Pysis # noqa: F401

__all__.append("AIMNet2Pysis")

if importlib.util.find_spec("torch_sim") is not None:
from .aimnet2torchsim import AIMNet2TorchSim # noqa: F401

__all__.append("AIMNet2TorchSim")
146 changes: 146 additions & 0 deletions aimnet/calculators/aimnet2torchsim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""torch-sim ModelInterface wrapper for AimNet2.

This module provides a TorchSim wrapper of the AimNet2 model for computing
energies, forces, and stresses for atomistic systems.
"""

from collections.abc import Callable
from pathlib import Path

import torch

from .calculator import AIMNet2Calculator

try:
import torch_sim as ts
from torch_sim.models.interface import ModelInterface, SimState, StateDict
except ImportError:
raise ImportError("torch-sim is not installed. Please install it using `pip install torch-sim-atomistic`.") # noqa: B904


def state_to_aimnet2_data(state: ts.SimState) -> dict[str, torch.Tensor]:
positions = state.positions
cell = state.row_vector_cell
z = state.atomic_numbers.long()
charge = state.charge
spin = state.spin
mol_idx = state.system_idx
data = {
"coord": positions,
"numbers": z,
"charge": charge,
"spin": spin,
"mol_idx": mol_idx,
}
# Handle periodic cells:
# - If cell is all zeros, treat as non-periodic and omit "cell"
# - If all batched cells are identical, use a single (3, 3) cell
# - Otherwise, keep the batched (B, 3, 3) cell so each system can have its own box
# Ensure we are working with a tensor (torch-sim may return None for non-periodic systems)
# Keep the cell tensor rank consistent with the incoming state to avoid
# downstream shape differences (e.g. stress rank changes) across chunks.
if isinstance(cell, torch.Tensor) and not torch.allclose(cell, torch.zeros_like(cell)):
data["cell"] = cell.contiguous()
return data


def state_dict_to_aimnet2_data(state: StateDict) -> dict[str, torch.Tensor]:
data: dict[str, torch.Tensor] = {}
if "positions" in state:
data["coord"] = state["positions"]
if "cell" in state:
data["cell"] = state["cell"]
if "atomic_numbers" in state:
data["numbers"] = state["atomic_numbers"]
if "charge" in state:
data["charge"] = state["charge"]
if "mol_idx" in state:
data["mol_idx"] = state["system_idx"]
cell = state["cell"]
if isinstance(cell, torch.Tensor) and not torch.allclose(cell, torch.zeros_like(cell)):
data["cell"] = cell.contiguous()
return data


class AIMNet2TorchSim(ModelInterface):
"""Computes energies, forces, and stresses for atomistic systems using the AIMNet2 model.

Attributes
----------
model : nn.Module
The loaded AIMNet2 model.
_device : str
Device the model is running on ("cuda" or "cpu").
_dtype: torch.dtype
_compute_stress: bool
implemented_properties: list[str]

"""

def __init__(
self,
base_calc: AIMNet2Calculator,
neighbor_list_fn: Callable | None = None,
*,
model_cache_dir: str | Path | None = None,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
compute_stress: bool = False,
):
"""Initial the AIMNet2TorchSim model.

Args:
base_calc: AIMNet2Calculator
The AIMNet2 calculator to use.
"""
super().__init__()
self.model = base_calc
self._device = base_calc.device
self._dtype = dtype or torch.float32
self._compute_stress = compute_stress
self._compute_forces = True
self._memory_scales_with = "n_atoms_x_density"
if neighbor_list_fn is not None:
raise NotImplementedError("Custom neighbor list is not supported for the AIMNet2 Model.")
self.predictor = base_calc.eval
self.implemented_properties = ["energy", "forces", "charges"]
if base_calc.is_nse:
self.implemented_properties.append("spin_charges")
if self._compute_stress:
self.implemented_properties.append("stress")

@property
def dtype(self) -> torch.dtype:
return self._dtype

@property
def device(self) -> torch.device:
return self._device

def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tensor]:
"""Compute energies, forces, and other properties.

Args:
state (SimState): State object containing positions, cells, atomic numbers,
and other system information.
**_kwargs: Unused; accepted for interface compatibility.

Returns:
dict: Dictionary of model predictions, which may include:
- energy (torch.Tensor): Energy with shape [batch_size]
- forces (torch.Tensor): Forces with shape [n_atoms, 3]
- stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3]
"""
if isinstance(state, SimState):
if state.device != self._device:
state = state.to(self._device)
data = state_to_aimnet2_data(state)
# Ensure system_idx has integer dtype
if state.system_idx.dtype != torch.int64:
data["mol_idx"] = data["mol_idx"].to(torch.int64)
elif isinstance(state, StateDict):
data = state_dict_to_aimnet2_data(state)

results = self.model(data, forces=True, stress=self._compute_stress)

return results
93 changes: 93 additions & 0 deletions examples/1119028.cif
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#######################################################################
#
# This file contains crystal structure data downloaded from the
# Cambridge Structural Database (CSD) hosted by the Cambridge
# Crystallographic Data Centre (CCDC).
#
# Full information about CCDC data access policies and citation
# guidelines are available at http://www.ccdc.cam.ac.uk/access/V1
#
# Audit and citation data items may have been added by the CCDC.
# Please retain this information to preserve the provenance of
# this file and to allow appropriate attribution of the data.
#
#######################################################################

data_CAFINE
#This CIF has been generated from an entry in the Cambridge Structural Database
_database_code_depnum_ccdc_archive 'CCDC 1119028'
_database_code_CSD CAFINE
loop_
_citation_id
_citation_doi
_citation_year
1 10.1107/S0365110X58001286 1958
_audit_creation_method 'Created from the CSD'
_audit_update_record
;
2026-03-06 downloaded from the CCDC.
;
_database_code_NBS 504758
_chemical_name_common 'Caffeine monohydrate'
_chemical_formula_moiety 'C8 H10 N4 O2,H2 O1'
_chemical_name_systematic '1,3,7-Trimethyl-purine-2,6-dione monohydrate'
_chemical_properties_biological 'stimulant which increases CNS activity'
_chemical_absolute_configuration unk
_diffrn_ambient_temperature 295
_exptl_crystal_density_diffrn 1.447
#These two values have been output from a single CSD field.
_refine_ls_R_factor_gt 0.146
_refine_ls_wR_factor_gt 0.146
_diffrn_radiation_probe x-ray
_symmetry_cell_setting monoclinic
_symmetry_space_group_name_H-M 'P 21/a'
_symmetry_Int_Tables_number 14
_space_group_name_Hall '-P 2yab'
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 x,y,z
2 1/2-x,1/2+y,-z
3 -x,-y,-z
4 1/2+x,1/2-y,z
_cell_length_a 14.8(1)
_cell_length_b 16.7(1)
_cell_length_c 3.97(3)
_cell_angle_alpha 90
_cell_angle_beta 97.0(5)
_cell_angle_gamma 90
_cell_volume 973.911
_cell_formula_units_Z 4
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
C1 C 0.24140 0.22250 -0.09980
C2 C 0.10030 0.25330 0.12950
C3 C 0.08410 0.17590 0.19440
C4 C 0.14630 0.11430 0.11550
C5 C -0.01990 0.25200 0.36380
C6 C 0.28910 0.08320 -0.12100
C7 C 0.19590 0.36380 -0.07910
C8 C -0.04640 0.10530 0.45840
N1 N 0.21960 0.14150 -0.02650
N2 N 0.18010 0.27690 -0.01520
N3 N 0.00200 0.17490 0.33760
N4 N 0.04030 0.30080 0.24400
O1 O 0.30630 0.24000 -0.23860
O2 O 0.13630 0.04040 0.16160
H1 H -0.08700 0.26100 0.47400
H2 H -0.01300 0.06200 0.59900
H3 H -0.06500 0.06300 0.27800
H4 H -0.10500 0.13700 0.51000
H5 H 0.26300 0.36200 -0.14300
H6 H 0.22800 0.39600 0.10500
H7 H 0.14200 0.37700 -0.21700
H8 H 0.34800 0.10000 -0.22800
H9 H 0.30000 0.03300 0.02200
H10 H 0.25700 0.06000 -0.32400
O3 O 0.01840 0.47050 0.27050

#END
2 changes: 1 addition & 1 deletion examples/ase_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def torch_show_device_into():
opt = LBFGS(atoms)

# run optimization
t0 = perf_counter()
t0: int | float = perf_counter()
print(f"Running optimization for {len(atoms)} atoms molecule.")
opt.run(fmax=0.01)
t1 = perf_counter()
Expand Down
41 changes: 41 additions & 0 deletions examples/ts_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
from time import perf_counter

import ase.io
import torch_sim as ts

from aimnet.calculators import AIMNet2Calculator, AIMNet2TorchSim


def torch_show_device_into():
import torch

print(f"Torch version: {torch.__version__}")
if torch.cuda.is_available():
print(f"CUDA available, version {torch.version.cuda}, device: {torch.cuda.get_device_name()}") # type: ignore
else:
print("CUDA not available")


torch_show_device_into()
# 59 conformations of taxol
xyzfile = os.path.join(os.path.dirname(__file__), "taxol.xyz")

# read the first one
atoms = ase.io.read(xyzfile, index=0)

# create the calculator with default model
base_calc = AIMNet2Calculator("aimnet2")
calc = AIMNet2TorchSim(base_calc)

t0: int | float = perf_counter()
n_systems = 500
systems = [atoms] * n_systems
for _i in range(n_systems):
systems[_i].info["charge"] = _i / n_systems

print(f"Running optimization for {len(atoms)} atoms molecule with {n_systems} systems.")
final_state = ts.optimize(system=systems, model=calc, optimizer=ts.Optimizer.fire, autobatcher=True, pbar=True)
final_atoms = final_state.to_atoms()
t1 = perf_counter()
print(f"Completed optimization in {t1 - t0:.1f} s")
49 changes: 49 additions & 0 deletions examples/ts_opt_pbc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os

import ase.io
import torch_sim as ts

from aimnet.calculators import AIMNet2Calculator, AIMNet2TorchSim


def torch_show_device_into():
import torch

print(f"Torch version: {torch.__version__}")
if torch.cuda.is_available():
print(f"CUDA available, version {torch.version.cuda}, device: {torch.cuda.get_device_name()}") # type: ignore
else:
print("CUDA not available")


torch_show_device_into()
ciffile_1 = os.path.join(os.path.dirname(__file__), "2019828.cif")
ciffile_2 = os.path.join(os.path.dirname(__file__), "1119028.cif")

# read the first one
atoms_1 = ase.io.read(ciffile_1)
atoms_2 = ase.io.read(ciffile_2)

# attach the calculator to the atoms object
# create the calculator with default model
base_calc = AIMNet2Calculator("aimnet2")
calc = AIMNet2TorchSim(base_calc, compute_stress=True)

final_state = ts.optimize(
system=[atoms_1],
model=calc,
optimizer=ts.Optimizer.fire,
)

n_steps = 1000
final_state = ts.integrate(
system=final_state, # Input atomic system
model=calc, # Energy/force model
integrator=ts.Integrator.nvt_vrescale, # Integrator to use
n_steps=n_steps, # Number of MD steps
temperature=300, # Target temperature (K)
timestep=0.0005, # Integration timestep (ps)
# external_pressure=0.0, # Target external pressure (GPa)
trajectory_reporter={"filenames": ["2019828.h5"]},
pbar=True,
)
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ authors = [
]
description = "AIMNet Machine Learned Interatomic Potential"
readme = "README.md"
requires-python = ">=3.11"
requires-python = ">=3.12"
license = { file = "LICENSE" }
keywords = [
"machine learning",
Expand Down Expand Up @@ -54,6 +54,11 @@ pysis = [
"pysisyphus",
]

# TorchSim calculator integration
torchsim = [
"torch-sim-atomistic>=0.5.2,<0.6.0"
]

# Training dependencies
train = [
"omegaconf>=2.3.0",
Expand Down Expand Up @@ -187,6 +192,7 @@ disable_error_code = ["import-untyped"]
[tool.pytest.ini_options]
testpaths = ["tests"]
markers = [
"torch_sim: marks tests that require TorchSim (deselect with: -m 'not torch_sim')",
"ase: marks tests that require ASE (deselect with: -m 'not ase')",
"gpu: marks tests that require GPU/CUDA (deselect with: -m 'not gpu')",
"pysis: marks tests that require PySisyphus (deselect with: -m 'not pysis')",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# All tests in this module require ASE
pytestmark = pytest.mark.ase

MODELS = ("aimnet2", "aimnet2_b973c")
MODELS = ("aimnet2", "aimnet2_2025")
NSE_MODEL = "aimnet2nse"


Expand Down
Loading