"""Implementation of the Lennard Jones potential and related functions."""
import argparse
import pathlib
from typing import Callable, Optional, Tuple
import jax
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt
from clinamen2.utils.script_functions import cma_parser
FRESH_SAMPLING = "fresh"
CONTINUE_SAMPLING = "continue"
[docs]
def get_max_span_from_lj_cluster(
lj_cluster: npt.ArrayLike, verbose=False
) -> float:
"""Evaluate LJ cluster positions and return largest span.
Args:
lj_cluster: Positions of the LJ spheres.
"""
lj_cluster_resh = np.reshape(a=lj_cluster, newshape=(-1, 3))
max_x_span = abs(lj_cluster_resh[:, 0].min() - lj_cluster_resh[:, 0]).max()
max_y_span = abs(lj_cluster_resh[:, 1].min() - lj_cluster_resh[:, 1]).max()
max_z_span = abs(lj_cluster_resh[:, 2].min() - lj_cluster_resh[:, 2]).max()
if verbose:
print(f"max x span = {max_x_span}")
print(f"max y span = {max_y_span}")
print(f"max z span = {max_z_span}")
return np.max([max_x_span, max_y_span, max_z_span])
[docs]
class PositionException(Exception):
"""Exception to be raised for exceeding position threshold."""
pass
[docs]
def create_position_filter(
position_bounds: npt.ArrayLike,
exception: Exception = BaseException,
) -> Callable:
"""Create function to filter configurations.
Args:
position_bounds: Positions components that may not be crossed.
Shape is (3, 2): 3 components, lower and upper bound for each.
exception: Exception identifying an issue with the calculation.
"""
def batch_filter(
loss: float, additional: list, inputs: Tuple
) -> Tuple[float, list]:
for i, inp in enumerate(inputs):
for p in range(3):
positions = np.reshape(a=inp, newshape=(-1, 3))
if (positions[:, p] < position_bounds[p, 0]).any() or (
positions[:, p] > position_bounds[p, 1]
).any():
additional[i]["exception"] = exception("Out of bounds.")
break
return loss, additional, inputs
def single_filter(loss: float, additional: list, inputs: Tuple):
positions = np.reshape(a=inputs, newshape=(-1, 3))
for p in range(3):
if (positions[:, p] < position_bounds[p, 0]).any() or (
positions[:, p] > position_bounds[p, 1]
).any():
additional["exception"] = PositionException("Out of bounds.")
break
return loss, additional, inputs
return single_filter, batch_filter
[docs]
def create_evaluate_lj_potential(
n_atoms: int = 38,
identifier: Optional[str] = None,
wales_path: pathlib.Path = None,
n_eval_batch=100,
) -> Callable:
"""Create an LJ evaluation function from a Cambridge database entry
Load the coordinates of the ground state of an n- atom LJ cluster
from an entry in
http://doye.chem.ox.ac.uk/jon/structures/LJ.html
and return the corresponding LennardJones object and an eval function.
Args:
n_atoms: Number of atoms in the cluster.
identifier: Additional identifier of a specific configuration.
For example "i" for "38i". Default is None.
wales_path: Path to the Wales potential data.
n_eval_batch: Batchsize to be vmapped.
Returns:
tuple
- Coordinates of specified Cluster
- Evaluation function
Raises:
ValueError: If there is a problem with the argument.
References:
[1] The Cambridge Cluster Database, D. J. Wales, J. P. K. Doye,
A. Dullweber, M. P. Hodges, F. Y. Naumkin F. Calvo, J. Hernández-Rojas
and T. F. Middleton, URL http://www-wales.ch.cam.ac.uk/CCD.html.
"""
if isinstance(n_atoms, int) or isinstance(n_atoms, str):
cluster = (
str(n_atoms) if identifier is None else str(n_atoms) + identifier
)
filename = wales_path / cluster
else:
raise ValueError("n must be an integer or a string")
if not filename.exists():
raise ValueError(
f"Coordinates for {n_atoms} atoms with identifier"
f" {identifier} not found in {wales_path}."
)
coordinates = np.loadtxt(filename)
# These coordinates use the sigma = 1 convention -> rescale
coordinates /= 2 ** (1.0 / 6.0)
@jax.jit
def evaluate_lj_potential(positions) -> float:
"""Calculate LJ energy of positions
Args:
positions: Flat coordinate vector.
Returns:
energy: The energy of the configuration.
"""
# Compute all distances between pairs without iterating.
positions = positions.reshape((-1, 3))
delta = positions[:, jnp.newaxis, :] - positions
r2 = (delta * delta).sum(axis=2)
# Take only the upper triangle (combinations of two atoms).
indices = jnp.triu_indices(r2.shape[0], k=1)
rm2 = 1.0 / r2[indices]
# Compute the potental energy recycling some calculations.
rm6 = rm2 * rm2 * rm2
return (rm6 * (rm6 - 2.0)).sum(), {}
vmapped_eval_batch = jax.vmap(evaluate_lj_potential, in_axes=(0))
def vmapped_eval_wrapper(positions) -> Tuple:
"""Wrapper around the vmapped evaluation.
Make sure that the additional information is a list of empty
dictionaries.
Args:
positions: LJ sphere positions.
Returns:
tuple containing
- Loss of the configuration.
- Additional information.
"""
loss = []
index_from = 0
while index_from < positions.shape[0]:
index_to = min(index_from + n_eval_batch, positions.shape[0])
positions_batch = positions[index_from:index_to, ...]
results_batch, _ = vmapped_eval_batch(positions_batch)
loss.extend(results_batch.tolist())
index_from = index_to
loss = jnp.asarray(loss)
additional = [{}] * loss.shape[0]
return loss, additional
return coordinates, evaluate_lj_potential, vmapped_eval_wrapper
[docs]
def lj_argparse():
"""Argument parser for LJ evolution scripts.
Returns
Parsed arguments
"""
parser = cma_parser()
parser.add_argument(
"-a",
"--atom_count",
type=int,
default=38,
help="Number of atoms in the cluster.",
)
parser.add_argument(
"-i",
"--identifier",
type=str,
default=None,
help="Identifier for cluster configuration, e.g. 'i'.",
)
parser.add_argument(
"-c",
"--configuration",
type=str,
default="cube",
help="Shape of initial configuration (cube, sphere, packmol).",
)
parser.add_argument(
"-b",
"--bounds",
type=float,
nargs="+",
help="Positions bounds as x1 x2 y1 y2 z1 z2",
default=None,
)
parser.add_argument(
"-q",
"--quiet",
action="store_true",
help="Do not write generation data to disk. Only end result.",
)
parser.add_argument(
"-j",
"--json_output",
type=str,
help="JSON file to store results in.",
default=None,
)
parser.add_argument(
"-w",
"--wales_path",
type=str,
help="Path to Wales potential data.",
default=None,
)
parser.add_argument(
"-e",
"--packmol_executable",
type=str,
help="Packmol executable with full path.",
default=None,
)
parser.add_argument(
"--packmol_tolerance",
type=float,
help="Packmol tolerance parameter.",
default=1.0,
)
parser.add_argument(
"--packmol_side_length",
type=float,
help="Packmol side_length parameter.",
default=3.0,
)
parser.add_argument(
"--packmol_seed",
type=float,
help="Random seed for packmol. Default is -1",
default=-1,
)
parser.add_argument(
"--continue_evolution",
action="store_true",
help="Set this flag to continue a stopped evolution.",
)
parser.add_argument(
"--generation_checkpoint",
type=int,
default=None,
help="Generation checkpoint to continue from. Use 'last_gen' if None.",
)
args, _ = parser.parse_known_args()
return args