msmjax.bspline.gridops

Generic code for long-range part (that is evaluated using grids).

References

[1] Hardy, D. J.; Wolff, M. A.; Xia, J.; Schulten, K.; Skeel, R. D. Multilevel Summation with B-Spline Interpolation for Pairwise Interactions in Molecular Dynamics Simulations. J. Chem. Phys. 2016, 144 (11), 114112. https://doi.org/10.1063/1.4943868.

[2] Hardy, D. J. Multilevel Summation for the Fast Evaluation of Forces for the Simulation of Biomolecules (PhD thesis), University of Illinois at Urbana-Champaign, 2006.

msmjax.bspline.gridops._find_n_gridpoints_1d(length, h, p, is_periodic)[source]

Find number of grid points along one dimension.

Parameters:
  • length (float) – Length of the domain where particles can be located.

  • h (float) – Grid spacing.

  • p (int) – Interpolation order.

  • is_periodic (bool) – Whether there is periodicity along this dimension.

Return type:

int

Returns:

Number of grid points.

msmjax.bspline.gridops.set_up_grids_all_levels(side_lengths, level_one_spacings, pbc, max_grid_level, p)[source]

Determine the shapes and spacings of the grids at all levels.

Parameters:
  • side_lengths (Sequence[float]) – Side lengths of a cuboid region where particles may be located.

  • level_one_spacings (Sequence[float]) – Level-one grid spacings, one for each direction.

  • pbc (Sequence[bool]) – One boolean per direction signaling periodicity.

  • max_grid_level (int) – The highest grid level.

  • p (int) – Interpolation order.

Returns:

  • A list of tuples indicating grid shapes, with the tuple at index l of the list corresponding to grid level l.

  • A list of arrays indicating grid spacings, with the tuple at index l of the list corresponding to grid level l.

Return type:

Tuple containing

msmjax.bspline.gridops._arbitrary_dim_outer(*xi)[source]

Compute the outer product of an arbitrary number of arrays

Return type:

Array

Parameters:

xi (Array)

msmjax.bspline.gridops._multiindex_outer(*inds_individual_axes)[source]

Construct a mesh grid and return it processed into a multiindex.

Return type:

tuple[Array, ...]

Parameters:

inds_individual_axes (Array)

msmjax.bspline.gridops._ravel_multi_index_with_invalidation(multi_index, dims, pbc)[source]

Like jax.numpy.ravel_multi_index() with different out-of-bounds handling

The out-of-bounds handling of this function depends on the boundary conditions, specified via pbc. Along periodic directions, it is identical to calling jax.numpy.ravel_multi_index() with mode="wrap". Along nonperiodic directions, any out-of-bounds values in the input mult-index are become out-of-bounds values also in the output flat index.

Parameters:
  • multi_index (tuple[Array, ...]) – Like in jax.numpy.ravel_multi_index().

  • dims (Sequence[int]) – Like in jax.numpy.ravel_multi_index().

  • pbc (Sequence[bool]) – One boolean per direction signaling periodicity.

Returns:

Flat indices.

msmjax.bspline.gridops.make_basis_evaluation_fn(grid_shape, p, pbc)[source]

Make fn that evaluates all contributing basis fns for a single particle.

Parameters:
  • grid_shape (tuple[int, ...]) – Tuple of integers representing shape of the grid on which the basis functions are defined.

  • p (int) – Interpolation order.

  • pbc (Sequence[bool]) – One boolean per direction signaling periodicity.

Return type:

Callable[[Array, Array], tuple[Array, Array]]

Returns:

A function that evaluates the grid-centered basis functions for a single particle. It has two inputs and two outputs:

  • The inputs are two 1-d arrays, both of shape (n_dim, ), representing the coordinates of a single particle, and the grid spacings along each direction.

  • The outputs are two 1-d arrays, representing the values of all basis functions that contain the particle in their support, and the flat indices of the corresponding grid points. They are always flat arrays, regardless of the spatial dimension of the system.

msmjax.bspline.gridops._make_restrict_1d(n_points_in, n_points_out, p, is_periodic)[source]

Make a function that performs restriction along one dimension.

Return type:

Callable[[Array], Array]

Parameters:
  • n_points_in (int)

  • n_points_out (int)

  • p (int)

  • is_periodic (bool)

msmjax.bspline.gridops.make_restriction_operator(grid_shape_in, grid_shape_out, p, pbc)[source]

Make a function that performs restriction.

