Source code for clinamen2.cmaes.cmaes_criteria

""" Implementation of the termination criteria as described in

    References:

    [1] N. Hansen, 2016, arXiv:1604.00772 [cs.LG]
"""
from collections import deque
from typing import NamedTuple

import numpy as np
import numpy.linalg as lin
import numpy.typing as npt
import scipy.linalg

from clinamen2.cmaes.params_and_state import (
    AlgorithmParameters,
    AlgorithmState,
)
from clinamen2.cmaes.termination_criterion import Criterion


[docs] class ConditionCovState(NamedTuple): """NamedTuple to keep track of condition number Args: cond: Condition number of C. """ cond: float = None
[docs] class ConditionCovCriterion(Criterion): """ConditionCov criterion Stop if the condition number of the covariance matrix exceeds 1e14. Args: parameters: Initial, immutable parameters of the CMA-ES run. threshold: Upper limit for accepted condition number. Default is 1e14. """ def __init__( self, parameters: AlgorithmParameters, threshold: float = 1e14 ): self.threshold = threshold super().__init__(parameters=parameters)
[docs] def init(self) -> ConditionCovState: """Initialize the associated NamedTuple.""" return ConditionCovState(cond=None)
[docs] def update( self, criterion_state: ConditionCovState, state: AlgorithmState, population: npt.ArrayLike, loss: npt.ArrayLike, ) -> ConditionCovState: """Return an updated associated NamedTuple. Args: criterion_state: The associated NamedTuple representing the current stopping criterion state. state: New algorithm state based on which to update the stopping criterion. population: New population based on which to update the stopping criterion. loss: New loss values based on which to update the stopping criterion. """ cond = lin.cond(state.cholesky_factor) ** 2 return ConditionCovState(cond=cond)
[docs] def met(self, criterion_state: ConditionCovState) -> bool: """Decide if the stopping criterion is met. Args: criterion_state: The associated NamedTuple representing the current stopping criterion state. """ return criterion_state.cond > self.threshold
[docs] class EqualFunValuesState(NamedTuple): """NamedTuple to keep track of function value staleness. Args: fun_values: Function values of past generations to be taken into account. """ fun_values: deque = None
[docs] class EqualFunValuesCriterion(Criterion): """EqualFunValues criterion Stop if the range of the loss within a certain range of generations is close to zero. Args: parameters: The algorithm parameters. generation_span: Number of generations over which the range of function values is to be taken into account. Default is 10 + ceil(30 dimension / pop_size). atol: Tolerance for 'zero'. Default is 1e-15. """ def __init__( self, parameters: AlgorithmParameters, generation_span: int = 0, atol: float = 1e-15, ): self.generation_span = int( generation_span if generation_span > 1 else 10 + np.ceil(30 * parameters.dimension / parameters.pop_size) ) self.atol = atol super().__init__(parameters=parameters)
[docs] def init(self) -> EqualFunValuesState: """Initialize the associated NamedTuple.""" print(f"span is {self.generation_span}") return EqualFunValuesState( fun_values=deque(maxlen=self.generation_span) )
[docs] def update( self, criterion_state: EqualFunValuesState, state: AlgorithmState, population: npt.ArrayLike, loss: npt.ArrayLike, ) -> EqualFunValuesState: """Return an updated associated NamedTuple. Args: criterion_state: The associated NamedTuple representing the current stopping criterion state. state: New algorithm state based on which to update the stopping criterion. population: New population based on which to update the stopping criterion. loss: New loss values based on which to update the stopping criterion. """ vals_deque = criterion_state.fun_values vals_deque.append(loss.min()) return EqualFunValuesState(fun_values=vals_deque)
[docs] def met(self, criterion_state: EqualFunValuesState) -> bool: """Decide if the stopping criterion is met. Args: criterion_state: The associated NamedTuple representing the current stopping criterion state. """ fun_values = np.asarray(criterion_state.fun_values) return ( len(criterion_state.fun_values) == self.generation_span and abs(fun_values.max() - fun_values.min()) < self.atol )
[docs] class TolXUpState(NamedTuple): """NamedTuple to keep track of function value staleness. Args: compare_to: Previous value of tolxup saved for comparison. latest_diff: Latest calculated absolute difference in tolxup values. """ compare_to: float = None latest_diff: float = None
[docs] class TolXUpCriterion(Criterion): """TolXUp criterion Stop if sigma times max(diag(D)) exceeds a threshold when compared to the previous generation. Args: parameters: Initial, immutable parameters of the CMA-ES run. threshold: Upper limit for accepted difference. Default is 1e4. interpolative: Control how the matrix norm is calculated. True: 'scipy.linalg.interpolative.estimate_spectral_norm' False: 'scipy.linalg.norm' Default is False. Use True for large matrices. """ def __init__( self, parameters: AlgorithmParameters, threshold: float = 1e4, interpolative: bool = False, ): self.threshold = threshold self.interpolative = interpolative if interpolative: self.norm_func = scipy.linalg.interpolative.estimate_spectral_norm self.norm_params = {} else: self.norm_func = scipy.linalg.norm self.norm_params = {"axis": (0, 1), "ord": 2} super().__init__(parameters=parameters)
[docs] def init(self) -> TolXUpState: """Initialize the associated NamedTuple.""" return TolXUpState()
[docs] def update( self, criterion_state: TolXUpState, state: AlgorithmState, population: npt.ArrayLike, loss: npt.ArrayLike, ) -> TolXUpState: """Return an updated associated NamedTuple. Args: criterion_state: The associated NamedTuple representing the current stopping criterion state. state: New algorithm state based on which to update the stopping criterion. population: New population based on which to update the stopping criterion. loss: New loss values based on which to update the stopping criterion. """ compare_to = state.step_size * self.norm_func( state.cholesky_factor, **self.norm_params ) if criterion_state.compare_to is None: latest_diff = 0.0 else: latest_diff = abs(criterion_state.latest_diff - compare_to) return TolXUpState(compare_to=compare_to, latest_diff=latest_diff)
[docs] def met(self, criterion_state: TolXUpState) -> bool: """Decide if the stopping criterion is met. Args: criterion_state: The associated NamedTuple representing the current stopping criterion state. """ return criterion_state.latest_diff > self.threshold