msmjax.utils.benchmarking
Utilities for benchmarking
- msmjax.utils.benchmarking.dir_context(dir_name)[source]
Create a context to run code in a different directory.
- Parameters:
dir_name (
Path) – Route to the directory.
- msmjax.utils.benchmarking.write_lammps_data(filename, cell, positions, charges)[source]
Write structure as LAMMPS data file
- Return type:
None
- msmjax.utils.benchmarking.parse_energy_from_lammps_log(filename)[source]
Get the energy from LAMMPS log file
- Return type:
float
- msmjax.utils.benchmarking.parse_lammps_log(filename)[source]
Get energy and other results from LAMMPS log file
- Return type:
tuple[float,ndarray]
- msmjax.utils.benchmarking.make_lammps_input_text_pppm(filename_data, filename_dump, accuracy, max_neighbors_one_atom)[source]
Write LAMMPS input script
- Parameters:
accuracy (float)
max_neighbors_one_atom (int | None)
- msmjax.utils.benchmarking.eval_lammps_pppm(positions, charges, cell, lammps_executable='lmp', accuracy=1e-05, max_neighbors_one_atom=None, show_stdout=False)[source]
Wrapper to compute periodic electrostatic energy, forces in LAMMPS with p3m
- Parameters:
positions (
numpy.typing.ArrayLike) – Array of article positions, shape (n_particles, 3)charges (
numpy.typing.ArrayLike) – Array of particle charges, (n_particles,)cell (
numpy.typing.ArrayLike) – Unit celllammps_executable (
str) – Path to LAMMPS executableaccuracy (
float) – Accuracy setting passed to LAMMPS via kspace_style pppm {accuracy}.max_neighbors_one_atom (
int|None) – Maximum number of neighbors of a single atom. Supplied to LAMMPS via neigh_modify one if given. You may need to increase this value if you’re getting errors.show_stdout – Whether to show the stdout from the subprocess call to LAMMPS.
- Return type:
tuple[float,ndarray,ndarray,ndarray]- Returns:
energy, forces
- msmjax.utils.benchmarking.calc_rmse(y_pred, y_true)[source]
Calculate RMSE.
- Return type:
Array- Parameters:
y_pred (Array | ndarray | bool_ | number | bool | int | float | complex)
y_true (Array | ndarray | bool_ | number | bool | int | float | complex)
- msmjax.utils.benchmarking.calc_relative_rmse_percent(y_pred, y_true)[source]
Calculate RMSE normalized to STD of reference results, as percentage.
- Return type:
Array- Parameters:
y_pred (Array | ndarray | bool_ | number | bool | int | float | complex)
y_true (Array | ndarray | bool_ | number | bool | int | float | complex)
- msmjax.utils.benchmarking.calc_relative_rmse(y_pred, y_true)[source]
Calculate RMSE normalized to STD of reference results.
- Return type:
Array- Parameters:
y_pred (Array | ndarray | bool_ | number | bool | int | float | complex)
y_true (Array | ndarray | bool_ | number | bool | int | float | complex)
- msmjax.utils.benchmarking.make_timed_eval(fn, repeat=10, number=100)[source]
Given a JAX-jittable function, return a function to time its evaluation.
The returned function internally takes care of jitting the function to be timed, and of calling
.block_until_ready()on its output to avoid meaningless results because of JAX’s asynchronous dispatch.- Parameters:
fn (
Callable) – A JAX-jittable function.repeat (
int) – Passed totimeit.repeat.number (
int) – Passed totimeit.repeat.
- Return type:
Callable- Returns:
A function with the same argument structure as
fnthat, when called, returns a 2-element tuple containing:The measured evaluation time
The original output of
fn.
- msmjax.utils.benchmarking.remove_duplicates_from_neighborlist(neighborlist, fill_value, size)[source]
Remove duplicate index pairs from a neighbor list.
- Parameters:
neighborlist (
Sequence[jax.typing.ArrayLike]) – Original neighbor list that may contain duplicates. (like containing both index pair ij and ji). Tuple of two 1-d integer arrays of the same size. Seemsmjax.utils.benchmarking()for more details on format.fill_value (
int) – A value used to pad the arrays in the neighbor list to lengthsize. Should be a value that is ignored by the evaluation function to which the neighbor list will be passed.size (
int) – The number of index pairs (including placeholder values that containfill_value) in the returned duplicate-free neighbor list. Must be known in advance to satisfy static-shape constraints under JIT.
- Return type:
tuple[Array,Array]- Returns:
The neighbor list with duplicate index pairs removed (i.e., containing only ij and not ji).
- msmjax.utils.benchmarking.build_duplicate_free_neighborlists(set_of_positions, set_of_cells, cutoff, pbc)[source]
Construct a neighbor list with no duplicate index pairs.
First uses matscipy for constructing the initial neighbor list, which does contain duplicates, then JAX is ued to remove the duplicates. The function processes several structures and pads the neighbor lists to the maximum neighbor-list size out of all these structures, in order to gain efficiency by avoiding re-compilation of the JAX duplicate-removal step.
- Parameters:
set_of_positions (
Sequence[jax.typing.ArrayLike]) – Sequence of particle positions (arrays of shape (n_particles, n_dim)), each corresponding to one structure.set_of_cells (
Sequence[jax.typing.ArrayLike]) – Sequence of cells (arrays of shape (n_dim, n_dim)), each corresponding to one structure.cutoff (
float) – Cutoff radius for neighbor search.pbc (
Sequence[bool]) – One boolean per direction signaling periodicity.
- Return type:
list[tuple[Array,Array]]- Returns:
List of duplicate-free neighbor lists for all structures.
- msmjax.utils.benchmarking.calc_energy_nonperiodic_direct_summation(positions, charges)[source]
Calculate exact energy for no periodicity by all-pairs sum
- Parameters:
positions (
jax.typing.ArrayLike) – Array of particle positions, shape (n_particles, n_dim)charges (
jax.typing.ArrayLike) – Array of particle charges, shape (n_particles,)
- Return type:
Array- Returns:
Energy
- msmjax.utils.benchmarking.calc_forces_nonperiodic_direct_summation(positions, charges)[source]
Calculate exact forces for no periodicity by all-pairs sum
- Parameters:
positions (
jax.typing.ArrayLike) – Array of particle positions, shape (n_particles, n_dim)charges (
jax.typing.ArrayLike) – Array of particle charges, shape (n_particles,)
- Return type:
Array- Returns:
Forces
- msmjax.utils.benchmarking.calc_chargegrad_nonperiodic_direct_summation(positions, charges)[source]
Calculate exact charge gradient for no periodicity by all-pairs sum
- Parameters:
positions (
jax.typing.ArrayLike) – Array of particle positions, shape (n_particles, n_dim)charges (
jax.typing.ArrayLike) – Array of particle charges, shape (n_particles,)
- Return type:
Array- Returns:
Gradient of the energy w.r.t. particle charges
- msmjax.utils.benchmarking.calc_stress_virial(positions, forces, cell)[source]
Calculate stress tensor from virial
- Parameters:
positions (
jax.typing.ArrayLike) – Array of particle positions, shape (n_particles, n_dim)forces (
jax.typing.ArrayLike) – Array of forces, shape (n_particles, n_dim)cell (
jax.typing.ArrayLike) – Array representing cell, shape (n_dim, n_dim)
- Return type:
Array- Returns:
Stress in 6-component format.
- msmjax.utils.benchmarking.calc_nonperiodic_reference_results(positions, charges, cell)[source]
Calculate various exact reference for a non-periodic system
- Parameters:
positions (
jax.typing.ArrayLike) – Array of particle positions, shape (n_particles, n_dim)charges (
jax.typing.ArrayLike) – Array of particle charges, shape (n_particles,)cell (
jax.typing.ArrayLike) – Array representing cell, shape (n_dim, n_dim)
- Return type:
tuple[Array,Array,Array,Array]- Returns:
Energy, forces, charge gradient, stress tensor