msmjax.core.shortrange
Generic short-range implementation (part evaluated without grids).
References
[1] Hardy, D. J.; Wolff, M. A.; Xia, J.; Schulten, K.; Skeel, R. D. Multilevel Summation with B-Spline Interpolation for Pairwise Interactions in Molecular Dynamics Simulations. J. Chem. Phys. 2016, 144 (11), 114112. https://doi.org/10.1063/1.4943868.
[2] Hardy, D. J. Multilevel Summation for the Fast Evaluation of Forces for the Simulation of Biomolecules (PhD thesis), University of Illinois at Urbana-Champaign, 2006.
- msmjax.core.shortrange._gen_supercell(positions, charges, cell, supercell_diag)[source]
Replicate unit cell and contained particles along its axes.
- Parameters:
positions (
jax.typing.ArrayLike
) – Array of positions, shape (n_particles, n_dim).charges (
jax.typing.ArrayLike
) – Array of charges, shape (n_particles,).cell (
jax.typing.ArrayLike
) – Array representing unit cell, shape (n_dim, n_dim).supercell_diag (
Sequence
[int
]) – Sequence of positive integers, one for each direction, indicating the number of times to replicate the system.
- Return type:
tuple
[Array
,Array
,Array
]- Returns:
- Tuple containing
array of positions after replication,
array of charges after replication,
unit cell after replication.
- msmjax.core.shortrange._generalized_diagonal_mask(a)[source]
Set the diagonal of a, possibly wider than tall, matrix to zero.
Adapted from JAX-MD. # TODO: JAX-MD attribution
Warning
Any NaN or infinite entries (including ones off the diagonal!) will be silently replaced by this function. For the diagonal, this is usually reasonable and desired. That is because in the case for which this function is designed, diagonal elements of the input correspond to interactions of particles with themselves, which are usually considered artifactual and which may be undefined. When off-diagonal elements are replaced this way, however, this may obscure the origin of bugs that caused them to be invalid.
- Parameters:
a (
jax.typing.ArrayLike
) – Original matrix.- Return type:
Array
- Returns:
The matrix with diagonal set to zero.
- msmjax.core.shortrange._displacement_free(r_1, r_2)[source]
Compute distance vector between two points in free space
- Parameters:
r_1 (
jax.typing.ArrayLike
) – First pointr_2 (
jax.typing.ArrayLike
) – Second point
- Return type:
Array
- Returns:
Distance vector.
- msmjax.core.shortrange._displacement_ortho(r_1, r_2, side_lengths)[source]
Compute distance vector between two points in orthorhombic cell.
Handles mixed periodicity: To indicate that specific directions lack periodicity, set the corresponding elements of
side_lengths
to zero. Periodicity is accounted for by means of the minimum image convention, with all the known limitations entailed by this.- Parameters:
r_1 (
jax.typing.ArrayLike
) – First pointr_2 (
jax.typing.ArrayLike
) – Second pointside_lengths (
jax.typing.ArrayLike
) – 1-d array of side lengths (one per direction)
- Return type:
Array
- Returns:
Distance vector.
- msmjax.core.shortrange._displacement_general(r_1, r_2, cell)[source]
Compute distance vector between two points in general triclinic cell.
Handles mixed periodicity: To indicate that specific directions lack periodicity, set the corresponding rows of
cell
to zero. Periodicity is accounted for by means of the minimum image convention, with all the known limitations entailed by this.- Parameters:
r_1 (
jax.typing.ArrayLike
) – First pointr_2 (
jax.typing.ArrayLike
) – Second pointcell (
jax.typing.ArrayLike
) – Array representing unit cell, shape (n_dim, n_dim).
- Return type:
Array
- Returns:
Distance vector.
- msmjax.core.shortrange._concretize_displacement_fn(pbc, cell_mode=None)[source]
Select/construct displacement fn based on PBCs, unit cell constraints.
Wraps lower-level displacement functions and transforms them into ones with a choice of periodic boundary conditions built-in already, and with a consistent signature.
- Parameters:
pbc (
Sequence
[bool
]) – One boolean per direction signaling periodicity.cell_mode (
Optional
[Literal
['ortho'
,'triclinic'
]]) – A string specifying assumptions on the shape of the unit cell. Either the cell is assumed orthorhombic and axis-aligned, in which case only its diagonal is considered, reducing computational cost, or a general triclinic one. May be omitted (and is ignored) if no direction is periodic.
- Return type:
Callable
[[jax.typing.ArrayLike
,jax.typing.ArrayLike
,Optional
[jax.typing.ArrayLike
]],Array
]- Returns:
A function of two position vector arguments and (optionally, depending on PBCs) a unit cell, that computes the distance vector between them.
- msmjax.core.shortrange.make_eval_pair_pot(kernel_fn, pbc, cell_mode=None, supercell_diag=None, per_particle=False, extra_uncharged_interaction=None)[source]
Transform interaction kernel into function acting on a particle system.
In other words, given a distance-dependent interaction kernel \(k(r)\), construct another function that computes the total system energy, \(\frac{1}{2} \sum_i \sum_{j \neq i} q_i q_j k(r_{ij})\), by mapping \(k(r)\) over all particle pairs.
Warning
Distance computations under periodic boundary conditions are handled by means of the minimum image convention, with the known limitations this entails. If the cutoff radius of
kernel_fn
is too large for the unit cell, or the unit cell is too deformed, results will be incorrect. If you know beforehand that this is an issue, you can remedy it by using thesupercell_diag
parameter (see below).- Parameters:
kernel_fn (
Callable
[[jax.typing.ArrayLike
],Array
]) – A function of a single scalar distance argument, corresponding to \(k(r)\) in the above formula.pbc (
Sequence
[bool
]) – One boolean per direction signaling periodicity.cell_mode (
Optional
[Literal
['ortho'
,'triclinic'
]]) – A string specifying assumptions on the shape of the unit cell. Either the cell is assumed orthorhombic and axis-aligned, in which case only its diagonal is considered, reducing computational cost, or a general triclinic one. May be omitted (and is ignored) if no direction is periodic.supercell_diag (
Optional
[Sequence
[int
]]) – An optional sequence of positive integers, one per direction. If supplied, pairwise interactions are computed between the particles in the original cell and all particles in a supercell created by repeating the cell the given number of times along each direction. This can be used to ensure that all interactions with neighbors are taken into account in cases where the cutoff ofkernel_fn
is too large for the original, non-replicated, cell.extra_uncharged_interaction (Callable[[Array | ndarray | bool_ | number | bool | int | float | complex], Array] | None)
- Return type:
Callable
[[jax.typing.ArrayLike
,jax.typing.ArrayLike
,Optional
[jax.typing.ArrayLike
]],Array
]- Returns:
A function that takes arrays of particle positions and charges, and the unit cell, as arguments and computes the energy for the whole system of particles.
- msmjax.core.shortrange.make_eval_pair_pot_neighborlist(kernel_fn, pbc, cell_mode=None, safe_eval_distance=1.0, extra_uncharged_interaction=None)[source]
Transform interaction kernel into function acting on a particle system.
Like
make_eval_pair_pot()
, but with a neighbor list.Warning
Distance computations under periodic boundary conditions are handled by means of the minimum image convention, with the known limitations this entails. If the cutoff radius of
kernel_fn
is too large for the unit cell, or the unit cell is too deformed, results will be incorrect. Unlikemake_eval_pair_pot()
, this neighbor-list version does not have a built-in supercell generation feature.- Parameters:
kernel_fn (
Callable
[[jax.typing.ArrayLike
],Array
]) – A function of a single scalar distance argument, see documentation ofmake_eval_pair_pot()
.pbc (
Sequence
[bool
]) – One boolean per direction signaling periodicity.cell_mode (
Optional
[Literal
['ortho'
,'triclinic'
]]) – A string specifying assumptions on the shape of the unit cell. Either the cell is assumed orthorhombic and axis-aligned, in which case only its diagonal is considered, reducing computational cost, or a general triclinic one. May be omitted (and is ignored) if no direction is periodic.safe_eval_distance (
float
) – A value for whichkernel_fn
evaluates to a result that is not NaN or infinite. Apart from this, it can be arbitrary. Used only internally, the exact value has no further consequence.extra_uncharged_interaction (Callable[[Array | ndarray | bool_ | number | bool | int | float | complex], Array] | None)
- Return type:
Callable
[[jax.typing.ArrayLike
,jax.typing.ArrayLike
,tuple
[jax.typing.ArrayLike
,jax.typing.ArrayLike
],jax.typing.ArrayLike
,Optional
[jax.typing.ArrayLike
]],Array
]- Returns:
A function that takes arrays of particle positions and charges, and the unit cell, and additionally a neighbor list and pairwise weights, as arguments and computes the energy for the whole system of particles.
- msmjax.core.shortrange.make_compute_u_zero(kernel_fns, pair_map_fn)[source]
Create a function that computes the MSM short-range energy contribution.
The precise quantity being computed is
\[U^0 = \frac{1}{2} \sum_i \sum_{j \neq i} q_i q_j k_0(r_{ij}) - \frac{1}{2} \sum_{l=1}^L \sum_i q_i^2 k_{l}(r)\big\rvert_{r=0} \, ,\]which consists of the pair interaction term for level zero, and a correction term for self-interaction at the higher levels.
This function is a high-level wrapper that constructs the evaluation function for \(U^0\) from two ingredients: The kernel functions \(k_l(r)\) at all levels, and a function that takes care of pair distance computation (this includes accounting for periodic boundary conditions) and evaluation of such a kernel function over all pairs of particles of a charged system.
- Parameters:
kernel_fns (
Sequence
[Callable
[[jax.typing.ArrayLike
],Array
]]) – List of functions of a single scalar distance argument, one for each MSM level, corresponding to the different partial kernels into which the full interaction kernel is split. They are expected to be natively broadcastable over array inputs.pair_map_fn (
Callable
[[Callable
[[jax.typing.ArrayLike
],Array
]],Callable
[[jax.typing.ArrayLike
,jax.typing.ArrayLike
,...
],Array
]]) –A function of a single argument that transforms a distance-dependent interaction kernel into a pairwise evaluation function that acts across a system of charged particles.
The input to
pair_map_fn
should be a single-argument function of a scalar distance argument.The return value of
pair_map_fn
should be a function that computes the first term in the formula above. It takes two arrays (positions of shape (n_particles, n_dim), and charges of shape (n_particles,)), plus optionally additional keyword arguments. Natural use cases for parameters passed through keyword arguments would be a unit cell in systems with periodicity, or a neighbor list.
Note
The way that periodic boundary conditions are handled is by an appropriate definition of
pair_map_fn
.pair_map_fn
gets applied to the zeroth element of thekernel_fns
argument:compute_pair_term = pair_map_fn(kernel_fns[0])
.The most convenient way to obtain a
pair_map_fn
with appropriate signature is by closingmake_eval_pair_pot()
ormake_eval_pair_pot_neighborlist()
over their extra arguments, e.g.pair_map_fn = functools.partial( make_eval_pair_pot, pbc=(True, True, False), cell_mode='ortho')
.
- Return type:
Callable
[[jax.typing.ArrayLike
,jax.typing.ArrayLike
,...
],Array
]- Returns:
A function with the same signature as the one returned by
pair_map_fn(kernel_fns[0])
, that computes \(U^0\) from positions, charges, and optional additional keyword arguments.