msmjax.utils.benchmarking

Utilities for benchmarking

msmjax.utils.benchmarking.dir_context(dir_name)[source]

Create a context to run code in a different directory.

Parameters:

dir_name (Path) – Route to the directory.

msmjax.utils.benchmarking.write_lammps_data(filename, cell, positions, charges)[source]

Write structure as LAMMPS data file

Return type:

None

msmjax.utils.benchmarking.parse_energy_from_lammps_log(filename)[source]

Get the energy from LAMMPS log file

Return type:

float

msmjax.utils.benchmarking.parse_lammps_log(filename)[source]

Get energy and other results from LAMMPS log file

Return type:

tuple[float, ndarray]

msmjax.utils.benchmarking.make_lammps_input_text_pppm(filename_data, filename_dump, accuracy, max_neighbors_one_atom)[source]

Write LAMMPS input script

Parameters:
  • accuracy (float)

  • max_neighbors_one_atom (int | None)

msmjax.utils.benchmarking.eval_lammps_pppm(positions, charges, cell, lammps_executable='lmp', accuracy=1e-05, max_neighbors_one_atom=None, show_stdout=False)[source]

Wrapper to compute periodic electrostatic energy, forces in LAMMPS with p3m

Parameters:
  • positions (numpy.typing.ArrayLike) – Array of article positions, shape (n_particles, 3)

  • charges (numpy.typing.ArrayLike) – Array of particle charges, (n_particles,)

  • cell (numpy.typing.ArrayLike) – Unit cell

  • lammps_executable (str) – Path to LAMMPS executable

  • accuracy (float) – Accuracy setting passed to LAMMPS via kspace_style pppm {accuracy}.

  • max_neighbors_one_atom (int | None) – Maximum number of neighbors of a single atom. Supplied to LAMMPS via neigh_modify one if given. You may need to increase this value if you’re getting errors.

  • show_stdout – Whether to show the stdout from the subprocess call to LAMMPS.

Return type:

tuple[float, ndarray, ndarray, ndarray]

Returns:

energy, forces

msmjax.utils.benchmarking.calc_rmse(y_pred, y_true)[source]

Calculate RMSE.

Return type:

Array

Parameters:
  • y_pred (Array | ndarray | bool_ | number | bool | int | float | complex)

  • y_true (Array | ndarray | bool_ | number | bool | int | float | complex)

msmjax.utils.benchmarking.calc_relative_rmse_percent(y_pred, y_true)[source]

Calculate RMSE normalized to STD of reference results, as percentage.

Return type:

Array

Parameters:
  • y_pred (Array | ndarray | bool_ | number | bool | int | float | complex)

  • y_true (Array | ndarray | bool_ | number | bool | int | float | complex)

msmjax.utils.benchmarking.calc_relative_rmse(y_pred, y_true)[source]

Calculate RMSE normalized to STD of reference results.

Return type:

Array

Parameters:
  • y_pred (Array | ndarray | bool_ | number | bool | int | float | complex)

  • y_true (Array | ndarray | bool_ | number | bool | int | float | complex)

msmjax.utils.benchmarking.make_timed_eval(fn, repeat=10, number=100)[source]

Given a JAX-jittable function, return a function to time its evaluation.

The returned function internally takes care of jitting the function to be timed, and of calling .block_until_ready() on its output to avoid meaningless results because of JAX’s asynchronous dispatch.

Parameters:
  • fn (Callable) – A JAX-jittable function.

  • repeat (int) – Passed to timeit.repeat.

  • number (int) – Passed to timeit.repeat.

Return type:

Callable

Returns:

A function with the same argument structure as fn that, when called, returns a 2-element tuple containing:

  • The measured evaluation time

  • The original output of fn.

msmjax.utils.benchmarking.remove_duplicates_from_neighborlist(neighborlist, fill_value, size)[source]

Remove duplicate index pairs from a neighbor list.

