msmjax.utils.general
General utilities
- msmjax.utils.general._divide_zero_safe(numerator, denominator)[source]
Function that forces the result of dividing by 0 to be equal to 0.0 in a jit- and autodiff-compatible way
- Parameters:
numerator (
Array) – Values in the numeratordenominator (
Array) – Values in the denominator, may contain zeros
- Return type:
Array- Returns:
numerator / denominator with result == 0.0 where denominator == 0.0
- msmjax.utils.general._sqrt_jvp(primals, tangents)[source]
Custom square-root implementation that avoids nan derivatives at zero
- msmjax.utils.general.get_max_cutoff_for_mic(cell)[source]
Get the maximum cutoff value that fits into a given cell.
I.e., the maximum cutoff inside which distances calculated using the minimum-image convention (MIC) are guaranteed to be calculated correctly.
- Parameters:
cell (
jax.typing.ArrayLike) – Array representing unit cell, shape (n_dim, n_dim).- Raises:
ValueError – If cell has invalid spatial dimension.
- Return type:
Array- Returns:
Cutoff radius
- msmjax.utils.general.find_covering_grid_extents(grid_axes, spacings, cutoff)[source]
Find the number of grid cells needed to cover a sphere of given radius
- Parameters:
grid_axes (
jax.typing.ArrayLike) – Array of shape (n_dim, n_dim), representing the directions of the grid axes (rows are different grid axes, columns are Cartesian coordinates). Only the directions matter, as the vectors are internally normalized, and the grid spacings are supplied separately via thespacingsargument.spacings (
jax.typing.ArrayLike) – 1-d array of grid spacing, one along each grid axis.cutoff (
float) – Cutoff radius.
- Return type:
tuple[int,...]- Returns:
Tuple of the required numbers of grid cells, for each axis direction, needed for the cutoff sphere to be contained in them. They are counted from the center of the sphere, i.e., e.g., for a sphere of radius 3.5 and a grid spanned by orthogonal axes and with a grid spacing of 1.0, in three dimensions, the result would be (4, 4, 4).