"""Plot functions for CMA-ES results.
"""
import pathlib
from numbers import Number
from typing import Tuple
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from clinamen2.utils.file_handling import CMAFileHandler
[docs]
class CMAPlotHelper:
"""Setup figure parameters and plot common results.
Args:
label: Label identifying the evolution and its results.
input_dir: Directory to read data from.
handler: Object to access the stored evolution and generations.
generation_bounds: First and last generation to include. Count starts
at 1, because generation 0 is the founder.
If generation_bounds[1] is passed in with value -1, the last
generation is read from the evolution element "last_gen" if
present.
losses: Loss of all read individuals.
generations: Indices of all read generations.
step_sizes: Step size of AlgorithmState.
output_dir: Directory to save files to. Will be created if it does not
exist. Default is None.
information_list: List of dictionaries with additional information per
generation.
additional_info_types: Contains keys identifying additional information
and the corresponding types as values for all additional
information that is of type numpy.ndarray, list of Numbers or
Number.
figsize: Default figure size to be used when no alternative is passed
to a plotting function.
colors: Dictionary of colors to be used in the plots.
fontsizes: Dictionary of fontsizes to be used in the plots.
"""
def __init__(
self,
label: str,
input_dir: pathlib.Path,
generation_bounds: Tuple[int, int],
output_dir: pathlib.Path = None,
figsize: Tuple[float, float] = (10, 8),
colors: dict = {
"dark_line": "#440154",
"medium_line": "#31688e",
"light_line": "#6ece58",
"dark_area": "#3e4989",
"medium_area": "#1f9e89",
"light_area": "#b5de2b",
"highlight": "#fde725",
},
fontsizes: dict = {
"title": 36,
"ax_label": 28,
"tick_param": 24,
"legend": 24,
},
json=True,
):
self.json = json
self.colors = colors
self.fontsizes = fontsizes
self.label = label
self.handler = CMAFileHandler(label=label, target_dir=input_dir)
file_format = "json" if json else "dill"
self.evolution = self.handler.load_evolution(file_format=file_format)
try:
self.last_gen = self.evolution[3]["last_gen"]
except KeyError:
self.last_gen = 0
if generation_bounds[1] == -1 and self.last_gen > 0:
generation_bounds = (generation_bounds[0], self.last_gen)
self.generation_bounds = generation_bounds
(
self.losses,
self.generations,
self.step_sizes,
self.information_list,
) = self.get_data_from_generations()
self.figsize = figsize
self.output_dir = output_dir if output_dir is not None else input_dir
if output_dir is not None:
output_dir.mkdir(parents=True, exist_ok=True)
self.additional_info_types = self.classify_additional_information()
[docs]
def get_data_from_generations(self):
"""Parse generations with handler and return data."""
file_format = "json" if self.json else "dill"
losses = []
generations = []
step_sizes = []
information_list = []
for g in range(
self.generation_bounds[0], self.generation_bounds[1] + 1
):
try:
(
current_state,
_,
loss,
_,
information,
) = self.handler.load_generation(
generation=g, label=self.label, file_format=file_format
)
losses.append(loss)
generations.append(g)
step_sizes.append(current_state.step_size)
information_list.append(information)
except FileNotFoundError:
pass # not every generation must have been saved
return (
np.asarray(losses),
np.asarray(generations),
np.asarray(step_sizes),
information_list, # list of dictionaries
)
[docs]
def plot_stepsize_ax2(
self,
ax: matplotlib.axis.Axis,
x: npt.ArrayLike,
selected_generations: npt.ArrayLike,
y_units: str = r"", # / \si{\angstrom}",
) -> matplotlib.axis.Axis:
"""Plot stepsize as twin of given axis.
Args:
ax: Main ax of figure. To be twinned.
x: Values for x-axis.
selected_generations: Generations to be plotted.
y_units: Units of loss to be added to y_label.
Default is r' / \si{\angstrom}'.
"""
ax2 = ax.twinx()
ax2.set_ylabel(
r"$\sigma$" + y_units,
fontsize=self.fontsizes["ax_label"],
color=self.colors["medium_line"],
)
ax2.yaxis.set_tick_params(
labelsize=self.fontsizes["tick_param"],
color=self.colors["medium_line"],
)
ln = ax2.plot(
x,
self.step_sizes[selected_generations],
color=self.colors["medium_line"],
linestyle="dashed",
label=r"$\sigma$",
)
ax2.tick_params(axis="y", colors=self.colors["medium_line"])
return ln
[docs]
def plot_mean_loss_per_generation(
self,
generation_bounds: Tuple[int, int] = None,
figsize: Tuple[float, float] = None,
show_sigma_e: bool = True,
sigma_e_mult: int = 3,
show_min_e: bool = False,
show_sigma: bool = False,
show_legend: bool = True,
y_units: str = "",
y_lim: Tuple[float] = None,
ref_val: float = None,
save_fig: bool = False,
fig_type: str = "pdf",
) -> None:
"""Plot mean loss per generation.
Args:
generations: Number of generations to be included.
generation_bounds: First and last generation to include. Count
starts at 1, because generation 0 is the founder.
figsize: Default figure size to be used when no alternative is
passed to a plotting function.
show_sigma_e: If True, sigma_e_mult * std deviation is plotted
additionally around the mean. Default is True.
sigma_e_mult: If sigma_e is shown, this defines the width.
show_min_e: If True, the minimum loss within each generation will
be plotted. Default is False.
show_sigma: If True, plot step size on axis y2. Default is False.
show_legend: If True, plot a legend. Default is True.
y_units: Units of loss to be added to y_label. Default is empty
string.
y_lim: Tuple of y_min and y_max to restrict the plot.
Default is None.
ref_val: Value to be plotted as a reference line. Default is None.
save_fig: If True the figure will be saved to file. Default is
False.
fig_type: File type to be saved. Default is 'pdf'.
"""
if generation_bounds is None:
generation_bounds = self.generation_bounds
if figsize is None:
figsize = self.figsize
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
ax.xaxis.set_tick_params(labelsize=self.fontsizes["ax_label"])
ax.set_xlabel(r"Generation", fontsize=self.fontsizes["ax_label"])
ax.yaxis.set_tick_params(labelsize=self.fontsizes["ax_label"])
ax.set_ylabel(r"Loss" + y_units, fontsize=self.fontsizes["ax_label"])
selected_generations = np.asarray(
self.generations[
np.greater_equal(self.generations, generation_bounds[0])
& np.less_equal(self.generations, generation_bounds[1])
]
)
generation_index = np.argwhere(
np.greater_equal(self.generations, generation_bounds[0])
& np.less_equal(self.generations, generation_bounds[1]),
).flatten()
x = np.asarray(selected_generations)
losses = self.losses[generation_index]
loss_mean = losses.mean(axis=1)
if show_min_e:
loss_min = losses.min(axis=1)
lns = ax.plot(
x, loss_mean, color=self.colors["dark_line"], label="mean"
)
if show_sigma_e:
loss_std = losses.std(axis=1)
ax.fill_between(
x,
loss_mean,
loss_mean + sigma_e_mult * loss_std,
color=self.colors["highlight"],
alpha=0.75,
)
lns += ax.plot(
x,
loss_mean + sigma_e_mult * loss_std,
color=self.colors["light_line"],
label=f"{sigma_e_mult}" + r"$\sigma_{E}$",
)
ax.fill_between(
x,
loss_mean,
loss_mean - sigma_e_mult * loss_std,
color=self.colors["highlight"],
alpha=0.75,
)
ax.plot(
x,
loss_mean - sigma_e_mult * loss_std,
color=self.colors["light_line"],
)
if show_min_e:
lns += [
ax.scatter(
x,
loss_min,
marker="x",
color=self.colors["dark_line"],
label="min",
)
]
if show_sigma:
lns += self.plot_stepsize_ax2(
ax=ax, x=x, selected_generations=generation_index
)
if y_lim is not None:
ax.set_ylim(y_lim)
if ref_val is not None:
lns += [
ax.axhline(
y=ref_val,
linestyle="dotted",
color="black",
alpha=0.85,
label="ref",
)
]
if show_legend:
labs = [l.get_label() for l in lns]
ax.legend(lns, labs, fontsize=self.fontsizes["legend"])
plt.title(
f"Mean loss per generation ({self.label})",
fontsize=self.fontsizes["title"],
)
if save_fig:
plt.savefig(
self.output_dir / (self.label + "_mean_loss." + fig_type),
format=fig_type,
)
return fig, ax
[docs]
def plot_loss_boxplots(
self,
generations: list,
figsize: Tuple[float, float] = None,
y_units: str = "",
color_by_key: bool = False,
color_key: str = "",
save_fig: bool = False,
fig_type: str = "pdf",
) -> None:
"""Plot loss of gen.
Args:
generations: Index of generations to be plotted. More than five will
not produce a satisfactory result.
figsize: Default figure size to be used when no alternative is
passed to a plotting function.
y_units: Units of loss to be added to y_label. Default is empty
string.
color_key: If True, the scatter plot will be colored according
to the data in 'key'. Default is False.
save_fig: If True the figure will be saved to file. Default is
False.
fig_type: File type to be saved. Default is 'pdf'.
"""
if figsize is None:
figsize = self.figsize
# check if all generations are available in the data
for g in generations:
if g not in (self.generations):
raise ValueError(f"Generation {g} not found.")
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
ax.xaxis.set_tick_params(labelsize=self.fontsizes["ax_label"])
ax.set_xlabel(r"Generation", fontsize=self.fontsizes["ax_label"])
# ax.xaxis.set_ticks(generations)
ax.xaxis.set_ticklabels([str(g) for g in generations])
ax.yaxis.set_tick_params(labelsize=self.fontsizes["ax_label"])
ax.set_ylabel(r"Loss" + y_units, fontsize=self.fontsizes["ax_label"])
data = []
for i, g in enumerate(generations):
index_of_g = np.argwhere(np.asarray(self.generations) == g)
data.append(self.losses[index_of_g].flatten())
_ = ax.boxplot(
data,
patch_artist=True,
boxprops=dict(facecolor=self.colors["light_area"], alpha=0.3),
zorder=-1,
)
if color_by_key:
coloring = []
for g in generations:
index_of_g = int(
np.argwhere(np.asarray(self.generations) == g).flatten()[0]
)
if self.additional_info_types[color_key] == Number:
color_data = np.asarray(
[
[
float(d[color_key])
for d in self.information_list[
index_of_g
][ # g - 1][
"information"
]
]
]
)
coloring.append(color_data)
for i, d in enumerate(data):
scatter = ax.scatter(
[i + 1] * d.shape[0],
d,
alpha=1.0,
cmap="plasma",
c=coloring[i],
)
cbar = plt.colorbar(scatter)
cbar.ax.tick_params(labelsize=self.fontsizes["tick_param"])
cbar.set_label(label=color_key, size=self.fontsizes["ax_label"])
else:
for i, d in enumerate(data):
_ = ax.scatter(
[i + 1] * d.shape[0],
d,
color=self.colors["dark_area"],
alpha=1.0,
)
plt.title(
f"Loss of individuals per generations {generations}",
fontsize=self.fontsizes["title"],
)
if save_fig:
plt.savefig(
self.output_dir / (self.label + "_loss_boxplots." + fig_type),
format=fig_type,
)
[docs]
def plot_loss_with_errorbars(
self,
generation: int,
key: str,
figsize: Tuple[float, float] = None,
errorbars: bool = True,
y_units: str = "",
save_fig: bool = False,
fig_type: str = "pdf",
) -> None:
"""Plot loss of gen.
Args:
generation: Index of generation.
key: Identifies the additional information to be plotted.
figsize: Default figure size to be used when no alternative is
passed to a plotting function.
errorbars: If True, errorbars are plotted.
y_units: Units of loss to be added to y_label.Default is empty
string.
save_fig: If True the figure will be saved to file. Default is
False.
fig_type: File type to be saved. Default is 'pdf'.
"""
# check if the generation is available in the data
if generation not in (self.generations):
raise ValueError(f"Generation {generation} not found.")
index_of_g = int(
np.argwhere(np.asarray(self.generations) == generation)
)
loss = self.losses[index_of_g]
if errorbars:
try:
if self.additional_info_types[key] == Number:
extra_data = np.asarray(
[
[
float(d[key])
for d in self.information_list[index_of_g][
"information"
]
]
]
)
except KeyError as exc:
raise KeyError(
f"Key {key} not present in additional information."
) from exc
if figsize is None:
figsize = self.figsize
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
ax.xaxis.set_tick_params(labelsize=self.fontsizes["ax_label"])
ax.set_xlabel(r"Individual", fontsize=self.fontsizes["ax_label"])
ax.yaxis.set_tick_params(labelsize=self.fontsizes["ax_label"])
ax.set_ylabel(r"Loss" + y_units, fontsize=self.fontsizes["ax_label"])
x = np.arange(loss.shape[0])
ax.set_xticks(x)
ax.scatter(x, loss, c=self.colors["dark_area"], s=50)
if errorbars:
ax.errorbar(
x,
loss,
yerr=extra_data,
fmt="none",
capsize=10,
ecolor=self.colors["medium_area"],
)
title = f"Individual loss in generation {generation}"
if errorbars:
title += " with errorbars."
plt.title(
title,
fontsize=self.fontsizes["title"],
)
if save_fig:
plt.savefig(
self.output_dir
/ (self.label + "_loss_with_errorbars." + fig_type),
format=fig_type,
)
[docs]
def plot_loss_per_generation(
self,
generation_bounds: Tuple[int, int] = None,
figsize: Tuple[float, float] = None,
y_units: str = "",
y_lim: Tuple[float, float] = None,
show_legend: bool = True,
ref_val: float = None,
kwargs: dict = None,
save_fig: bool = False,
fig_type: str = "pdf",
) -> None:
"""Plot loss per individual per generation.
Args:
generations: Number of generations to be included.
generation_bounds: First and last generation to include. Count
starts at 1, because generation 0 is the founder.
figsize: Default figure size to be used when no alternative is
passed to a plotting function.
show_legend: If True, plot a legend. Default is True.
ref_val: Value to be plotted as a reference line. Default is None.
y_units: Units of loss to be added to y_label. Default is empty
string.
y_lim: Limit to restrict y axis. Default is None.
save_fig: If True the figure will be saved to file. Default is
False.
fig_type: File type to be saved. Default is 'pdf'.
kwargs: Dicitionary with additional keywords.
"""
if generation_bounds is None:
generation_bounds = self.generation_bounds
if figsize is None:
figsize = self.figsize
if kwargs is None:
kwargs = {}
if "marker" not in kwargs.keys():
kwargs["marker"] = "x"
if "alpha" not in kwargs.keys():
kwargs["alpha"] = 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
ax.xaxis.set_tick_params(labelsize=self.fontsizes["ax_label"])
ax.set_xlabel(r"Generation", fontsize=self.fontsizes["ax_label"])
ax.yaxis.set_tick_params(labelsize=self.fontsizes["ax_label"])
ax.set_ylabel(r"Loss" + y_units, fontsize=self.fontsizes["ax_label"])
selected_generations = self.generations[
np.greater_equal(self.generations, generation_bounds[0])
& np.less_equal(self.generations, generation_bounds[1])
]
generation_index = np.argwhere(
np.greater_equal(self.generations, generation_bounds[0])
& np.less_equal(self.generations, generation_bounds[1]),
).flatten()
losses = self.losses[generation_index]
x = np.tile(
selected_generations,
(losses.shape[1], 1),
).T
ax.scatter(
x,
losses,
color=self.colors["medium_line"],
# label="all",
**kwargs,
)
ax.scatter(
x[:, 0],
losses.max(axis=1),
marker=".",
color=self.colors["dark_line"],
label="max",
)
ax.scatter(
x[:, 0],
losses.min(axis=1),
marker=".",
color=self.colors["light_line"],
label="min",
)
if ref_val is not None:
ax.axhline(
y=ref_val,
linestyle="dotted",
color="black",
alpha=0.85,
label="ref",
)
if y_lim is not None:
ax.set_ylim(y_lim)
plt.title(
f"Loss per generation ({self.label})",
fontsize=self.fontsizes["title"],
)
if show_legend:
plt.legend(fontsize=self.fontsizes["legend"])
if save_fig:
plt.savefig(
self.output_dir / (self.label + "_loss." + fig_type),
format=fig_type,
)
[docs]
def plot_additional_per_individual(
self,
key: str = None,
generation_bounds: Tuple[int, int] = None,
figsize: Tuple[float, float] = None,
y_units: str = "",
show_legend: bool = True,
ref_val: float = None,
kwargs: dict = None,
save_fig: bool = False,
fig_type: str = "pdf",
) -> None:
"""Plot additional info per individual per generation.
Args:
key: Key that identifies the additional information.
Supercedes index if both given. Default is None.
generations: Number of generations to be included.
generation_bounds: First and last generation to include. Count
starts at 1, because generation 0 is the founder.
figsize: Default figure size to be used when no alternative is
passed to a plotting function.
show_legend: If True, plot a legend. Default is True.
ref_val: Value to be plotted as a reference line. Default is None.
y_units: Units of loss to be added to y_label. Default is empty
string.
save_fig: If True the figure will be saved to file. Default is
False.
fig_type: File type to be saved. Default is 'pdf'.
kwargs: Dicitionary with additional keywords.
"""
if generation_bounds is None:
generation_bounds = self.generation_bounds
if figsize is None:
figsize = self.figsize
if kwargs is None:
kwargs = {}
if "marker" not in kwargs.keys():
kwargs["marker"] = "x"
if "alpha" not in kwargs.keys():
kwargs["alpha"] = 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
ax.xaxis.set_tick_params(labelsize=self.fontsizes["ax_label"])
ax.set_xlabel(r"Generation", fontsize=self.fontsizes["ax_label"])
ax.yaxis.set_tick_params(labelsize=self.fontsizes["ax_label"])
ax.set_ylabel(f"{key}" + y_units, fontsize=self.fontsizes["ax_label"])
selected_generations = self.generations[
np.greater_equal(self.generations, generation_bounds[0])
& np.less_equal(self.generations, generation_bounds[1])
]
if self.additional_info_types[key] == Number:
data = np.asarray(
[
[
float(d[key])
for d in self.information_list[i]["information"]
]
for i in range(self.generations.shape[0])
]
)
data = data[selected_generations - 1]
x = np.tile(
selected_generations,
(data.shape[1], 1),
).T
ax.scatter(
x,
data,
color=self.colors["medium_line"],
# label="all",
**kwargs,
)
ax.scatter(
x[:, 0],
data.max(axis=1),
marker=".",
color=self.colors["dark_line"],
label="max",
)
ax.scatter(
x[:, 0],
data.min(axis=1),
marker=".",
color=self.colors["light_line"],
label="min",
)
if ref_val is not None:
ax.axhline(
y=ref_val,
linestyle="dotted",
color="black",
alpha=0.85,
label="ref",
)
plt.title(
f"{key} per generation ({self.label})",
fontsize=self.fontsizes["title"],
)
if show_legend:
plt.legend(fontsize=self.fontsizes["legend"])
if save_fig:
plt.savefig(
self.output_dir / (self.label + f"_{key}." + fig_type),
format=fig_type,
)