From 74e7252ffcd70f47e8f91eb993be377eed53b8ae Mon Sep 17 00:00:00 2001 From: KuangYu Date: Tue, 7 Apr 2026 08:50:16 +0800 Subject: [PATCH 1/2] Add wrapper for NeighborlistNNPOps --- AGENTS.md | 232 ++++++++++++ dmff/common/constants.py | 2 + dmff/common/nblist.py | 164 ++++++++- dmff/mbar.py | 778 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 1156 insertions(+), 20 deletions(-) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..f886d3ef6 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,232 @@ +## DMFF Codebase High-Level Architecture Overview + +This document targets humans/agents who need to develop, refactor, test, or debug in the DMFF codebase. It summarizes the project purpose, core architecture, build and test flows, and configuration/security information that is explicitly present in the repository. + +> All descriptions here are derived from this repository itself (for example `README.md`, `docs/`, `dmff/`, `backend/`), without introducing external or generic development practices. + +--- + +### 1. Project Overview + +- **Goal and Scope** + - DMFF (Differentiable Molecular Force Field) is a JAX-based Python package that provides a fully differentiable implementation of molecular force-field models, enabling parameter optimization and efficient energy/force evaluation for systems such as water, biomacromolecules, organic polymers, and small organic molecules `README.md:5`, `docs/index.md:5`. + - It supports conventional point-charge models (OPLS/AMBER-like) and multipolar polarizable models (AMOEBA/MPID-like), and is designed to integrate modern machine-learning optimization techniques for automated parameterization and trajectory-based optimization `README.md:7-10`, `docs/dev_guide/introduction.md:11-19`. + +- **Core Layered Architecture (Python)** + - **Top-level package entry `dmff`** + - `dmff/__init__.py:1-6` re-exports key objects: + - Global settings: `dmff.settings` (numeric precision, JIT flag, debug flag); + - Neighbor-list utilities: `dmff.common.nblist.NeighborList` / `NeighborListFreud`; + - Force-field generators: `dmff.generators`; + - System Hamiltonian: `dmff.api.Hamiltonian`; + - Topology operators and MD tools: `dmff.operators`, `dmff.mdtools`. + - **API layer: Hamiltonian & topology** + - `dmff/api/__init__.py:1-2` exposes two core classes: `Hamiltonian` and `DMFFTopology`. + - `Hamiltonian` encapsulates total energy/force, etc.; + - `DMFFTopology` represents topology and parameter data for a system. + - **Generators and Calculators** + - `dmff/generators/__init__.py:1-4` aggregates the submodules `classical`, `admp`, `ml`, and `qeq`, each corresponding to a particular potential form. + - `docs/dev_guide/introduction.md:11-18` describes the division of responsibilities: + - `Generator` loads and organizes parameters from force-field XML files; + - **Calculators** are pure, heavy-duty functions that take atomic positions and force-field parameters as input and return energies; they are JAX-differentiable and JIT-compilable. + - **Runtime settings** + - `dmff/settings.py:3-7` defines: + - `PRECISION` (e.g. `'double'`) controlling JAX 64-bit precision; + - `DO_JIT` controlling whether to JIT-compile core computations; + - `DEBUG` controlling debug behavior; + - `update_jax_precision()` updates the global JAX `jax_enable_x64` flag at import time `dmff/settings.py:10-19`. + - **Operators pipeline** + - `dmff/operators/base.py:4-12` defines `BaseOperator`: + - `__call__` accepts a `DMFFTopology` and delegates to `operate`; + - subclasses in `dmff/operators/` implement topology/parameter transformations (e.g. typing, virtual sites, AM1 charges) in a pipeline-like fashion. + - **Neighbor list and backend acceleration** + - Python-level neighbor list: `dmff/common/nblist.py` (not expanded here, but exported in `dmff/__init__.py:2`). + - C++/CUDA backend: `dmff/dpnblist/` provides a high-performance neighbor-list library with CPU/GPU scheduling algorithms and doctest-based tests `dmff/dpnblist/tests/CMakeLists.txt:1-38`. + +- **ADMP and Classical force-field architecture (from docs)** + - `docs/assets/DMFF_arch.md:1-26` outlines a three-part architecture: + - **Parser & typification**: + - Input: force-field XML file; + - `parseElement` parses XML and builds **Generators**; + - `createPotential` produces an intermediate representation containing atomic/topological parameters. + - **Calculators layer**: + - ADMP: General Pairwise Calculator, Multipole PME Calculator, Dispersion PME Calculator; + - Classical: Intramolecular and Intermolecular calculators; + - All calculators expose a unified energy API `potential(pos, box, pairs, params)` which is JAX-differentiable. + - **Neighbor list and parameter coupling**: + - Neighbor pairs `pairs` come from a jax-md-style neighbor list (edge `J -> I` in the diagram); + - Generator outputs (differentiable parameters) feed into the calculators. + +- **OpenMM plugin backend** + - `backend/openmm_dmff_plugin/README.md:1-4` describes an OpenMM plugin that embeds a trained DMFF JAX model as an OpenMM `Force` for molecular dynamics. + - Installation requires `libtensorflow_cc`, `cppflow`, and CMake; the plugin builds `DMFFForce`-related kernels under `backend/openmm_dmff_plugin/openmmapi/` and `backend/openmm_dmff_plugin/platforms/` (file-level details are omitted here). + +- **Documentation and examples** + - Documentation: MkDocs-based site with navigation defined in `mkdocs.yml:1-27` and `docs/index.md:17-39`. + - Examples: `examples/` provides runnable examples for Classical, ADMP, MLForce, OpenMM plugin, DiffTraj, etc. `README.md:48-57`, `docs/user_guide/3.usage.md`. + +Currently there is no existing `AGENT.md` / `AGENTS.md`, and no `.cursor/rules/`, `.trae/rules/`, or `.github/copilot-instructions.md` files have been detected in this repository. + +--- + +### 2. Build & Commands + +This section lists only commands and tools that appear explicitly in the repository. + +- **Install from source (Python package)** + - Installing DMFF from source `docs/user_guide/2.installation.md:35-40`: + - `git clone https://github.com/deepmodeling/DMFF.git` + - `cd DMFF` + - `pip install . --user` + - Dependency installation (partial): + - Conda environment creation and installation of JAX, mdtraj, optax, jaxopt, pymbar, OpenMM, RDKit, etc. See `docs/user_guide/2.installation.md:3-33` for exact commands. + +- **Python tests (pytest)** + - The root `Makefile:1-31` defines per-module pytest targets, all using `pytest --disable-warnings`: + - `make test_admp` → `pytest --disable-warnings tests/test_admp`; + - `make test_classical` → `pytest --disable-warnings tests/test_classical`; + - `make test_common` → `pytest --disable-warnings tests/test_common`; + - `make test_difftraj` → `pytest --disable-warnings tests/test_difftraj`; + - `make test_dimer` → `pytest --disable-warnings tests/test_dimer`; + - `make test_frontend` → `pytest --disable-warnings tests/test_frontend`; + - `make test_mbar` → `pytest --disable-warnings tests/test_mbar`; + - `make test_sgnn` → `pytest --disable-warnings tests/test_sgnn`; + - `make test_energy` → `pytest --disable-warnings tests/test_energy.py`; + - `make test_utils` → `pytest --disable-warnings tests/test_utils.py`. + +- **C++ neighbor-list backend (dpnblist) build & tests** + - `dmff/dpnblist/CMakeLists.txt` (not expanded here) defines how to build the C++/CUDA neighbor-list library. + - Test executable `dpnblist_test` is defined in `dmff/dpnblist/tests/CMakeLists.txt:1-38`: + - `add_executable(dpnblist_test ...)` with multiple `test_*.cpp` sources; + - `find_package(doctest)` or a bundled `doctest.cmake` is used to enable doctest-based unit tests. + +- **OpenMM DMFF plugin build & tests** + - Environment setup and build commands are documented in `backend/openmm_dmff_plugin/README.md:9-56`: + - Install `python`, `openmm`, `cudatoolkit`, and `libtensorflow_cc` via conda; + - Download TensorFlow sources and copy `tensorflow/c` headers into the conda environment to satisfy `cppflow` requirements; + - Set `OPENMM_INSTALLED_DIR`, `CPPFLOW_INSTALLED_DIR`, `LIBTENSORFLOW_INSTALLED_DIR`; + - In `backend/openmm_dmff_plugin/build`, run `cmake .. -DOPENMM_DIR=... -DCPPFLOW_DIR=... -DTENSORFLOW_DIR=...` and then `make && make install && make PythonInstall`. + - Python-level plugin tests `backend/openmm_dmff_plugin/README.md:58-62`: + - `python -m OpenMMDMFFPlugin.tests.test_dmff_plugin_nve -n 100` + - `python -m OpenMMDMFFPlugin.tests.test_dmff_plugin_nvt -n 100 --platform CUDA` + +- **Docs development & preview (MkDocs)** + - Documentation framework: MkDocs `docs/dev_guide/write_docs.md:5`. + - Preview command `docs/dev_guide/write_docs.md:27-31`: + - In the directory containing `mkdocs.yml`, run `mkdocs serve` to start a local dev server with auto-reload. + +--- + +### 3. Code Style + +This section collects only style guidelines that are explicitly stated in the repository. + +- **Code organization** + - Root layout `docs/dev_guide/convention.md:10-15`: + - `dmff/`: project source code; + - `docs/`: Markdown documentation; + - `examples/`: standalone examples; + - `tests/`: unit and integration tests. + - Within `dmff/` `docs/dev_guide/convention.md:17-22`: + - `api.py`: API (frontend modules); + - `settings.py`: global settings; + - `utils.py`: basic utilities; + - each subdirectory corresponds to a potential form (e.g. `admp`, `classical`). + +- **Docstrings and comments** + - DMFF adopts **NumPy-style docstrings**: + - `docs/dev_guide/convention.md:24-27` states: + - methods and classes should use NumPy-style docstrings, combined with `typing` annotations, to support API documentation generation; + - an extended example is provided via the Napoleon NumPy-style sample `docs/dev_guide/convention.md:30-387`. + - Documentation system: MkDocs; authoring guidelines in `docs/dev_guide/write_docs.md`: + - new docs are added as Markdown files in the appropriate directories; + - images should be placed under `docs/assets/` and referenced via relative paths `docs/dev_guide/write_docs.md:21-23`. + +- **Language and dependencies** + - Python version and runtime dependencies: + - `setup.py:47-53` requires Python `~=3.8`, and `setup.py:21-30` lists core dependencies such as `numpy>=1.18`, `jax>=0.4.1`, `openmm>=7.6.0`, `freud-analysis`, `networkx>=3.0`, `optax>=0.1.4`, `jaxopt>=0.8.0`, `pymbar>=4.0.0`, and `tqdm`. + - Docs-related dependencies: `requirements.txt:1-15` lists documentation tooling (`mkdocs`, `mkdocs-autorefs`, `mkdocs-gen-files`, `mkdocs-literate-nav`, `mkdocstrings`, `mkdocstrings-python`, `pygments`) and runtime libraries (`jax`, `jaxlib`, `pymbar`, `rdkit`, `ase`). + +--- + +### 4. Testing + +DMFF uses both Python-level unit/integration tests and C++-level backend tests. + +- **Python test layout** + - All Python tests live under `tests/`, covering: frontend API, classical force fields (`tests/test_classical/`), ADMP module (`tests/test_admp/`), neighbor-list utilities (`tests/test_common/`), DiffTraj, MBAR, SGNN, EANN, and others (see the directory tree under `tests/`). + - The `Makefile:1-31` provides per-module pytest entry points, making it easy to run only a subset of tests. + - `tests/conftest.py:1` is currently empty; there are no repository-wide pytest fixtures or hooks defined. + +- **Installation sanity checks (Python)** + - User guide installation check `docs/user_guide/2.installation.md:42-52`: + - In an interactive Python session, import `dmff` and `dmff.admp` to ensure the package is available; + - run `examples/water_fullpol/run.py` to confirm example scripts execute successfully. + +- **C++ backend tests (dpnblist)** + - `dmff/dpnblist/tests/CMakeLists.txt:1-38`: + - defines the `dpnblist_test` executable combining multiple `test_*.cpp` files; + - uses doctest for unit tests; when an external doctest installation is not found, it pulls in `external/doctest-2.4.11` and uses `doctest.cmake`; + - `doctest_discover_tests` registers tests with CTest. + +- **OpenMM plugin tests** + - Python-level tests `backend/openmm_dmff_plugin/README.md:58-62`: + - `python -m OpenMMDMFFPlugin.tests.test_dmff_plugin_nve -n 100`; + - `python -m OpenMMDMFFPlugin.tests.test_dmff_plugin_nvt -n 100 --platform CUDA`. + - C++-level tests include `TestDMFFPlugin4CUDA.cpp` and `TestDMFFPlugin4Reference.cpp` with corresponding `CMakeLists.txt` under `backend/openmm_dmff_plugin/platforms/*/tests/`, used to validate force and energy consistency across platforms. + +--- + +### 5. Security + +The repository does not contain a dedicated security design document or explicit security policies. This section lists only direct, observable facts related to data and dependencies, without adding generic recommendations. + +- **Data and file types** + - Force fields and topologies: many XML and PDB files under `tests/data/` and `examples/` provide input for topology and parameter construction (`tests/data/*.xml`, `examples/*/*.xml`, `*.pdb`, etc.). + - Trained models and parameters: + - ML force-field parameters are stored in files such as `*.pickle` and `*.pt`, e.g. `examples/eann/eann_model.pickle`, `examples/sgnn/test_backend/model1.pth`, `tests/data/water_eann.pickle`; + - the OpenMM plugin uses `backend/save_dmff2tf.py` to export JAX models into a TensorFlow-compatible format consumed by the plugin `backend/openmm_dmff_plugin/README.md:4-5`. + +- **Dependencies and runtime environment** + - Core numerical dependencies include `jax`, `jaxlib`, `numpy`, `openmm`, `rdkit`, etc., with version constraints specified in `setup.py:21-30`, `requirements.txt:1-15`, and `docs/user_guide/2.installation.md:3-33`. + - The OpenMM plugin depends on `libtensorflow_cc` and `cppflow`; TensorFlow headers under `tensorflow/c` must be copied from upstream sources into the conda environment `backend/openmm_dmff_plugin/README.md:18-27`. + +- **Access control and encryption** + - No additional access-control, authentication, or encryption mechanisms are defined in this repository; + - Configuration example `config/freud.ini:2-41` is for a third-party tool (layout, key bindings, DB filename, color scheme) and is not coupled to DMFF’s internal numerical logic. + +More fine-grained security policies (data isolation, permission control, etc.) need to be handled by the systems that integrate DMFF; this repository itself does not impose additional constraints. + +--- + +### 6. Configuration & Environment + +- **Python environment and dependencies** + - Recommended setup in `docs/user_guide/2.installation.md:3-33`: + - create a conda environment named `dmff` (example uses Python 3.9); + - install specific versions of JAX (CPU or CUDA builds), mdtraj, optax, jaxopt, pymbar, OpenMM, RDKit, etc. + - Package-level dependencies are centralized in `setup.py:21-30` and `requirements.txt:1-15` to aid environment reproduction. + +- **Global numerical and debug settings** + - `dmff/settings.py:3-19` defines runtime configuration: + - `PRECISION`: controls whether JAX double precision is enabled (via `update_jax_precision` updating `jax_enable_x64`); + - `DO_JIT`: controls JIT compilation of core computations; + - `DEBUG`: toggles debug behavior; + - these are exported via `dmff/__init__.py:1` and can be imported and modified by user code. + +- **Documentation system configuration** + - `mkdocs.yml:1-47`: + - defines the site name (`DMFF`) and navigation (User Guide, Developer Guide, module docs, etc.); + - uses the `readthedocs` theme and `pymdownx.arithmatex` for math rendering; + - enables `gen-files`, `literate-nav`, and `mkdocstrings` plugins to generate API references and SUMMARY-based navigation. + +- **Docker environments (packaging & development)** + - `package/docker/develop_cpu.dockerfile` and `package/docker/develop_gpu.dockerfile` describe CPU/GPU development images (system dependencies plus Python environment) for reproducible development inside containers. + +- **External tool configuration example** + - `config/freud.ini:2-41` configures a third-party tool (likely an HTTP/requests inspector) with layout, key bindings, database filename, and style settings; it is independent of DMFF’s core simulation logic and can be treated as optional tooling. + +--- + +The above content is intended to let developers or agents grasp DMFF’s overall design, build process, and testing/configuration strategy without fully reading the code. For concrete implementation or refactoring tasks, combine this overview with the corresponding module-specific docs (for example `docs/dev_guide/` and `docs/user_guide/4.*.md`) and nearby tests to drive detailed understanding and validation. + diff --git a/dmff/common/constants.py b/dmff/common/constants.py index 286861a1f..538f63b8e 100644 --- a/dmff/common/constants.py +++ b/dmff/common/constants.py @@ -6,3 +6,5 @@ # units EV2KJ = 96.48530749925791 +A2NM = 0.1 + diff --git a/dmff/common/nblist.py b/dmff/common/nblist.py index ef52b5912..deff8b4e3 100644 --- a/dmff/common/nblist.py +++ b/dmff/common/nblist.py @@ -1,20 +1,32 @@ import numpy as np import jax.numpy as jnp from itertools import permutations +import warnings try: import freud import freud.box import freud.locality except ImportError: freud = None - import warnings warnings.warn("WARNING: freud not installed, users need to create neighbor list by themselves.") try: import dpnblist except ImportError: dpnblist = None - import warnings warnings.warn("WARNING: dpdpnblist not installed, users need to create neighbor list by themselves.") +try: + import torch +except ImportError: + torch = None +try: + from torch2jax import j2t, t2j +except ImportError: + j2t = None + t2j = None +try: + from NNPOps.neighbors import getNeighborPairs as nnpops_get_neighbor_pairs +except ImportError: + nnpops_get_neighbor_pairs = None class NeighborListDp: def __init__(self, alg_type, box, rcut, cov_map, padding=True): @@ -138,6 +150,154 @@ def scaled_pairs(self): def positions(self): return self._positions + +class NeighborListNNPOps: + def __init__(self, box, rcut, cov_map, padding=True, capacity_multiplier=1.3, sort=False): + if torch is None: + raise ImportError("torch not installed.") + if j2t is None or t2j is None: + raise ImportError("torch2jax not installed.") + if nnpops_get_neighbor_pairs is None: + raise ImportError("NNPOps not installed.") + self.box = box + self.rcut = rcut + self.capacity_multiplier = None + self.capacity_multiplier_scale = capacity_multiplier + self.padding = padding + self.cov_map = cov_map + self.sort = sort + + def _search_pairs(self, coords, box=None, max_num_pairs=1, check_errors=False): + coords_t = coords if torch.is_tensor(coords) else j2t(jnp.asarray(coords)) + if box is None: + box_t = self.box if torch.is_tensor(self.box) else j2t(jnp.asarray(self.box)) + else: + box_t = box if torch.is_tensor(box) else j2t(jnp.asarray(box)) + pairs, _, _, _ = nnpops_get_neighbor_pairs( + positions=coords_t, + cutoff=float(self.rcut), + max_num_pairs=int(max_num_pairs), + box_vectors=box_t, + check_errors=check_errors, + ) + return pairs.transpose(0, 1) + + def _canonicalize_pairs(self, pairs, natoms): + if pairs.numel() == 0: + return torch.empty((0, 2), dtype=torch.int32) + valid_mask = torch.logical_and(pairs[:, 0] >= 0, pairs[:, 1] >= 0) + pairs = pairs[valid_mask] + if pairs.numel() == 0: + return torch.empty((0, 2), dtype=torch.int32, device=valid_mask.device) + pairs = torch.stack( + [torch.minimum(pairs[:, 0], pairs[:, 1]), torch.maximum(pairs[:, 0], pairs[:, 1])], + dim=1, + ) + if self.sort: + keys = pairs[:, 0].to(torch.int64) * int(natoms) + pairs[:, 1].to(torch.int64) + order = torch.argsort(keys) + pairs = pairs[order] + unique_mask = torch.ones(pairs.shape[0], dtype=torch.bool, device=pairs.device) + unique_mask[1:] = torch.any(pairs[1:] != pairs[:-1], dim=1) + return pairs[unique_mask].to(torch.int32) + + def _build_pairs(self, real_pairs, natoms): + if not self.padding: + pairs_jax = t2j(real_pairs) + nbond = self.cov_map[pairs_jax[:, 0], pairs_jax[:, 1]] + return jnp.concatenate([pairs_jax, nbond[:, None]], axis=1) + + if self.capacity_multiplier is None: + raise RuntimeError("Neighbor list capacity is not initialized. Call allocate first.") + if real_pairs.shape[0] > self.capacity_multiplier: + raise ValueError( + f"NeighborListNNPOps capacity exceeded: found {real_pairs.shape[0]} pairs, " + f"but allocated capacity is {self.capacity_multiplier}." + ) + + pairs_jax = t2j(real_pairs) + if real_pairs.shape[0] > 0: + cov_jax = self.cov_map[pairs_jax[:, 0], pairs_jax[:, 1]] + pairs_jax = jnp.concatenate([pairs_jax, cov_jax[:, None]], axis=1) + else: + pairs_jax = jnp.zeros((0, 3), dtype=jnp.int32) + + padding_width = self.capacity_multiplier - real_pairs.shape[0] + if padding_width == 0: + return pairs_jax + + padding = jnp.zeros((padding_width, 3), dtype=jnp.int32) + padding = padding.at[:, 0].set(natoms) + padding = padding.at[:, 1].set(natoms) + return jnp.concatenate([pairs_jax, padding], axis=0) + + def _compute_capacity(self, npairs): + capacity = int(np.ceil(npairs * self.capacity_multiplier_scale)) + if npairs > 0: + return max(capacity, npairs) + return max(capacity, 1) + + def allocate(self, coords, box=None): + self._positions = coords + natoms = int(coords.shape[0]) + search_box = box if box is not None else self.box + total_possible_pairs = max(natoms * (natoms - 1) // 2, 1) + probe_capacity = max(1, natoms) + + while True: + try: + raw_pairs = self._search_pairs( + coords, + box=search_box, + max_num_pairs=min(probe_capacity, total_possible_pairs), + check_errors=True, + ) + break + except RuntimeError as exc: + if probe_capacity >= total_possible_pairs: + raise RuntimeError("Unable to allocate NNPOps neighbor list within the full pair bound.") from exc + probe_capacity = min(probe_capacity * 2, total_possible_pairs) + + real_pairs = self._canonicalize_pairs(raw_pairs, natoms) + if self.capacity_multiplier is None: + self.capacity_multiplier = self._compute_capacity(real_pairs.shape[0]) + elif self.padding: + self.capacity_multiplier = max(int(self.capacity_multiplier), real_pairs.shape[0]) + + self._pairs = self._build_pairs(real_pairs, natoms) + return self._pairs + + def update(self, positions, box=None): + self._positions = positions + natoms = int(positions.shape[0]) + search_box = box if box is not None else self.box + + if self.capacity_multiplier is None: + return self.allocate(positions, box=search_box) + + raw_pairs = self._search_pairs( + positions, + box=search_box, + max_num_pairs=self.capacity_multiplier, + check_errors=True, + ) + real_pairs = self._canonicalize_pairs(raw_pairs, natoms) + self._pairs = self._build_pairs(real_pairs, natoms) + return self._pairs + + @property + def pairs(self): + return self._pairs + + @property + def scaled_pairs(self): + return self._pairs + + @property + def positions(self): + return self._positions + + class NeighborList(NeighborListFreud): ... diff --git a/dmff/mbar.py b/dmff/mbar.py index 4e55e7957..e300220f8 100644 --- a/dmff/mbar.py +++ b/dmff/mbar.py @@ -1,17 +1,35 @@ import numpy as np +import warnings try: import mdtraj as md except ImportError: - import warnings warnings.warn("MDTraj not installed. MBAREstimator is not available.") try: from pymbar import MBAR except ImportError: MBAR = None - import warnings warnings.warn("MBAR not installed, MBAREstimator for multiple states is not available.") +try: + import base + from base.inference.calculator import BaseCalculator +except ImportError: + warnings.warn("base not installed, base-related functions are not available") + +try: + import ase + from ase import Atoms + from ase.io.lammpsdata import write_lammps_data +except ImportError: + warnings.warn("ASE not installed, related functions are not available") + +try: + import torch +except ImportError: + warnings.warn("torch not installed, related functions are not available") + + from .settings import update_jax_precision, PRECISION update_jax_precision(PRECISION) import jax @@ -22,6 +40,11 @@ import openmm.app as app import openmm.unit as unit from .common.nblist import NeighborListFreud +from .common.constants import EV2KJ, A2NM +from collections import defaultdict +import subprocess +from pathlib import Path +from .torch_tools import j2t_pytree def buildTrajEnergyFunction( @@ -78,22 +101,183 @@ def energy_function(traj, parameters): return energy_function +def buildFrameEnergyFunction( + potential_func, + cov_map=None, + cutoff=None, + builtin_nbl: bool = False, +): + """Wrap a DMFF potential into a frame-level energy function. + + This is the frame-based counterpart of :func:`buildTrajEnergyFunction`. + + Parameters + ---------- + potential_func : callable + Low-level DMFF potential with signature + ``potential_func(positions, box, pairs, parameters)``. + cov_map : array-like, optional + Covalent map used to annotate neighbor-list pairs; required when + ``builtin_nbl=Flase``. + cutoff : float, optional + Cutoff radius in nanometers for building the neighbor list; required + when ``builtin_nbl=False``. + builtin_nbl : bool, optional + If ``False``, build a neighbor list using :class:`NeighborListFreud` + from ``cov_map``, ``box``, and ``positions`` (typical for classical + force fields). If ``False``, pass ``pairs=None`` to the potential, + which is the usual case for ML potentials that manage neighbor lists + internally. + + Returns + ------- + energy_function : callable + A function ``energy_function(frame, parameters)`` that takes a + single mdtraj frame (``frame``) and a parameter dict, returning the + potential energy in kJ/mol. + """ + + if not builtin_nbl: + if cov_map is None or cutoff is None: + raise ValueError( + "cov_map and cutoff must be provided when builtin_nbl=True in buildFrameEnergyFunction." + ) + + def energy_function(frame, parameters): + # frame is an mdtraj.Trajectory with a single frame + box = jnp.array(frame.unitcell_vectors[0]) # (3, 3) in nm + positions = jnp.array(frame.xyz[0, :, :]) # (n_atoms, 3) in nm + nbobj = NeighborListFreud(box, cutoff, cov_map) + nbobj.capacity_multiplier = 1 + pairs = nbobj.allocate(positions) + return potential_func(positions, box, pairs, parameters) + + else: + + def energy_function(frame, parameters): + box = jnp.array(frame.unitcell_vectors[0]) + positions = jnp.array(frame.xyz[0, :, :]) + pairs = None + return potential_func(positions, box, pairs, parameters) + + return energy_function + + class TargetState: - def __init__(self, temperature, energy_function): + def __init__( + self, + temperature, + energy_function, + pressure: float = 0.0, + mu_dict=None, + legacy: bool = True, + ): + """State describing a target ensemble for reweighting. + + Parameters + ---------- + temperature : float + Temperature in Kelvin. + energy_function : callable + If ``legacy=True`` (default), this is a trajectory-level function + ``energy_function(traj, parameters)`` that returns a list or array + of potential energies (in kJ/mol) for each frame in a trajectory. + This is the behavior relied on by existing code and tests. + + If ``legacy=False``, this is a frame-level function + ``energy_function(frame, parameters)`` that operates on individual + mdtraj frames. In that mode, PV and μ·N contributions are handled + in :meth:`calc_energy`. + pressure : float, optional + Pressure in bar. Used only when ``legacy=False``. + mu_dict : dict, optional + Mapping atom names to chemical potentials. Used only when + ``legacy=False``. + legacy : bool, optional + When True, preserve the original trajectory-level behavior. When + False, use the frame-level API compatible with finetune/ft.py. + """ + self._temperature = temperature self._efunc = energy_function + self._pressure = pressure + self._mu_dict = mu_dict if isinstance(mu_dict, dict) else defaultdict(float) + self._legacy = legacy def calc_energy(self, trajectory, parameters): beta = 1.0 / self._temperature / 8.314 * 1000.0 - eners = self._efunc(trajectory, parameters) - ulist = jnp.concatenate([beta * e.reshape((1,)) for e in eners]) - return ulist + + if self._legacy: + # Original behavior: energy_function operates on the full + # trajectory and returns per-frame energies. PV and μ·N are + # assumed to be handled inside the energy_function. + eners = self._efunc(trajectory, parameters) + ulist = jnp.concatenate([beta * e.reshape((1,)) for e in eners]) + return ulist + + # New behavior: energy_function operates on individual frames. + # trajectory is assumed to be an mdtraj.Trajectory. + # + # Note: self._efunc may return a scalar or a length-1 array for each + # frame. We therefore flatten to a 1D array of length n_frames to + # maintain compatibility with downstream code and tests. + eners = jnp.array([ + self._efunc(trajectory[i : i + 1], parameters) + for i in range(trajectory.n_frames) + ]).reshape((trajectory.n_frames,)) + + # PV term using unitcell_volumes (nm^3) + if hasattr(trajectory, "unitcell_volumes"): + eners = eners + 0.06023 * self._pressure * jnp.array( + trajectory.unitcell_volumes + ) + + # μ·N term from topology atom names + if hasattr(trajectory, "topology") and hasattr(trajectory.topology, "atoms"): + mu_contrib = jnp.sum(jnp.array([ + self._mu_dict[a.name] for a in trajectory.topology.atoms + ])) + # Grand-canonical reduced potential: u = β(E + PV - μN). + # Therefore the μN contribution enters with a minus sign. + eners = eners - mu_contrib + + eners = eners * beta + return eners class SampleState: - def __init__(self, temperature, name): + def __init__( + self, + temperature, + name, + pressure: float = 0.0, + mu_dict=None, + legacy: bool = True, + ): + """Base class for sampling states. + + Parameters + ---------- + temperature : float + Temperature in Kelvin. + name : str + Name identifying this state (used for MBAR bookkeeping). + pressure : float, optional + Pressure in bar. Used only when ``legacy=False``. + mu_dict : dict, optional + Mapping atom names to chemical potentials. Used only when + ``legacy=False``. + legacy : bool, optional + When True, preserve the original behavior where PV and μ·N are + handled in :meth:`calc_energy_frame` or elsewhere. When False, use + a frame-level API that mirrors finetune/ft.py semantics. + """ + self._temperature = temperature self.name = name + self._pressure = pressure + self._mu_dict = mu_dict if isinstance(mu_dict, dict) else defaultdict(float) + self._legacy = legacy def calc_energy_frame(self, frame): return 0.0 @@ -101,11 +285,55 @@ def calc_energy_frame(self, frame): def calc_energy(self, trajectory): # return beta * u beta = 1.0 / self._temperature / 8.314 * 1000.0 - eners = [] - for frame in tqdm(trajectory): - e = self.calc_energy_frame(frame) - eners.append(e * beta) - return jnp.array(eners) + + if self._legacy: + # Original behavior: subclasses (e.g. OpenMMSampleState) are + # expected to include PV contributions in calc_energy_frame. + eners = [] + for frame in tqdm(trajectory): + e = self.calc_energy_frame(frame) + eners.append(e * beta) + return jnp.array(eners) + + # New behavior: frame-level potential + PV and μ·N handled here. + eners = jnp.array([ + self.calc_energy_frame(frame) for frame in tqdm(trajectory) + ]) + + if hasattr(trajectory, "unitcell_volumes"): + eners = eners + 0.06023 * self._pressure * jnp.array( + trajectory.unitcell_volumes + ) + + if hasattr(trajectory, "topology") and hasattr(trajectory.topology, "atoms"): + mu_contrib = jnp.sum(jnp.array([ + self._mu_dict[a.name] for a in trajectory.topology.atoms + ])) + # Grand-canonical reduced potential: u = β(E + PV - μN). + # Therefore the μN contribution enters with a minus sign. + eners = eners - mu_contrib + + eners = eners * beta + return eners + + def sample(self, *args, **kwargs): + """Run MD and generate a :class:`Sample`. + + Subclasses representing concrete simulators (for example + :class:`OpenMMSampleState`, :class:`ASENNPNPTSampleState`, or + :class:`LammpsNNPNPTSampleState`) should override this method. + """ + + raise NotImplementedError + + def update_parameters(self, *args, **kwargs): + """Update internal model parameters of the state. + + Concrete subclasses that wrap ML potentials are expected to override + this (for example, to reload a checkpoint). + """ + + raise NotImplementedError class OpenMMSampleState(SampleState): @@ -120,9 +348,20 @@ def __init__( useSwitchingFunction=False, platform="CPU", properties={}, + legacy: bool = True, + mu_dict=None, **args ): - super(OpenMMSampleState, self).__init__(temperature, name) + # ``legacy`` retains the original behavior used in the tests by + # default. When ``legacy=False``, PV and μ·N are handled in + # :meth:`SampleState.calc_energy` instead. + super(OpenMMSampleState, self).__init__( + temperature=temperature, + name=name, + pressure=pressure, + mu_dict=mu_dict, + legacy=legacy, + ) self._pressure = pressure # create a context pdb = app.PDBFile(topology) @@ -157,13 +396,398 @@ def calc_energy_frame(self, frame): self.ctx.setPeriodicBoxVectors(*frame.openmm_boxes(0)) state = self.ctx.getState(getEnergy=True) vol = frame.unitcell_volumes[0] # in nm^3 - ener = ( - state.getPotentialEnergy().value_in_unit(unit.kilojoule_per_mole) - + 0.06023 * vol * self._pressure - ) + ener = state.getPotentialEnergy().value_in_unit(unit.kilojoule_per_mole) + # For legacy mode, include the PV contribution here to keep the + # original behavior expected by existing tests. For non-legacy mode, + # PV is added in :meth:`SampleState.calc_energy`. + if self._legacy: + ener = ener + 0.06023 * vol * self._pressure return ener +class ASENNPNPTSampleState(SampleState): + """NPT sampling state driven by an ASE calculator wrapping a DMFF model. + + This class mirrors :class:`SampleState` and ``ASENNPNPTSampleState`` in + ``finetune/ft.py``, but is integrated into the core MBAR module. + """ + + def __init__( + self, + temperature, + name, + init_model_path, + config_path, + ffname, + e_eval_loader, + cutoff=0.5, + pressure=0.0, + ): + # Use non-legacy mode so that PV and μ·N are handled in + # SampleState.calc_energy, matching finetune semantics. + super().__init__( + temperature=temperature, + name=name, + pressure=pressure, + mu_dict=0.0, + legacy=False, + ) + self.ffname = ffname + self.e_eval_loader = e_eval_loader + self.config_path = config_path + # cutoff is provided in Angstrom; convert to nm + self.cutoff = cutoff / A2NM + self.e_eval = self.e_eval_loader(init_model_path, self.config_path, self.cutoff) + + def calc_energy_frame(self, frame): + from ase import Atoms + + atoms = Atoms( + numbers=[a.element.atomic_number for a in frame.topology.atoms], + cell=frame.unitcell_vectors.reshape(3, 3) / A2NM, + positions=frame.xyz.reshape(-1, 3) / A2NM, + pbc=[1, 1, 1], + ) + atoms.calc = self.e_eval + energy = atoms.get_potential_energy() + return energy * EV2KJ + + def load_init_struct(self, init_atoms, top_file_name): + self.init_atoms = init_atoms + self.top_file_name = top_file_name + + def sample(self, nsteps, interval, skip, timestep, ttime, pfactor, file_name): + """Run ASE NPT MD and return a :class:`Sample`. + + The interface matches ``ASENNPNPTSampleState.sample`` in + ``finetune/ft.py``. + """ + + from copy import deepcopy + from collections.abc import Sequence + from ase.md.velocitydistribution import MaxwellBoltzmannDistribution + from ase.md.npt import NPT + from ase import units + + class TrajectoryObserver(Sequence): + def __init__(self, atoms): + self.atoms = atoms + self.xyz = [] + self.unitcell_lengths = [] + self.unitcell_angles = [] + + def __call__(self): + cell = self.atoms.get_cell() + self.xyz.append(self.atoms.get_positions()) + self.unitcell_lengths.append(cell.lengths()) + self.unitcell_angles.append(cell.angles()) + + def __getitem__(self, item): + return ( + self.xyz[item], + self.unitcell_lengths[item], + self.unitcell_angles[item], + ) + + def __len__(self): + return len(self.xyz) + + def save(self, top_file_name, filename, skip): + traj = md.Trajectory( + xyz=self.xyz[skip:], + topology=md.load_topology(top_file_name), + unitcell_lengths=self.unitcell_lengths[skip:], + unitcell_angles=self.unitcell_angles[skip:], + ) + traj.save_dcd(filename) + + init_atoms = deepcopy(self.init_atoms) + init_atoms.calc = self.e_eval + MaxwellBoltzmannDistribution(init_atoms, temperature_K=self._temperature) + dyn = NPT( + init_atoms, + timestep=timestep * units.fs, + temperature_K=self._temperature, + externalstress=self._pressure * units.bar, + ttime=ttime * units.fs, + pfactor=pfactor * units.GPa * (units.fs**2), + logfile=None, + loginterval=1, + ) + obs = TrajectoryObserver(init_atoms) + dyn.attach(obs, interval=interval) + dyn.run(nsteps) + obs.save(self.top_file_name, file_name, skip=int(skip / interval)) + traj = md.load_dcd(file_name, top=self.top_file_name) + traj.xyz = jnp.array(traj.xyz).astype(jnp.float32) * A2NM + traj.unitcell_lengths = jnp.array(traj.unitcell_lengths).astype(jnp.float32) * A2NM + traj.unitcell_angles = jnp.array(traj.unitcell_angles).astype(jnp.float32) + return Sample(traj, from_state=self.name) + + def output_parameters(self, H, params, state_dict_name): + for g in H.getGenerators(): + if g.getName() == self.ffname: + generator = g + break + else: + raise ValueError(f"Generator with name {self.ffname} not found in Hamiltonian.") + generator.write_to(params=params, state_dict_file=state_dict_name) + + def update_parameters(self, ckpt_name): + self.e_eval = self.e_eval_loader(ckpt_name, self.config_path, self.cutoff) + + +class LammpsNNPNPTSampleState(SampleState): + """NPT sampling state driven by a LAMMPS runner and DMFF ML potential.""" + + def __init__( + self, + temperature, + name, + ckpt_path, + config_path, + ffname, + e_eval_loader, + cutoff=0.5, + pressure=0.0, + ): + # Use non-legacy mode for PV and μ·N handling. + super().__init__( + temperature=temperature, + name=name, + pressure=pressure, + mu_dict=0.0, + legacy=False, + ) + self.ffname = ffname + self.e_eval_loader = e_eval_loader + self.config_path = config_path + self.cutoff = cutoff / A2NM + self.ckpt_path = ckpt_path + self.e_eval = self.e_eval_loader(ckpt_path, self.config_path, self.cutoff) + + def calc_energy_frame(self, frame): + from ase import Atoms + + atoms = Atoms( + numbers=[a.element.atomic_number for a in frame.topology.atoms], + cell=frame.unitcell_vectors.reshape(3, 3) / A2NM, + positions=frame.xyz.reshape(-1, 3) / A2NM, + pbc=[1, 1, 1], + ) + atoms.calc = self.e_eval + energy = atoms.get_potential_energy() + return energy * EV2KJ + + def load_init_struct(self, init_atoms_lammps_data, ele_list, top_file_name): + self.init_atoms_lammps_data = init_atoms_lammps_data + self.ele_list = ele_list + self.top_file_name = top_file_name + + def load_md_runner(self, run_md): + """Attach a callable that runs LAMMPS given an input script path.""" + + self.run_md = run_md + + def sample(self, nsteps, interval, skip, timestep, ttime, ptime, file_name): + """Run LAMMPS NPT MD via an external runner and return a Sample.""" + + from pathlib import Path + import MDAnalysis as mda + + lines = [ + "units metal", + "dimension 3", + "boundary p p p", + "atom_style atomic", + "", + f"read_data {self.init_atoms_lammps_data}", + "", + f"pair_style base {self.cutoff}", + f"pair_coeff * * {self.ckpt_path} " + " ".join(self.ele_list), + "", + f"timestep {timestep * 1e-3}", # metal uses ps + f"velocity all create {self._temperature} 12345 loop all mom yes rot no dist gaussian", + f"fix npt_fix all npt temp {self._temperature} {self._temperature} {ttime * 1e-3} tri {self._pressure} {self._pressure} {ptime * 1e-3}", + "", + "thermo 100", + "thermo_style custom step temp pe ke etotal press vol", + f"run {skip}", + f"dump dcd_dump all dcd {interval} {file_name}", + "dump_modify dcd_dump sort id", + f"run {nsteps - skip}", + ] + in_lammps_dir = str(Path(self.ckpt_path).parent) + in_lammps = in_lammps_dir + "/in.lammps" + with open(in_lammps, "w") as f: + for line in lines: + f.write(line + "\n") + + # Delegate actual MD execution to the provided runner + self.run_md(in_lammps) + + # Use MDAnalysis to read .dcd and build mdtraj.Trajectory to prevent + # automatic unit changes when reading directly with mdtraj + u = mda.Universe(self.top_file_name, file_name) + xyz, unitcell_lengths, unitcell_angles = [], [], [] + for frame in u.trajectory: + xyz.append(frame.positions) + unitcell_lengths.append(frame.dimensions[:3]) + unitcell_angles.append(frame.dimensions[3:]) + traj = md.Trajectory( + xyz=jnp.array(xyz).astype(jnp.float32), + topology=md.load_topology(self.top_file_name), + unitcell_lengths=jnp.array(unitcell_lengths).astype(jnp.float32), + unitcell_angles=jnp.array(unitcell_angles).astype(jnp.float32), + ) + traj.save_dcd(file_name) + traj = md.load_dcd(file_name, top=self.top_file_name) + traj.xyz = jnp.array(traj.xyz).astype(jnp.float32) * A2NM + traj.unitcell_lengths = jnp.array(traj.unitcell_lengths).astype(jnp.float32) * A2NM + traj.unitcell_angles = jnp.array(traj.unitcell_angles).astype(jnp.float32) + return Sample(traj, from_state=self.name) + + def output_parameters(self, H, params, state_dict_name): + for g in H.getGenerators(): + if g.getName() == self.ffname: + generator = g + break + else: + raise ValueError(f"Generator with name {self.ffname} not found in Hamiltonian.") + generator.write_to(params=params, state_dict_file=state_dict_name) + + def update_parameters(self, ckpt_name): + self.e_eval = self.e_eval_loader(ckpt_name, self.config_path, self.cutoff) + + + +class LmpsBaseSampleState(SampleState): + + def __init__( + self, + temperature, + name, + ckpt_path, + top_file, + config_path=None, + cutoff=0.5, + pressure=0.0, + ): + super().__init__( + temperature=temperature, + name=name, + pressure=pressure, + mu_dict=0.0, + legacy=False, + ) + self.ffname = 'BASEForce' + self.cutoff = cutoff / A2NM # self.cutoff is in A + self.ckpt_path = ckpt_path + self.config_path = config_path + self.ase_calculator = BaseCalculator(self.ckpt_path) + self.top_file = top_file + self.init_frame = md.load(top_file) + # initialize ASE atoms objects, used for frame energy calculation + self.atoms = Atoms( + numbers=[a.element.atomic_number for a in self.init_frame.topology.atoms], + cell=self.init_frame.unitcell_vectors.reshape(3, 3) / A2NM, + positions=self.init_frame.xyz.reshape(-1, 3) / A2NM, + pbc=[1, 1, 1], + ) + self.rundir = str(Path(self.ckpt_path).parent) + self.atoms.calc = self.ase_calculator + self.init_lmpsdata_file = self.rundir + '/init_struct.data' + # write out the init lammps data file + self.update_init_lmpsdata(self.init_frame) + self.elements = sorted(set(self.atoms.get_chemical_symbols())) + return + + def update_atoms_geom(self, frame): + """ Update the geometry in the self.atoms ASE object """ + # trajectory should be in nm, convert to A to be compatible with ase + self.atoms.positions = frame.xyz / A2NM + self.atoms.set_cell(frame.unitcell_vectors[0] / A2NM) + return + + def update_init_lmpsdata(self, frame): + """ Update the lammps initial data file """ + self.update_atoms_geom(frame) + # ase.io.write(self.init_lmpsdata_file, self.atoms, format='lammps-data') + write_lammps_data(self.init_lmpsdata_file, atoms=self.atoms, masses=True, atom_style='atomic', units='metal') + return + + def calc_energy_frame(self, frame): + self.update_atoms_geom(frame) + energy = self.atoms.get_potential_energy() + return energy * EV2KJ + + def sample(self, nsteps, interval, skip, timestep, ttime, ptime, traj_file): + """Run LAMMPS NPT MD via an external runner and return a Sample.""" + + from pathlib import Path + import MDAnalysis as mda + + lines = [ + "units metal", + "dimension 3", + "boundary p p p", + "atom_style atomic", + "", + f"read_data {self.init_lmpsdata_file}", + "", + f"pair_style base {self.cutoff}", + f"pair_coeff * * {self.ckpt_path} " + " ".join(self.elements), + "", + f"timestep {timestep * 1e-3}", # metal uses ps + f"velocity all create {self._temperature} 12345 loop all mom yes rot no dist gaussian", + f"fix npt_fix all npt temp {self._temperature} {self._temperature} {ttime * 1e-3} tri {self._pressure} {self._pressure} {ptime * 1e-3}", + "", + "thermo 100", + "thermo_style custom step temp pe ke etotal press vol", + f"run {skip}", + f"dump dcd_dump all dcd {interval} {traj_file}", + "dump_modify dcd_dump sort id", + f"run {nsteps - skip}", + ] + in_lammps = self.rundir + "/in.lammps" + with open(in_lammps, "w") as f: + for line in lines: + f.write(line + "\n") + + # run md + subprocess.run([f'lmp_bamboo_v100 -k on g 1 -sf kk -in {in_lammps}'], shell=True) + + # collect trajectory, md_traj load everything in nm + traj = md.load_dcd(traj_file, top=self.top_file) + return Sample(traj, from_state=self.name) + + def update_parameters(self, params, ckpt_path=None): + """ Update model parameters """ + # params is in jax, convert to torch + if self.ffname in params: + params_t = j2t_pytree(params[self.ffname]) + else: + params_t = j2t_pytree(params) + self.ase_calculator.model.load_state_dict(params_t, strict=False) + + # update the ckpt file used by lammps + if ckpt_path is not None: + self.ckpt_path = ckpt_path + torch.jit.save(self.ase_calculator.model, self.ckpt_path) + return + + def write_state_dict(self, params, state_dict_file): + # update the state of the internal model + if self.ffname in params: + params_t = j2t_pytree(params[self.ffname]) + else: + params_t = j2t_pytree(params) + sd = self.ase_calculator.model.state_dict() + for p in params_t.keys(): + sd[p] = params_t[p] + torch.save(sd, state_dict_file) + return + class Sample: def __init__(self, trajectory, from_state): self.trajectory = trajectory @@ -178,7 +802,7 @@ def generate_energy(self, state_list): ) -class ReweightEstimator: +class LegacyReweightEstimator: def __init__( self, ref_energies, @@ -208,6 +832,124 @@ def estimate_weight(self, uinit): return weight +class ReweightEstimator: + """Single-sample, single-state reweighting helper. + + This implementation mirrors :class:`ReweightEstimator` in + ``finetune/ft.py``, with two intentional differences: + + - ``base_energies`` is provided at :meth:`estimate_weight` time instead of + construction time; + - ``base_energies`` is treated as an energy correction in kJ/mol applied + to the *target* state before β-scaling (i.e. we add ``beta * + base_energies`` to the reduced target energies). + """ + + def __init__(self): + self.sample = None + self.state = None + + def set_sample_and_state(self, sample, state): + """Attach the reference sample and its generating state. + + Parameters + ---------- + sample : Sample + A :class:`Sample` instance holding the reference trajectory and + energy cache. + state : SampleState + The sampling :class:`SampleState` that generated ``sample``. + """ + # Sanity check: the sample must originate from the provided state. + if getattr(sample, "from_state", None) != getattr(state, "name", None): + raise ValueError( + "Sample.from_state must match state.name in ReweightEstimator. " + f"Got from_state={getattr(sample, 'from_state', None)!r}, " + f"state.name={getattr(state, 'name', None)!r}." + ) + self.sample = sample + self.state = state + # Ensure reference energies are available + self.compute_energy_matrix() + + def remove_sample_and_state(self): + """Detach any currently stored sample/state pair.""" + + self.sample = None + self.state = None + + def compute_energy_matrix(self): + """Populate ``sample.energy_data`` for the current reference state.""" + + if self.sample is None or self.state is None: + raise ValueError("Sample and state must be set before computing energies.") + self.sample.generate_energy([self.state]) + + def estimate_weight(self, target_state, params, base_energies=0.0, calc_uref=False): + """Estimate reweighting factors for a new target state. + + Parameters + ---------- + target_state : TargetState + Target ensemble description; its :meth:`calc_energy` method is + expected to return reduced potentials ``u_new = beta * (E + PV - μN)``. + params : PyTree + Parameter tree passed through to ``target_state.calc_energy``. + base_energies : float or array-like, optional + Energy correction(s) in kJ/mol applied to the target state before + β-scaling. Conceptually, if ``E_new`` is the uncorrected target + energy, we are using ``E_new + base_energies`` for reweighting. + + Returns + ------- + weight : jax.numpy.ndarray + Normalized reweighting factors for each frame in the sample + trajectory. + """ + + if self.sample is None or self.state is None: + raise ValueError("Sample and state must be set before estimating weights.") + + # Sanity check: enforce identical temperatures between sampling and + # target states to avoid mixing ensembles. + if not hasattr(target_state, "_temperature") or not hasattr( + self.state, "_temperature" + ): + raise ValueError( + "Both target_state and sample state must expose '_temperature' " + "for ReweightEstimator." + ) + if abs(target_state._temperature - self.state._temperature) > 1e-6: + raise ValueError( + "TargetState temperature must match SampleState temperature: " + f"target={target_state._temperature}, sample={self.state._temperature}." + ) + + # Ensure reference energies are available + if calc_uref: + self.compute_energy_matrix() + + # Target reduced potential u_new (shape: n_frames,) + unew = target_state.calc_energy(self.sample.trajectory, params) + + # Recover beta from the target state's temperature and convert the + # base_energies correction (kJ/mol) into reduced units. + beta = 1.0 / target_state._temperature / 8.314 * 1000.0 + base = jnp.array(base_energies) + unew = unew + beta * base + + # Reference reduced potential u_ref from the sampling state + uref = self.sample.energy_data[self.state.name].flatten() + + deltaU = unew - uref + deltaU = deltaU - deltaU.max() + weight = jnp.exp(-deltaU) + # weight = weight / weight.mean() + weight = weight / jnp.sum(weight) + kappa = jnp.exp(-jnp.sum(weight * jnp.log(weight))) / len(weight) + return weight, kappa + + class MBAREstimator: def __init__(self): self.samples = [] From 29f62a3a63b972ed41c20e6ff74f7668917fcd06 Mon Sep 17 00:00:00 2001 From: KuangYu Date: Tue, 7 Apr 2026 09:00:40 +0800 Subject: [PATCH 2/2] fix the torch dependency issue --- dmff/mbar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dmff/mbar.py b/dmff/mbar.py index e300220f8..05aa9cc9d 100644 --- a/dmff/mbar.py +++ b/dmff/mbar.py @@ -26,6 +26,7 @@ try: import torch + from .torch_tools import j2t_pytree except ImportError: warnings.warn("torch not installed, related functions are not available") @@ -44,7 +45,6 @@ from collections import defaultdict import subprocess from pathlib import Path -from .torch_tools import j2t_pytree def buildTrajEnergyFunction(