Setting up and running a calculation
import numpy as onp
from msmjax.calculators import set_up_msm_params
msm_params = set_up_msm_params(
cell=onp.diag([10., 10., 10.]),
level_one_spacings=onp.array([1., 1., 1.]),
level_zero_cutoff=3.0,
p=4,
pbc=(False, False, False),
cell_mode="ortho",
dynamic_cell=False,
)
from msmjax.calculators import create_msm
evaluation_fns = create_msm(msm_params)
# jit for performance
eval_energy_and_forces = jax.jit(
evaluation_fns["energy_and_forces"]
)
energy, forces = eval_energy_and_forces(
positions, charges
)
from msmjax.calculators import MSMParams
msm_params.save_json("params.json")
msm_params_restored = MSMParams.load_json("params.json")