msmjax.utils.general

General utilities

msmjax.utils.general._divide_zero_safe(numerator, denominator)[source]

Function that forces the result of dividing by 0 to be equal to 0.0 in a jit- and autodiff-compatible way

Parameters:
  • numerator (Array) – Values in the numerator

  • denominator (Array) – Values in the denominator, may contain zeros

Return type:

Array

Returns:

numerator / denominator with result == 0.0 where denominator == 0.0

msmjax.utils.general._sqrt_jvp(primals, tangents)[source]

Custom square-root implementation that avoids nan derivatives at zero

msmjax.utils.general.get_max_cutoff_for_mic(cell)[source]

Get the maximum cutoff value that fits into a given cell.

I.e., the maximum cutoff inside which distances calculated using the minimum-image convention (MIC) are guaranteed to be calculated correctly.

Parameters:

cell (jax.typing.ArrayLike) – Array representing unit cell, shape (n_dim, n_dim).

Raises:

ValueError – If cell has invalid spatial dimension.

Return type:

Array

Returns:

Cutoff radius

msmjax.utils.general.find_covering_grid_extents(grid_axes, spacings, cutoff)[source]

Find the number of grid cells needed to cover a sphere of given radius

Parameters:
  • grid_axes (jax.typing.ArrayLike) – Array of shape (n_dim, n_dim), representing the directions of the grid axes (rows are different grid axes, columns are Cartesian coordinates). Only the directions matter, as the vectors are internally normalized, and the grid spacings are supplied separately via the spacings argument.

  • spacings (jax.typing.ArrayLike) – 1-d array of grid spacing, one along each grid axis.

  • cutoff (float) – Cutoff radius.

Return type:

tuple[int, ...]

Returns:

Tuple of the required numbers of grid cells, for each axis direction, needed for the cutoff sphere to be contained in them. They are counted from the center of the sphere, i.e., e.g., for a sphere of radius 3.5 and a grid spanned by orthogonal axes and with a grid spacing of 1.0, in three dimensions, the result would be (4, 4, 4).