msmjax.core.longrange
Generic code for long-range part (that is evaluated using grids).
References
[1] Hardy, D. J.; Wolff, M. A.; Xia, J.; Schulten, K.; Skeel, R. D. Multilevel Summation with B-Spline Interpolation for Pairwise Interactions in Molecular Dynamics Simulations. J. Chem. Phys. 2016, 144 (11), 114112. https://doi.org/10.1063/1.4943868.
[2] Hardy, D. J. Multilevel Summation for the Fast Evaluation of Forces for the Simulation of Biomolecules (PhD thesis), University of Illinois at Urbana-Champaign, 2006.
- msmjax.core.longrange._anterpolate(basis_vals, basis_inds, charges, grid_shape)[source]
Low-level function doing anterpolation (= calculating grid charge).
- Parameters:
basis_vals (
jax.typing.ArrayLike
) – Array of shape (n_particles, support_size), where support_size is the number of grid points around one particle with non-zero values of their basis functions. Contains the values of nearby non-zero basis functions for all particles. The first axis runs over particles, the second over grid points.basis_inds (
jax.typing.ArrayLike
) – Array of the same shape asbasis_vals
which, for all particles, contains the flat (!) indices of all nearby grid points with non-zero basis function values.charges (
jax.typing.ArrayLike
) – Array of charges, shape (n_particles,).grid_shape (
tuple
[int
,...
]) – Tuple of integers representing shape of the target grid to which particle charges will be anterpolated.
- Returns:
An array of the same shape as the
grid_shape
parameter that contains the value of the grid charge for each grid point.
- msmjax.core.longrange._interpolate_energy(gridpotential, basis_vals, basis_inds, charges)[source]
Low-level function calculating long-range energy from grid potential.
- Parameters:
gridpotential (
jax.typing.ArrayLike
) – Array of grid potential (\(e^{l+}\) in the language of the reference).basis_vals (
jax.typing.ArrayLike
) – See_anterpolate()
.basis_inds (
jax.typing.ArrayLike
) – See_anterpolate()
.charges (
jax.typing.ArrayLike
) – Array of particle charges, shape (n_particles,).
- Return type:
Array
- Returns:
The scalar electrostatic energy.
- msmjax.core.longrange._interpolate_energy_positions_gradient(gridpotential, basis_grads, basis_inds, charges)[source]
Low-level function calculating positions gradient of long-range energy.
Implements an analytic expression for the derivative that calculates it by explicitly interpolating it from the grid potential. This allows a more efficient computation than default automatic differentiation of the energy.
- Parameters:
gridpotential (
jax.typing.ArrayLike
) – Array of grid potential (\(e^{l+}\) in the language of the reference).basis_grads (
jax.typing.ArrayLike
) – Similar tobasis_vals
(see_anterpolate()
), but containing the gradients of the basis functions w.r.t. particle positions instead of their values. Shape (n_particles, support_size, n_dim), where n_dim is the spatial dimension of the system.basis_inds (
jax.typing.ArrayLike
) – See_anterpolate()
.charges (
jax.typing.ArrayLike
) – Array of particle charges, shape (n_particles,).
- Return type:
Array
- Returns:
The gradient of the long-range energy w.r.t. particle positions, which is an array of shape (n_particles, n_dim).
- msmjax.core.longrange._interpolate_energy_charge_gradient(gridpotential, basis_vals, basis_inds)[source]
Low-level function calculating charge gradient of long-range energy.
Implements an analytic expression for the derivative that calculates it by explicitly interpolating it from the grid potential. This allows a more efficient computation than default automatic differentiation of the energy.
- Parameters:
gridpotential (
jax.typing.ArrayLike
) – Array of grid potential (\(e^{l+}\) in the language of the reference).basis_vals (
jax.typing.ArrayLike
) – See_anterpolate()
.basis_inds (
jax.typing.ArrayLike
) – See_anterpolate()
.
- Return type:
Array
- Returns:
The gradient of the long-range energy w.r.t. particle charges, which is an array of shape (n_particles,).
- msmjax.core.longrange.special_periodic_convolve_scipy(data, kernel, pbc, method)[source]
Perform a specialized case of convolution with optional wrapping.
Implemented as a wrapper around
jax.scipy.signal.convolve()
with, depending on periodicity, appropriate padding of the input arrays:If no direction is periodic, this function is equivalent to calling
jax.scipy.signal.convolve()
with mode=’same’.Along any periodic direction, the
data
array is first periodically replicated as much as needed for thekernel
array to not extend beyond the edges. Then, the convolution is performed withjax.scipy.signal.convolve()
, before trimming the result back to the original size ofdata
.
- Parameters:
data (
jax.typing.ArrayLike
) – First input. Represents some data on a real-space grid. If a direction is periodic, it corresponds to the values contained within the unit cell along that direction.kernel (
jax.typing.ArrayLike
) – Second input. Should have the same number of dimensions asdata
. Represents a finite-size kernel or filter. In the original intended use case, always has an odd number of points along each dimension (i.e., can be centered w.r.t. the points ofdata
).pbc (
Sequence
[bool
]) – One boolean per direction signaling periodicity.method (
Literal
['direct'
,'fft'
]) – String indicating the method to use for calculating the convolution. Either ‘direct’ or ‘fft’. Passed on tojax.scipy.signal.convolve()
. fft is usually much faster.
- Return type:
Array
- Returns:
An array of the same shape as
data
containing the convolution of the two arrays.
- msmjax.core.longrange.make_grid_pass_fn(restriction_fns, prolongation_fns, interaction_fns)[source]
Create a function that makes a pass through all grid levels.
In other words, create the linear operator (consisting of restriction of the grid charge to higher grid levels, calculation of potentials at all levels, and prolongation of the higher-level potentials down to lower levels) that connects the grid charge at level one to the accumulated grid potential at level one. This corresponds to the upper part of the V-cycle diagram as which the MSM is commonly visualized.
Note
The operator sequences passed as parameters to this function need to respect the grid-level indexing convention. As per convention, the index 0 refers to the particle level (where interactions are computed directly without grids), whereas the lowest actual grid level is at index 1. Further, for the restriction and prolongation functions, which connect two different grid levels, the convention is that the function whose output lives on grid level \(l\) is located at index \(l\) of the sequence. For example, the operator \(\mathcal{ I}^2_1\) that restricts the grid charge from level 1 to 2 would be addressed as
restriction_fns[2]
.Thus, the input for, e.g., four grid levels should look like this, employing placeholders where needed:
restriction_fns = [None, None, I_21, I_32, I_43] prolongation_fns = [None, I_12, I_23, I_34, None] interaction_fns = [None, K_1, K_2, K_3, K_4]
- Parameters:
restriction_fns (
Sequence
[Callable
[[jax.typing.ArrayLike
],Array
]]) – Sequence of restriction functions, one for each level, including placeholders (see above note). Corresponding to the upward arrows on the left side of the V-cycle diagram. Input is array of grid charge, output is array of grid charge one level higher.prolongation_fns (
Sequence
[Callable
[[jax.typing.ArrayLike
],Array
]]) – Sequence of prolongation functions, one for each level, including placeholders (see above note). Corresponding to the downward arrows on the right side of the V-cycle diagram. Input is array of grid potential, output is array of potential prolongated to the next lower level.interaction_fns (
Sequence
[Callable
[[jax.typing.ArrayLike
,jax.typing.ArrayLike
],Array
]]) – Sequence of interaction functions, one for each level, including placeholders (see above note). Corresponding to the horizontal arrows in the V-cycle diagram. Inputs are two arrays, grid charge and a kernel coefficient stencil, output is the grid potential on the same level. Mathematically: \(e^{l}_{\mathbf{m}} = \sum_{\mathbf{n}} K^l_{\mathbf{m} - \mathbf{n}} \tilde{q}^l_{\mathbf{n}}\).
- Return type:
Callable
[[jax.typing.ArrayLike
,Sequence
[Optional
[jax.typing.ArrayLike
]]],Array
]- Returns:
A function for performing the pass through all grid levels. It calculates the accumulated grid potential \(e^{1+}\) at level one and takes two arguments:
The level-one grid charge \(\tilde{q}^1\).
A sequence of coefficient stencil arrays for the interaction kernels, one per grid level (including a placeholder at level zero, see note above on grid-level indexing convention). These are passed to the
interaction_fns
that were supplied to construct the grid pass function. The \(l\)-th stencil is consumed by the \(l\)-th element ofìnteraction_fns
- msmjax.core.longrange.make_compute_u_oneplus(singleparticle_basis_fn_lvl_one, grid_pass_fn, grid_shape_lvl_one, transform_mode=None, kernel_stencils=None, kernel_stencil_construction_fn=None, use_custom_derivatives=True)[source]
Create a function that computes the MSM long-range energy contribution.
The quantity being (approximately) calculated is called \(U^{1+}\) in reference 1.
- Parameters:
singleparticle_basis_fn_lvl_one (
Callable
[[jax.typing.ArrayLike
],tuple
[Array
,Array
]]) –A function that, for one particle,
identifies all points on the level-one grid that are sufficiently close for the particle’s position to be contained within the support of the associated basis functions, i.e., finds the set of grid points \(M = \{ \mathbf{m} : \varphi^{1}_{\mathbf{m}}(\mathbf{r}_i) \neq 0 \} \,\) (where \(\mathbf{r}_i\) denotes the position of particle \(i\) and \(\mathbf{\varphi^{1}_{\mathbf{m}}}\) is the basis function centered on point \(\mathbf{m}\) of the level-one grid),
evaluates the corresponding basis functions, i.e. computes \(\varphi^{1}_{\mathbf{m}}(\mathbf{r}_i)\) for all grid points \(\mathbf{m} \in M \,\).
Inputs and outputs:
Input to
singleparticle_basis_fn_lvl_one
should be a 1-d array, shape (n_dim,), representing the coordinates of a single particle.Output of
singleparticle_basis_fn_lvl_one
should be a tuple of two 1-d arrays, each of shape (support_size,), where support_size designates the number of non-zero basis functions around one particle (= the cardinality of \(M\) from above). The first of the two arrays contains the values of the basis functions at all grid points \(\mathbf{m} \in M\). The second array contains the corresponding set of grid point indices \(M\) as flat (!) indices into the grid.
Note
Regardless of the spatial dimension of the system,
singleparticle_basis_fn_lvl_one
should always return flat arrays.Warning
If
singleparticle_basis_fn_lvl_one
returs indices that are out of bounds w.r.t. to the grid size defined bygrid_shape_lvl_one
, this will result in NaNs. This is intentional because errors like particles moving outside the grid boundaries might otherwise go unnoticed.grid_pass_fn (
Callable
[[jax.typing.ArrayLike
,Sequence
[Optional
[jax.typing.ArrayLike
]]],Array
]) –A function that performs the entire process of moving up, across, and back down the grid hierarchy. For more details on the expected signature, see
make_grid_pass_fn()
, which can be used conveniently to create such a function.Inputs and outputs:
Input to
grid_pass_fn
should be the level-one grid charge \(\tilde{q}^1\) (an array whose shape matches thegrid_shape_lvl_one
parameter), and a sequence of coefficient stencils for the interaction kernels (one per grid level, including a placeholder at level zero).Output of
grid_pass_fn
should be the level-one grid potential, also of shapegrid_shape_lvl_one
.
grid_shape_lvl_one (
tuple
[int
,...
]) – Tuple of integers representing shape of target grid at level one, to which particle charges will be anterpolated.use_custom_derivatives (
bool
) – Whether the returned energy function should use custom (more efficient) differentiation rules for its derivatives w.r.t. positions and charges.transform_mode (Literal['ortho', 'triclinic'] | None)
kernel_stencils (Sequence[Array | ndarray | bool_ | number | bool | int | float | complex | None])
kernel_stencil_construction_fn (Callable[[Array | ndarray | bool_ | number | bool | int | float | complex], Sequence[Array | ndarray | bool_ | number | bool | int | float | complex | None]])
- Return type:
Callable
[[jax.typing.ArrayLike
,jax.typing.ArrayLike
,Sequence
[Optional
[jax.typing.ArrayLike
]]],Array
]- Returns:
A function that computes the scalar energy \(U^{1+}\) and takes three arguments:
Array of positions, shape (n_particles, n_dim).
Array of charges, shape (n_particles,).
A sequence of coefficient stencils for the interaction kernels (one per grid level, including a placeholder at level zero). Passed to
grid_pass_fn
. Seemake_grid_pass_fn()
for more details.