msmjax.core.longrange

Generic code for long-range part (that is evaluated using grids).

References

[1] Hardy, D. J.; Wolff, M. A.; Xia, J.; Schulten, K.; Skeel, R. D. Multilevel Summation with B-Spline Interpolation for Pairwise Interactions in Molecular Dynamics Simulations. J. Chem. Phys. 2016, 144 (11), 114112. https://doi.org/10.1063/1.4943868.

[2] Hardy, D. J. Multilevel Summation for the Fast Evaluation of Forces for the Simulation of Biomolecules (PhD thesis), University of Illinois at Urbana-Champaign, 2006.

msmjax.core.longrange._anterpolate(basis_vals, basis_inds, charges, grid_shape)[source]

Low-level function doing anterpolation (= calculating grid charge).

Parameters:
  • basis_vals (jax.typing.ArrayLike) – Array of shape (n_particles, support_size), where support_size is the number of grid points around one particle with non-zero values of their basis functions. Contains the values of nearby non-zero basis functions for all particles. The first axis runs over particles, the second over grid points.

  • basis_inds (jax.typing.ArrayLike) – Array of the same shape as basis_vals which, for all particles, contains the flat (!) indices of all nearby grid points with non-zero basis function values.

  • charges (jax.typing.ArrayLike) – Array of charges, shape (n_particles,).

  • grid_shape (tuple[int, ...]) – Tuple of integers representing shape of the target grid to which particle charges will be anterpolated.

Returns:

An array of the same shape as the grid_shape parameter that contains the value of the grid charge for each grid point.

msmjax.core.longrange._interpolate_energy(gridpotential, basis_vals, basis_inds, charges)[source]

Low-level function calculating long-range energy from grid potential.

Parameters:
  • gridpotential (jax.typing.ArrayLike) – Array of grid potential (\(e^{l+}\) in the language of the reference).

  • basis_vals (jax.typing.ArrayLike) – See _anterpolate().

  • basis_inds (jax.typing.ArrayLike) – See _anterpolate().

  • charges (jax.typing.ArrayLike) – Array of particle charges, shape (n_particles,).

Return type:

Array

Returns:

The scalar electrostatic energy.

msmjax.core.longrange._interpolate_energy_positions_gradient(gridpotential, basis_grads, basis_inds, charges)[source]

Low-level function calculating positions gradient of long-range energy.

Implements an analytic expression for the derivative that calculates it by explicitly interpolating it from the grid potential. This allows a more efficient computation than default automatic differentiation of the energy.

Parameters:
  • gridpotential (jax.typing.ArrayLike) – Array of grid potential (\(e^{l+}\) in the language of the reference).

  • basis_grads (jax.typing.ArrayLike) – Similar to basis_vals (see _anterpolate()), but containing the gradients of the basis functions w.r.t. particle positions instead of their values. Shape (n_particles, support_size, n_dim), where n_dim is the spatial dimension of the system.

  • basis_inds (jax.typing.ArrayLike) – See _anterpolate().

  • charges (jax.typing.ArrayLike) – Array of particle charges, shape (n_particles,).

Return type:

Array

Returns:

The gradient of the long-range energy w.r.t. particle positions, which is an array of shape (n_particles, n_dim).

msmjax.core.longrange._interpolate_energy_charge_gradient(gridpotential, basis_vals, basis_inds)[source]

Low-level function calculating charge gradient of long-range energy.

Implements an analytic expression for the derivative that calculates it by explicitly interpolating it from the grid potential. This allows a more efficient computation than default automatic differentiation of the energy.

Parameters:
  • gridpotential (jax.typing.ArrayLike) – Array of grid potential (\(e^{l+}\) in the language of the reference).

  • basis_vals (jax.typing.ArrayLike) – See _anterpolate().

  • basis_inds (jax.typing.ArrayLike) – See _anterpolate().

Return type:

Array

Returns:

The gradient of the long-range energy w.r.t. particle charges, which is an array of shape (n_particles,).

msmjax.core.longrange.special_periodic_convolve_scipy(data, kernel, pbc, method)[source]

Perform a specialized case of convolution with optional wrapping.

