"""Read and write CMA-ES results."""
import json
import pathlib
from collections import deque
from dataclasses import asdict
from typing import NamedTuple, Tuple
import dill
import numpy as np
import numpy.typing as npt
from clinamen2.cmaes.params_and_state import (
AlgorithmParameters,
AlgorithmState,
)
from clinamen2.cmaes.termination_criterion import Criterion
[docs]
class JSONEncoder(json.JSONEncoder):
"""Class that extends JSONEncoder to handle different data types."""
[docs]
def default(self, o):
"""Return a json-izable version of o or delegate on the base class."""
if isinstance(o, np.generic):
# Deal with non-serializable types such as numpy.int64
return o.item()
elif isinstance(o, np.ndarray):
nruter = {
"main_type": "NumPy/" + o.dtype.name,
"data": o.tolist(),
}
return nruter
elif isinstance(o, deque):
nruter = {
"main_type": "deque/" + str(o.maxlen),
"data": list(o),
}
return nruter
return json.JSONEncoder.default(self, o)
[docs]
class JSONDecoder(json.JSONDecoder):
"""Class that extends the JSONDecoder to handle different data types."""
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(
self, object_hook=self.object_hook, *args, **kwargs
)
[docs]
def object_hook(self, obj):
"""Reencode numpy arrays from dictionary."""
try:
main_type, *extra = obj["main_type"].split("/")
if main_type == "NumPy":
return np.asarray(obj["data"], dtype=extra[0])
elif main_type == "deque":
maxlen = int(extra[0])
return deque(obj["data"], maxlen=maxlen)
except (KeyError, ValueError):
return obj
[docs]
class CMAFileHandler:
"""Handle input and output files.
Args:
label: Label of the evolution run.
target_dir: Full path to target directory. Default is the current
working directory.
"""
def __init__(
self,
label: str = None,
target_dir: pathlib.Path = pathlib.Path.cwd(),
):
"""Constructor
Returns:
Name of generated file.
"""
self.label = label
self.target_dir = target_dir
[docs]
def get_evolution_filename(
self,
label: str = None,
) -> pathlib.Path:
"""Return a pathlib.Path object representing the evolution file.
Args:
label: Label of the evolution run.
"""
filename = "evolution"
if label is not None:
filename = filename + "_" + label
elif self.label is not None:
filename = filename + "_" + self.label
filename = self.target_dir / filename
return filename
[docs]
def get_generation_filename(
self,
generation: int,
label: str = None,
) -> pathlib.Path:
"""Return a pathlib.Path object representing the generation file.
Args:
generation: Generation index.
label: Label of the evolution run.
"""
filename = "generation"
if label is not None:
filename = filename + "_" + label + "_" + str(generation)
elif self.label is not None:
filename = filename + "_" + self.label + "_" + str(generation)
else:
filename = filename + "_" + str(generation)
filename = self.target_dir / filename
return filename
[docs]
def save_evolution(
self,
initial_parameters: AlgorithmParameters,
initial_state: AlgorithmState,
termination_criterion: Criterion = None,
label: str = None,
additional: dict = None,
file_format: str = "json",
json_encoder: json.JSONEncoder = JSONEncoder,
) -> str:
"""Function that writes the initial evolution to file.
Accepts any data that is serializable using dill. It might be
necessary to use a different JSONEncoder. If 'additional' contains
dataclasses, use `dataclasses.asdict()` and take care to manually
re-cast after loading.
Serialized data contains:
| 'initial_parameters': AlgorithmParameters
| 'initial_state': AlgorithmState
| 'termination_criterion': Criterion (optional)
| any additional compatible information (dill-only)
Filenames:
| `evolution.json`
| `evolution_ + label + .json`
Args:
initial_parameters: The initial parameters to start the evolution.
initial_state: The initial state to start the evolution.
termination_criterion: The termination criterion set up for the
evolution. Can be a combinated criterion.
label: String to be added to filename.
additional: Additional information to be saved with the
initial evolution properties.
file_format: Indicate the file format to be used for serialization.
The options are 'json' and 'dill'. Default is 'json'.
json_encoder: If default encoder file_handling.JSONEncoder does not
offer the required functionality, e.g., JAX datatypes.
Returns:
Name of generated file.
"""
filename = self.get_evolution_filename(label)
if additional is None:
additional = {}
if not isinstance(additional, dict):
raise TypeError(
"Parameter 'additional' must be a dictionary or None."
)
if file_format == "json":
if initial_parameters is not None:
initial_parameters = asdict(initial_parameters)
if initial_state is not None:
initial_state = asdict(initial_state)
data = {
"initial_parameters": initial_parameters,
"initial_state": initial_state,
**additional,
}
if termination_criterion is not None:
data["termination_criterion"] = termination_criterion
if file_format == "dill":
with open(str(filename) + ".dill", "wb") as f:
dill.dump(data, file=f, protocol=5)
elif file_format == "json":
with open(
str(filename) + ".json", "w", encoding="UTF-8"
) as json_file:
json.dump(data, json_file, cls=json_encoder)
else:
raise NotImplementedError(
f"Format {file_format} has not been implemented."
)
return filename
[docs]
def load_evolution(
self,
label: str = None,
file_format: str = "json",
) -> Tuple[AlgorithmParameters, AlgorithmState, Criterion, dict]:
"""Function that loads an evolution from file.
Objects loaded:
| 'initial_parameters': AlgorithmParameters
| 'initial_state': AlgorithmState
| 'termination_criterion': Criterion
| any additional compatible information
Args:
label: String to be added to filename.
file_format: Indicate the file format used for serialization.
The options are 'json' and 'dill'. Default is 'json'.
Returns:
tuple containing
- AlorithmParameters object
- AlgorithmState object
- Criterion object
- Dictionary containing additional objects if present.
"""
filename = self.get_evolution_filename(label)
if file_format == "dill":
with open(str(filename) + ".dill", "rb") as f:
loaded = dill.load(f)
elif file_format == "json":
with open(
str(filename) + ".json", "r", encoding="UTF-8"
) as json_file:
loaded = json.load(json_file, cls=JSONDecoder)
else:
raise ValueError(f"Format {file_format} not understood.")
initial_parameters = loaded.pop("initial_parameters")
initial_state = loaded.pop("initial_state")
try:
termination_criterion = loaded.pop("termination_criterion")
except KeyError:
termination_criterion = None
if file_format == "json":
initial_parameters = AlgorithmParameters(**initial_parameters)
initial_state = AlgorithmState(**initial_state)
return initial_parameters, initial_state, termination_criterion, loaded
[docs]
def save_generation(
self,
current_state: AlgorithmState,
population: npt.ArrayLike,
loss: npt.ArrayLike,
termination_state: NamedTuple = None,
label: str = None,
additional: dict = None,
file_format: str = "json",
json_encoder: json.JSONEncoder = JSONEncoder,
) -> str:
"""Function that writes a generation to file.
Accepts any data that is serializable using dill.
Serialized data contains:
| 'current_state': AlgorithmState
| 'population': numpy.ArrayLike
| 'loss': numpy.ArrayLike
| 'termination_state': NamedTuple
| any additional compatible information in a dictionary
Filenames:
'generation' + str(number) + '.json'
'generation' + label + '_' + str(number) +'.json'
Args:
current_state: The current state of the evolution.
population: The population of individuals of the generation.
loss: Loss of each individual within the population.
termination_state: State of termination criterion.
label: String to be added to filename.
additional: Additional information to be saved with the initial
evolution properties.
file_format: Indicate the file format to be used for serialization.
The options are 'json' and 'dill'. Default is 'json'.
json_encoder: If default encoder file_handling.JSONEncoder does not
offer the required functionality, e.g., JAX datatypes.
Returns:
Name of generated file.
"""
if isinstance(current_state, AlgorithmState):
generation = current_state.generation
elif isinstance(current_state, dict):
generation = current_state["generation"]
else:
generation = 0
filename = self.get_generation_filename(generation, label=label)
if file_format == "json":
if current_state is not None:
current_state = asdict(current_state)
if termination_state is not None:
termination_state = termination_state._asdict()
if additional is None:
additional = {}
if not isinstance(additional, dict):
raise TypeError(
"Parameter 'additional' must be a dictionary or None."
)
data = {
"current_state": current_state,
"population": population,
"loss": loss,
"termination_state": termination_state,
**additional,
}
if file_format == "dill":
with open(f"{filename}.dill", "wb") as f:
dill.dump(data, file=f, protocol=5)
elif file_format == "json":
with open(f"{filename}.json", "w", encoding="UTF-8") as json_file:
json.dump(data, json_file, cls=json_encoder)
else:
raise NotImplementedError(
f"Format {file_format} has not been implemented."
)
return filename
[docs]
def load_generation(
self,
generation: int,
label: str = None,
allow_legacy: bool = True,
file_format: str = "json",
) -> Tuple[AlgorithmState, npt.ArrayLike, npt.ArrayLike, NamedTuple, dict]:
"""Function that loads a generation from file.
Args:
generation: ID of generation to be loaded.
label: String to be added to filename.
allow_legacy: If True, older results containing 'fitness'
instead of 'loss' can be loaded. Default is True.
file_format: Indicate the file format used for serialization.
The options are 'json' and 'dill'. Default is 'json'.
Returns:
tuple containing
- AlgorithmState object
- Array of population
- Array of loss values
- NamedTuple of termination state
- Dictionary containing additional objects if present.
"""
filename = self.get_generation_filename(
generation=generation, label=label
)
if file_format == "dill":
with open(str(filename) + ".dill", "rb") as f:
loaded = dill.load(f)
elif file_format == "json":
with open(
str(filename) + ".json", "r", encoding="UTF-8"
) as json_file:
loaded = json.load(json_file, cls=JSONDecoder)
else:
raise NotImplementedError(
f"Format {file_format} has not been implemented."
)
current_state = loaded.pop("current_state")
population = loaded.pop("population")
if file_format == "json":
if current_state is not None:
current_state = AlgorithmState(**current_state)
try:
termination_state = loaded.pop("termination_state")
except KeyError:
termination_state = None
try:
loss = loaded.pop("loss")
except KeyError as exc: # to be able to handle older results
if allow_legacy:
loss = loaded.pop("fitness")
else:
raise KeyError(
"Legacy data may not be loaded. Check settings."
) from exc
return current_state, population, loss, termination_state, loaded
[docs]
def update_evolution(
self,
label: str = None,
additional: dict = None,
) -> str:
"""Function that calls save_evolution(file_format='json').
Accepts all data that is serializable using json.
Values for keys that are already present in the existing additional
information will be replaced.
Filenames:
| `evolution.json`
| `evolution_ + label + .json`
Args:
label: Label of the evolution run.
additional: Additional information to be saved with the
evolution properties.
Returns:
Name of evolution file.
"""
try:
evolution = self.load_evolution(label=label)
except FileNotFoundError:
print("Evolution update only implemented for json format.")
raise
existing_additional = evolution[3] if evolution[3] is not None else {}
updated_additional = existing_additional.copy()
updated_additional.update(additional)
filename = self.save_evolution(
initial_parameters=evolution[0],
initial_state=evolution[1],
termination_criterion=evolution[2],
label=label,
additional=updated_additional,
)
return filename