Source code for clinamen2.utils.jax_data

"""JSONEncoder that handles DeviceArrays.

    Removed from file_handling to avoid jnp dependency.
"""
import json
from collections import deque

import jax.numpy as jnp
import numpy as np

# If use case arises, add decoder and keep DeviceArray datatype.


[docs] class JSONEncoderwithJNP(json.JSONEncoder): """Class that extends JSONEncoder to handle different data types. Separate version that encodes JNP arrays to numpy. """
[docs] def default(self, o): """Return a json-izable version of o or delegate on the base class.""" if isinstance(o, np.generic): # Deal with non-serializable types such as numpy.int64 return o.item() elif isinstance(o, np.ndarray): nruter = { "main_type": "NumPy/" + o.dtype.name, "data": o.tolist(), } return nruter elif isinstance(o, jnp.DeviceArray): nruter = { "main_type": "NumPy/" + o.dtype.name, "data": o.tolist(), } return nruter elif isinstance(o, deque): nruter = { "main_type": "deque/" + str(o.maxlen), "data": list(o), } return nruter return json.JSONEncoder.default(self, o)