Parameters:
  • neighborlist (Sequence[jax.typing.ArrayLike]) – Original neighbor list that may contain duplicates. (like containing both index pair ij and ji). Tuple of two 1-d integer arrays of the same size. See msmjax.utils.benchmarking() for more details on format.

  • fill_value (int) – A value used to pad the arrays in the neighbor list to length size. Should be a value that is ignored by the evaluation function to which the neighbor list will be passed.

  • size (int) – The number of index pairs (including placeholder values that contain fill_value) in the returned duplicate-free neighbor list. Must be known in advance to satisfy static-shape constraints under JIT.

Return type:

tuple[Array, Array]

Returns:

The neighbor list with duplicate index pairs removed (i.e., containing only ij and not ji).

msmjax.utils.benchmarking.build_duplicate_free_neighborlists(set_of_positions, set_of_cells, cutoff, pbc)[source]

Construct a neighbor list with no duplicate index pairs.

First uses matscipy for constructing the initial neighbor list, which does contain duplicates, then JAX is ued to remove the duplicates. The function processes several structures and pads the neighbor lists to the maximum neighbor-list size out of all these structures, in order to gain efficiency by avoiding re-compilation of the JAX duplicate-removal step.

Parameters:
  • set_of_positions (Sequence[jax.typing.ArrayLike]) – Sequence of particle positions (arrays of shape (n_particles, n_dim)), each corresponding to one structure.

  • set_of_cells (Sequence[jax.typing.ArrayLike]) – Sequence of cells (arrays of shape (n_dim, n_dim)), each corresponding to one structure.

  • cutoff (float) – Cutoff radius for neighbor search.

  • pbc (Sequence[bool]) – One boolean per direction signaling periodicity.

Return type:

list[tuple[Array, Array]]

Returns:

List of duplicate-free neighbor lists for all structures.

msmjax.utils.benchmarking.coulomb_kernel(r)[source]
msmjax.utils.benchmarking.calc_energy_nonperiodic_direct_summation(positions, charges)[source]

Calculate exact energy for no periodicity by all-pairs sum

Parameters:
  • positions (jax.typing.ArrayLike) – Array of particle positions, shape (n_particles, n_dim)

  • charges (jax.typing.ArrayLike) – Array of particle charges, shape (n_particles,)

Return type:

Array

Returns:

Energy

msmjax.utils.benchmarking.calc_forces_nonperiodic_direct_summation(positions, charges)[source]

Calculate exact forces for no periodicity by all-pairs sum

Parameters:
  • positions (jax.typing.ArrayLike) – Array of particle positions, shape (n_particles, n_dim)

  • charges (jax.typing.ArrayLike) – Array of particle charges, shape (n_particles,)

Return type:

Array

Returns:

Forces

msmjax.utils.benchmarking.calc_chargegrad_nonperiodic_direct_summation(positions, charges)[source]

Calculate exact charge gradient for no periodicity by all-pairs sum

Parameters:
  • positions (jax.typing.ArrayLike) – Array of particle positions, shape (n_particles, n_dim)

  • charges (jax.typing.ArrayLike) – Array of particle charges, shape (n_particles,)

Return type:

Array

Returns:

Gradient of the energy w.r.t. particle charges

msmjax.utils.benchmarking.calc_stress_virial(positions, forces, cell)[source]

Calculate stress tensor from virial

Parameters:
  • positions (jax.typing.ArrayLike) – Array of particle positions, shape (n_particles, n_dim)

  • forces (jax.typing.ArrayLike) – Array of forces, shape (n_particles, n_dim)

  • cell (jax.typing.ArrayLike) – Array representing cell, shape (n_dim, n_dim)

Return type:

Array

Returns:

Stress in 6-component format.

msmjax.utils.benchmarking.calc_nonperiodic_reference_results(positions, charges, cell)[source]

Calculate various exact reference for a non-periodic system

Parameters:
  • positions (jax.typing.ArrayLike) – Array of particle positions, shape (n_particles, n_dim)

  • charges (jax.typing.ArrayLike) – Array of particle charges, shape (n_particles,)

  • cell (jax.typing.ArrayLike) – Array representing cell, shape (n_dim, n_dim)

Return type:

tuple[Array, Array, Array, Array]

Returns:

Energy, forces, charge gradient, stress tensor