Parameters:
  • grid_shape_in (Sequence[int]) – Tuple of integers indicating the shape of the input (lower-level, finer) grid.

  • grid_shape_out (Sequence[int]) – Tuple of integers indicating the shape of the target (higher-level, coarser) grid.

  • p (int) – Interpolation order.

  • pbc (Sequence[bool]) – One boolean per direction signaling periodicity.

Return type:

Callable[[Array], Array]

Returns:

A function that takes in an array of shape grid_shape_in and restricts it to a coarser grid, returning an array of shape grid_shape_out.

msmjax.bspline.gridops._make_prolongate_1d(n_points_in, n_points_out, p, is_periodic)[source]

Make a function that performs prolongation along one dimension.

Return type:

Callable[[Array], Array]

Parameters:
  • n_points_in (int)

  • n_points_out (int)

  • p (int)

  • is_periodic (bool)

msmjax.bspline.gridops.make_prolongation_operator(grid_shape_in, grid_shape_out, p, pbc)[source]

Make a function that performs prolongation.

Parameters:
  • grid_shape_in (Sequence[int]) – Tuple of integers indicating the shape of the input (higher-level, coarser) grid.

  • grid_shape_out (Sequence[int]) – Tuple of integers indicating the shape of the target (lower-level, finer) grid.

  • p (int) – Interpolation order.

  • pbc (Sequence[bool]) – One boolean per direction signaling periodicity.

Return type:

Callable[[Array], Array]

Returns:

A function that takes in an array of shape grid_shape_in and prolongs it to a finer grid, returning an array of shape grid_shape_out.

msmjax.bspline.gridops.create_all_grid_to_grid_ops(grid_shapes, p, pbc, convolution_methods)[source]

Create all necessary functions that map from grids to grids

Parameters:
  • grid_shapes (Sequence[tuple[int, ...]]) – Sequence of tuples of integers indicating the shape of the grids at all levels.

  • p (int) – Interpolation order.

  • pbc (Sequence[bool]) – One boolean per direction signaling periodicity.

  • convolution_methods (Sequence[Optional[Literal['scipy-direct', 'scipy-fft']]]) – Algorithm to use for convolution in the grid potential calculation. One value per grid level (i.e., different algorithms can be used at different levels).

Returns:

  • Sequence of functions for restrictions between all pairs of grids at adjacent levels.

  • Sequence of functions for prolongations between all pairs of grids at adjacent levels.

  • Sequence of functions for potential calculation via convolution with kernel stencils at all levels.

Return type:

Three-element tuple containing

msmjax.bspline.gridops.suggest_max_grid_level_nonperiodic(side_lengths, n_particles, level_one_spacings, level_zero_cutoff, p)[source]

Suggest an efficient value for highest grid level in nonperiodic case.

The criterion used, as per Ref. [1], is that the number of grid points at the highest level should be less than or equal to the sqare root of the number of particles. However, this may be unachievable for lower particle counts, since the support of the basis functions dictates that the grids extend beyond the boundary of the cell, so there is a minimum number of grid points that cannot be reduced further even by going to arbitrarily high grid levels. Therefore, other criteria are implemented as a failsafe as well.

Parameters:
  • side_lengths (jax.typing.ArrayLike) – Side lengths of a cuboid region where particles may be located.

  • n_particles (int) – Number of particles.

  • level_one_spacings (jax.typing.ArrayLike) – Level-one grid spacings, one value per direction.

  • level_zero_cutoff (float) – Level-zero cutoff.

  • p (int) – Interpolation order.

Return type:

int

Returns:

Suggested value for highest grid level.

msmjax.bspline.gridops.find_spacings_and_max_level_periodic(side_lengths, target_level_one_spacings)[source]

Find grid spacings and number of grid levels compatible with periodicity

The spacing and number of levels are not independent and are adjusted/determined automatically because:

  • In a periodic system, the grid spacings need to divide the side lengths without remainder.

  • In our implementation of the MSM for periodic boundary conditions, based on Ref. [2], the number of grids keeps being increased until there is only one point at the highest level, which is then not actually visited (it would not contribute due to charge neutrality). Since this condition can lead to non-ideal spacings, far from the ideal ones, being enforced, we allow the grid points along each direction to be not just powers of two, but also three times a power of two.

Parameters:
  • side_lengths – Array of side lengths, one per direction.

  • target_level_one_spacings – Array of target grid spacings at level one.

Returns:

  • Array of level-one spacings that have been adjusted to be compatible with periodic boundary conditions.

  • The highest level in the kernel splitting. This is one level higher than the highest grid level at which calculations are actually performed, because the top level (with either one or three grid points in each direction) is omitted.

Return type:

Tuple containing