Implemented as a wrapper around jax.scipy.signal.convolve() with, depending on periodicity, appropriate padding of the input arrays:

  • If no direction is periodic, this function is equivalent to calling jax.scipy.signal.convolve() with mode=’same’.

  • Along any periodic direction, the data array is first periodically replicated as much as needed for the kernel array to not extend beyond the edges. Then, the convolution is performed with jax.scipy.signal.convolve(), before trimming the result back to the original size of data.

Parameters:
  • data (jax.typing.ArrayLike) – First input. Represents some data on a real-space grid. If a direction is periodic, it corresponds to the values contained within the unit cell along that direction.

  • kernel (jax.typing.ArrayLike) – Second input. Should have the same number of dimensions as data. Represents a finite-size kernel or filter. In the original intended use case, always has an odd number of points along each dimension (i.e., can be centered w.r.t. the points of data).

  • pbc (Sequence[bool]) – One boolean per direction signaling periodicity.

  • method (Literal['direct', 'fft']) – String indicating the method to use for calculating the convolution. Either ‘direct’ or ‘fft’. Passed on to jax.scipy.signal.convolve(). fft is usually much faster.

Return type:

Array

Returns:

An array of the same shape as data containing the convolution of the two arrays.

msmjax.core.longrange.make_grid_pass_fn(restriction_fns, prolongation_fns, interaction_fns)[source]

Create a function that makes a pass through all grid levels.

In other words, create the linear operator (consisting of restriction of the grid charge to higher grid levels, calculation of potentials at all levels, and prolongation of the higher-level potentials down to lower levels) that connects the grid charge at level one to the accumulated grid potential at level one. This corresponds to the upper part of the V-cycle diagram as which the MSM is commonly visualized.

Note

The operator sequences passed as parameters to this function need to respect the grid-level indexing convention. As per convention, the index 0 refers to the particle level (where interactions are computed directly without grids), whereas the lowest actual grid level is at index 1. Further, for the restriction and prolongation functions, which connect two different grid levels, the convention is that the function whose output lives on grid level \(l\) is located at index \(l\) of the sequence. For example, the operator \(\mathcal{ I}^2_1\) that restricts the grid charge from level 1 to 2 would be addressed as restriction_fns[2].

Thus, the input for, e.g., four grid levels should look like this, employing placeholders where needed:

restriction_fns = [None, None, I_21, I_32, I_43]
prolongation_fns = [None, I_12, I_23, I_34, None]
interaction_fns = [None, K_1, K_2, K_3, K_4]
Parameters:
  • restriction_fns (Sequence[Callable[[jax.typing.ArrayLike], Array]]) – Sequence of restriction functions, one for each level, including placeholders (see above note). Corresponding to the upward arrows on the left side of the V-cycle diagram. Input is array of grid charge, output is array of grid charge one level higher.

  • prolongation_fns (Sequence[Callable[[jax.typing.ArrayLike], Array]]) – Sequence of prolongation functions, one for each level, including placeholders (see above note). Corresponding to the downward arrows on the right side of the V-cycle diagram. Input is array of grid potential, output is array of potential prolongated to the next lower level.

  • interaction_fns (Sequence[Callable[[jax.typing.ArrayLike, jax.typing.ArrayLike], Array]]) – Sequence of interaction functions, one for each level, including placeholders (see above note). Corresponding to the horizontal arrows in the V-cycle diagram. Inputs are two arrays, grid charge and a kernel coefficient stencil, output is the grid potential on the same level. Mathematically: \(e^{l}_{\mathbf{m}} = \sum_{\mathbf{n}} K^l_{\mathbf{m} - \mathbf{n}} \tilde{q}^l_{\mathbf{n}}\).

Return type:

Callable[[jax.typing.ArrayLike, Sequence[Optional [jax.typing.ArrayLike]]], Array]

Returns:

A function for performing the pass through all grid levels. It calculates the accumulated grid potential \(e^{1+}\) at level one and takes two arguments:

  • The level-one grid charge \(\tilde{q}^1\).

  • A sequence of coefficient stencil arrays for the interaction kernels, one per grid level (including a placeholder at level zero, see note above on grid-level indexing convention). These are passed to the interaction_fns that were supplied to construct the grid pass function. The \(l\)-th stencil is consumed by the \(l\)-th element of ìnteraction_fns

