Source code for clinamen2.runner.basic_runner

"""Implementation of the Runner classes to be used with Dask."""
import logging
import os
import pathlib
import pickle
import subprocess
import tempfile
import time
from abc import ABC, abstractmethod
from collections import namedtuple
from typing import Callable, Optional

import dask.distributed
import jinja2
import numpy.typing as npt
from dask.distributed import Client, as_completed

# loss: float, information: dict
WorkerResult = namedtuple("WorkerResult", ["loss", "information"])


[docs] class Runner(ABC): """Abstract class to interface to any queue Inherit from this class to implement any loss evaluation. Usage: - instantiate - submit() or submit_batch(): Send structure(s) to Runner - pop(): Retrieve latest result from Runner Args: recreate_structure: Function that recreates a structure from a given 1D array. """ def __init__(self): super().__init__()
[docs] @abstractmethod def submit(self, individual: npt.ArrayLike): """Function to submit one structure to the Runner Args: individual: 1D array representing an individual structure. Returns: ID of input array """
[docs] @abstractmethod def pop(self): """Function to fetch results from the Runner Returns: - ID of input array - Calculation result """
[docs] def submit_batch(self, individuals: npt.ArrayLike): """Function that sequentially calls submit() for all individuals. Args: individuals: Input to be submitted. Returns: List of IDs of input arrays """ ids = [] for individual in individuals: ids.append(self.submit(individual)) return ids
[docs] class FunctionRunner(Runner): """Simple Runner for local function evaluation using Dask. Args: evaluator_function: Function that calculates the loss. Needs to return a tuple of (float, {}). workers: Number of workers to be started in parallel. scheduler_file: Path to file identifying the dask scheduler. convert_input: Function that takes the input array and returns the object that is expected by the Runner. Default is identity. """ def __init__( self, evaluator_function: Callable, workers: int, scheduler_file: Optional[pathlib.Path] = None, convert_input: Callable = lambda x: x, ): self.evaluator_function = evaluator_function client_params = { "threads_per_worker": 1, "n_workers": workers, "silence_logs": logging.ERROR, } if scheduler_file is not None: client_params["scheduler_file"] = scheduler_file self.dask_client = Client(**client_params) self.futures = as_completed([]) self.convert_input = convert_input super().__init__()
[docs] def submit(self, individual: npt.ArrayLike): future = self.dask_client.submit( self.evaluator_function, self.convert_input(individual) ) self.futures.add(future) return future.key
[docs] def pop(self): for future in self.futures: yield future
def __del__(self): self.dask_client.close()
[docs] class ScriptRunner(Runner): """Runner for distributed script evaluations using Dask. Usage: This class provides the functionality to evaluate individuals in a distributed manner using Dask workers. The provided script is executed on the workers using the script_run_command in a temporary directiory. This directory contains the output of the convert_input function saved with pickle and named "input". The script is expected to write a pickled WorkerResult object named "result", or in case of failure save the pickled exception as "result". Args: script_text: Text of the script to be executed on the workers. Before execution the script will be rendered with jinja using the given script_config as context. script_run_command: Command line command to execute the script on the workers, it has to contain {SCRIPTFILE}, which will be replaced by the file name of the actual script. E.g "python {SCRIPTFILE}" script_config: Dictionary of jinja keyword - value pairs to be used for script rendering. Default is None. convert_input: Function that takes the input array and returns the object that is expected by the Runner. Default is identity. scheduler_info_path: The path to the Dask scheduler descriptor file. Default is None, which starts the Dask scheduler locally. """ def __init__( self, script_text: str, script_run_command: str, script_config: Optional[dict] = None, convert_input: Callable = lambda x: x, scheduler_info_path: Optional[str] = None, ): env = jinja2.Environment() self.raw_script_text = script_text self.template = env.from_string(script_text) self.script_config = script_config self.script_run_command = script_run_command if scheduler_info_path is None: self.dask_client = Client() else: self.dask_client = Client(scheduler_file=scheduler_info_path) self.futures = as_completed([]) self.convert_input = convert_input super().__init__()
[docs] @staticmethod def script_driver(payload): script_text, script_run_command, data = payload worker = dask.distributed.get_worker() with tempfile.TemporaryDirectory( dir=worker.local_directory ) as foldername: with open(os.path.join(foldername, "script"), "w") as tmpscript: tmpscript.write(script_text) with open(os.path.join(foldername, "input"), "wb") as tmpdata: pickle.dump(data, tmpdata) proc = subprocess.Popen( script_run_command.format( SCRIPTFILE=os.path.join(foldername, "script") ), shell=True, cwd=foldername, ) while proc.poll() is None: time.sleep(1) with open(os.path.join(foldername, "result"), "rb") as resultfile: result = pickle.load(resultfile) if isinstance(result, Exception): raise result return result
[docs] def peek_script(self, script_config=None) -> str: """Function to check what the jinja script rendering would result in. Returns: Rendered script text. """ return self.template.render( self.script_config if script_config is None else script_config )
[docs] def submit(self, individual, script_config=None): future = self.dask_client.submit( self.script_driver, ( self.template.render( self.script_config if script_config is None else script_config ), self.script_run_command, self.convert_input(individual), ), ) self.futures.add(future) return future.key
[docs] def submit_batch(self, individuals: npt.ArrayLike, script_config=None): ids = [] for individual in individuals: ids.append(self.submit(individual, script_config)) return ids
[docs] def pop(self): for future in self.futures: yield future
def __del__(self): try: self.dask_client.close() except AttributeError: pass