msmjax.kernels

Code for splitting interaction kernels into sum of partial kernels.

References

[1] Hardy, D. J.; Wolff, M. A.; Xia, J.; Schulten, K.; Skeel, R. D. Multilevel Summation with B-Spline Interpolation for Pairwise Interactions in Molecular Dynamics Simulations. J. Chem. Phys. 2016, 144 (11), 114112. https://doi.org/10.1063/1.4943868.

[2] Hardy, D. J. Multilevel Summation for the Fast Evaluation of Forces for the Simulation of Biomolecules (PhD thesis), University of Illinois at Urbana-Champaign, 2006.

class msmjax.kernels.SoftenerOneOverR(order)[source]

Class for constructing and evaluating softener for the 1/r kernel.

The softener is a function of a dimensionless argument rho that is equal to 1/rho for rho >= 1 and bounded and smooth for rho < 1.

Parameters:

order (int) – Order (as a function of s = rho**2) of the Taylor polynomial that the softening function consists of for rho < 1. In line with common spline terminology, its polynomial degree (as a function of s = rho**2) is equal to order - 1.

msmjax.kernels.split_one_over_r(max_level, level_zero_cutoff, softening_function)[source]

Split kernel 1/r in (max_level + 1) terms according to reference.

The splitting terms sum up to the Coulomb kernel 1/r like this: 1/r = g_0(r) + g_1(r) + g_2(r) + … + g_{max_order}(r)

Parameters:
  • max_level (int) – The number of splits to be performed.

  • level_zero_cutoff (float) – The cutoff radius of the level-zero kernel function.

  • softening_function (Callable) – The basic smoothing function for this splitting.

Return type:

list[Callable[[jax.typing.ArrayLike], Array]]

Returns:

A list of one-argument functions g_l with l from zero to max_order that represent the terms in the splitting of the interaction kernel. These have their respective cutoffs ‘built in’ already (in the sense that they evaluate to zero for distances beyond) and take their arguments in the same length units that level_zero_cutoff was supplied in.

Raises:

ValueError – If the arguments do not make sense.

msmjax.kernels._get_distances(extents_from_center, spacing_or_gridcell)[source]

Compute distances (from the origin) of points on a regular grid.

This function’s originally intended use is to construct the grid of distances at which the interaction kernels are to be evaluated for the purpose of calculating kernel stencils for interpolation.

Parameters:
  • extents_from_center (Sequence[int]) – Number of points along each axis, counted from, and not including, the center point. That is to say, for extents_from_center = (s_1, s_2, s_3), the shape of the grid of points is (2 * s_1 + 1, 2 * s_2 + 1, 2 * s_3 + 1).

  • spacing_or_gridcell (jax.typing.ArrayLike) – Either a scalar grid spacing value (use if orthogonal axes with the same grid spacing along all axes), a 1-d array (use if orthogonal axes with different spacings), or a 2-d array of shape (n_dim, n_dim) corresponding to exactly one grid cell (use if the axes spanning the grid are not orthogonal).

Return type:

Array

Returns:

Distances from the origin of points on the grid.

msmjax.kernels._compute_one_stencil(function_values, omega_prime, mode)[source]

Compute one kernel stencil by convolution with universal spline coeffs.

Parameters:
  • function_values (jax.typing.ArrayLike) – Values of the function for which the stencil of interpolation coefficients is to be computed, on a regular grid.

  • omega_prime (jax.typing.ArrayLike) – 1-d sequence of universal spline coefficients for convolution with function_values.

  • mode (str) – The mode argument passed to jax.scipy.signal.convolve used internally for doing the convolution.

Return type:

Array

Returns:

The resulting kernel stencil.

msmjax.kernels.make_construct_stencils(omega_prime, n_levels_intermed, include_toplevel, scaled_spacings, cell_mode, k_lowest_intermed=None, extents_from_center_intermed=None, k_toplevel=None, grid_shape_toplevel=None)[source]

Create a function to compute all kernel stencils from a given unit cell

The created function is suitable for passing as the kernel_stencil_construction_fn argument to msmjax.core.longrange.make_compute_u_oneplus().

The design is taileored to kernel splitting schemes like the one implemented by split_one_over_r(), in which all the intermediate-level partial kernels \(k_l(r)\), \(l = 1, ..., L - 1\) have the same functional form except for scaling, and the kernel stencils on grids at subsequent levels with twice the spacing can be computed by simply dividing the lower-level kernel stencil by two.

Parameters:
  • omega_prime (jax.typing.ArrayLike) – 1-d sequence of universal spline coefficients from convolution with which the kernel stencils are obtained.

  • n_levels_intermed (int) – The number of intermediate levels, i.e. all levels \(l = 1, ..., L-1\), but excluding \(l = 0\) and \(l = L\).

  • include_toplevel (bool) – Whether the top level (\(l = L\)) should be included. Usually, this is true in systems without periodicity, and false in periodic ones.

  • scaled_spacings (jax.typing.ArrayLike) – 1-d array of grid spacings values, one per axis, expressed in fractional coordinates w.r.t. the cell.

  • cell_mode (Literal['ortho', 'triclinic']) – A string specifying assumptions on the shape of the unit cell. Either the cell is assumed orthorhombic and axis-aligned, in which case only its diagonal is considered, reducing computational cost, or a general triclinic one.

  • k_lowest_intermed (Optional[Callable[[jax.typing.ArrayLike], Array]]) – The partial kernel at the lowest intermediate grid level (\(l = 1 < L\) if n_levels_intermed > 0). Omit if n_levels_intermed = 0, in which case \(L = 1\).

  • extents_from_center_intermed (tuple[int, ...] | None) – Number of points of the intermediate (\(l = L\)) stencils along each axis, counted from, and not including, the center point. That is to say, for extents_from_center = (s_1, s_2, s_3), the shape of the stencils is (2 * s_1 + 1, 2 * s_2 + 1, 2 * s_3 + 1). Required if k_toplevel was given.

  • k_toplevel (Optional[Callable[[jax.typing.ArrayLike], Array]]) – The partial kernel :math`k_L(r)` at the top level.

  • grid_shape_toplevel (tuple[int, ...] | None) – The shape (number of points along each axis) of the highest-level (\(l = L\)) grid. Required if k_toplevel was given.

Return type:

Callable[[jax.typing.ArrayLike], list[Array]]

Returns:

A function of one argument (cell, of shape (n_dim, n_ndim)), that returns a list of kernel stencils, one for each level, with None as a placeholder at level zero, where there is no grid.