msmjax.core.longrange.make_compute_u_oneplus(singleparticle_basis_fn_lvl_one, grid_pass_fn, grid_shape_lvl_one, transform_mode=None, kernel_stencils=None, kernel_stencil_construction_fn=None, use_custom_derivatives=True)[source]

Create a function that computes the MSM long-range energy contribution.

The quantity being (approximately) calculated is called \(U^{1+}\) in reference 1.

Parameters:
  • singleparticle_basis_fn_lvl_one (Callable[[jax.typing.ArrayLike], tuple[Array, Array]]) –

    A function that, for one particle,

    1. identifies all points on the level-one grid that are sufficiently close for the particle’s position to be contained within the support of the associated basis functions, i.e., finds the set of grid points \(M = \{ \mathbf{m} : \varphi^{1}_{\mathbf{m}}(\mathbf{r}_i) \neq 0 \} \,\) (where \(\mathbf{r}_i\) denotes the position of particle \(i\) and \(\mathbf{\varphi^{1}_{\mathbf{m}}}\) is the basis function centered on point \(\mathbf{m}\) of the level-one grid),

    2. evaluates the corresponding basis functions, i.e. computes \(\varphi^{1}_{\mathbf{m}}(\mathbf{r}_i)\) for all grid points \(\mathbf{m} \in M \,\).

    Inputs and outputs:

    • Input to singleparticle_basis_fn_lvl_one should be a 1-d array, shape (n_dim,), representing the coordinates of a single particle.

    • Output of singleparticle_basis_fn_lvl_one should be a tuple of two 1-d arrays, each of shape (support_size,), where support_size designates the number of non-zero basis functions around one particle (= the cardinality of \(M\) from above). The first of the two arrays contains the values of the basis functions at all grid points \(\mathbf{m} \in M\). The second array contains the corresponding set of grid point indices \(M\) as flat (!) indices into the grid.

    Note

    Regardless of the spatial dimension of the system, singleparticle_basis_fn_lvl_one should always return flat arrays.

    Warning

    If singleparticle_basis_fn_lvl_one returs indices that are out of bounds w.r.t. to the grid size defined by grid_shape_lvl_one, this will result in NaNs. This is intentional because errors like particles moving outside the grid boundaries might otherwise go unnoticed.

  • grid_pass_fn (Callable[[jax.typing.ArrayLike, Sequence[Optional [jax.typing.ArrayLike]]], Array]) –

    A function that performs the entire process of moving up, across, and back down the grid hierarchy. For more details on the expected signature, see make_grid_pass_fn(), which can be used conveniently to create such a function.

    Inputs and outputs:

    • Input to grid_pass_fn should be the level-one grid charge \(\tilde{q}^1\) (an array whose shape matches the grid_shape_lvl_one parameter), and a sequence of coefficient stencils for the interaction kernels (one per grid level, including a placeholder at level zero).

    • Output of grid_pass_fn should be the level-one grid potential, also of shape grid_shape_lvl_one.

  • grid_shape_lvl_one (tuple[int, ...]) – Tuple of integers representing shape of target grid at level one, to which particle charges will be anterpolated.

  • use_custom_derivatives (bool) – Whether the returned energy function should use custom (more efficient) differentiation rules for its derivatives w.r.t. positions and charges.

  • transform_mode (Literal['ortho', 'triclinic'] | None)

  • kernel_stencils (Sequence[Array | ndarray | bool_ | number | bool | int | float | complex | None])

  • kernel_stencil_construction_fn (Callable[[Array | ndarray | bool_ | number | bool | int | float | complex], Sequence[Array | ndarray | bool_ | number | bool | int | float | complex | None]])

Return type:

Callable[[jax.typing.ArrayLike, jax.typing.ArrayLike, Sequence[Optional [jax.typing.ArrayLike]]], Array]

Returns:

A function that computes the scalar energy \(U^{1+}\) and takes three arguments:

  • Array of positions, shape (n_particles, n_dim).

  • Array of charges, shape (n_particles,).

  • A sequence of coefficient stencils for the interaction kernels (one per grid level, including a placeholder at level zero). Passed to grid_pass_fn. See make_grid_pass_fn() for more details.