diff --git a/.idea/runConfigurations/autora_theorist_basic_usage.xml b/.idea/runConfigurations/autora_theorist_basic_usage.xml new file mode 100644 index 000000000..25e5314c1 --- /dev/null +++ b/.idea/runConfigurations/autora_theorist_basic_usage.xml @@ -0,0 +1,25 @@ + + + + + \ No newline at end of file diff --git a/autora/controller/__main__.py b/autora/controller/__main__.py new file mode 100644 index 000000000..c7c73ba8d --- /dev/null +++ b/autora/controller/__main__.py @@ -0,0 +1,60 @@ +import logging +import os +import pathlib +from typing import Optional + +import typer +import yaml + +from autora.controller import Controller +from autora.theorist.__main__ import _configure_logger + +_logger = logging.getLogger(__name__) + + +def main( + manager: pathlib.Path=typer.Argument(..., help="Manager path"), + directory: pathlib.Path=typer.Argument(..., help="Directory path"), + step_name: Optional[str] = typer.Argument(None, help="Name of step"), + verbose: bool = typer.Option(False, help="Turns on info logging level."), + debug: bool = typer.Option(False, help="Turns on debug logging level."), +): + _logger.debug("initializing") + _configure_logger(debug, verbose) + controller_ = _load_manager(manager) + + _logger.debug(f"loading manager state from {directory=}") + controller_.load(directory) + + if step_name is not None: + controller_ = _set_next_step_name(controller_, step_name) + + _logger.info("running next step") + next(controller_) + + _logger.debug(f"last result: {controller_.state.history[-1]}") + + _logger.info("writing out results") + controller_.dump(directory) + + return + + +def _load_manager(path: pathlib.Path) -> Controller: + _logger.debug(f"_load_manager: loading from {path=} (currently in {os.getcwd()})") + with open(path, "r") as f: + controller_ = yaml.load(f, yaml.Loader) + assert isinstance( + controller_, Controller + ), f"controller type {type(controller_)=} unsupported" + return controller_ + + +def _set_next_step_name(controller: Controller, step_name: str): + _logger.info(f"setting next {step_name=}") + controller.planner = lambda _: step_name + return controller + + +if __name__ == "__main__": + typer.run(main) diff --git a/autora/controller/base.py b/autora/controller/base.py index fcb661698..a6852b824 100644 --- a/autora/controller/base.py +++ b/autora/controller/base.py @@ -2,16 +2,14 @@ from __future__ import annotations import logging -from typing import Callable, Mapping, Optional, TypeVar - -_logger = logging.getLogger(__name__) +from typing import Callable, Generic, Mapping, Optional +from autora.controller.protocol import State -State = TypeVar("State") -ExecutorName = TypeVar("ExecutorName", bound=str) +_logger = logging.getLogger(__name__) -class BaseController: +class BaseController(Generic[State]): """ Runs an experimentalist, theorist and experiment runner in a loop. @@ -36,8 +34,8 @@ class BaseController: def __init__( self, state: State, - planner: Callable[[State], ExecutorName], - executor_collection: Mapping[ExecutorName, Callable[[State], State]], + planner: Callable[[State], str], + executor_collection: Mapping[str, Callable[[State], State]], monitor: Optional[Callable[[State], None]] = None, ): """ diff --git a/autora/controller/controller.py b/autora/controller/controller.py index 5ba2d8427..dbf6b8f9c 100644 --- a/autora/controller/controller.py +++ b/autora/controller/controller.py @@ -2,13 +2,15 @@ from __future__ import annotations import logging +import pathlib from typing import Callable, Dict, Optional from sklearn.base import BaseEstimator -from autora.controller.base import BaseController, ExecutorName +from autora.controller.base import BaseController from autora.controller.executor import make_online_executor_collection from autora.controller.planner import last_result_kind_planner +from autora.controller.serializer import HistorySerializer from autora.controller.state import History from autora.experimentalist.pipeline import Pipeline from autora.variable import VariableCollection @@ -16,7 +18,7 @@ _logger = logging.getLogger(__name__) -class Controller(BaseController): +class Controller(BaseController[History]): """ Runs an experimentalist, experiment runner, and theorist in order. @@ -33,13 +35,13 @@ class Controller(BaseController): def __init__( self, - metadata: Optional[VariableCollection], + metadata: Optional[VariableCollection] = None, theorist: Optional[BaseEstimator] = None, experimentalist: Optional[Pipeline] = None, experiment_runner: Optional[Callable] = None, params: Optional[Dict] = None, monitor: Optional[Callable[[History], None]] = None, - planner: Callable[[History], ExecutorName] = last_result_kind_planner, + planner: Callable[[History], str] = last_result_kind_planner, ): """ Args: @@ -102,3 +104,14 @@ def __init__( def seed(self, **kwargs): for key, value in kwargs.items(): self.state = self.state.update(**{key: value}) + + def load(self, directory: pathlib.Path): + serializer = HistorySerializer(directory) + state = serializer.load() + self.state = state + return + + def dump(self, directory: pathlib.Path): + serializer = HistorySerializer(directory) + serializer.dump(self.state) + return diff --git a/autora/controller/executor.py b/autora/controller/executor.py index 27f0c7615..2a170e0d6 100644 --- a/autora/controller/executor.py +++ b/autora/controller/executor.py @@ -6,11 +6,13 @@ import copy import logging +import pprint from functools import partial from types import MappingProxyType from typing import Callable, Dict, Iterable, Literal, Optional, Tuple, Union import numpy as np +import pandas as pd from sklearn.base import BaseEstimator from autora.controller.protocol import SupportsControllerState @@ -27,15 +29,27 @@ def experimentalist_wrapper( params_ = resolve_state_params(params, state) new_conditions = pipeline(**params_) - assert isinstance(new_conditions, Iterable) - # If the pipeline gives us an iterable, we need to make it into a concrete array. - # We can't move this logic to the Pipeline, because the pipeline doesn't know whether - # it's within another pipeline and whether it should convert the iterable to a - # concrete array. - new_conditions_values = list(new_conditions) - new_conditions_array = np.array(new_conditions_values) + if isinstance(new_conditions, pd.DataFrame): + new_conditions_array = new_conditions + elif isinstance(new_conditions, np.ndarray): + _logger.warning( + f"{new_conditions=} is an ndarray, so variable confusion is a possibility" + ) + new_conditions_array = new_conditions + elif isinstance(new_conditions, np.recarray): + new_conditions_array = new_conditions + elif isinstance(new_conditions, Iterable): + # If the pipeline gives us an iterable, we need to make it into a concrete array. + # We can't move this logic to the Pipeline, because the pipeline doesn't know whether + # it's within another pipeline and whether it should convert the iterable to a + # concrete array. + new_conditions_values = list(new_conditions) + new_conditions_array = np.array(new_conditions_values) + else: + raise NotImplementedError( + f"Can't handle experimentalist output {new_conditions=}" + ) - assert isinstance(new_conditions_array, np.ndarray) # Check the object is bounded new_state = state.update(conditions=[new_conditions_array]) return new_state @@ -46,8 +60,15 @@ def experiment_runner_wrapper( """Interface for running the experiment runner callable.""" params_ = resolve_state_params(params, state) x = state.conditions[-1] - y = callable(x, **params_) - new_observations = np.column_stack([x, y]) + output = callable(x, **params_) + + if isinstance(x, pd.DataFrame): + new_observations = output + elif isinstance(x, np.ndarray): + new_observations = np.column_stack([x, output]) + else: + raise NotImplementedError(f"type {x=} not supported") + new_state = state.update(observations=[new_observations]) return new_state @@ -59,13 +80,35 @@ def theorist_wrapper( params_ = resolve_state_params(params, state) metadata = state.metadata observations = state.observations - all_observations = np.row_stack(observations) - n_xs = len(metadata.independent_variables) - x, y = all_observations[:, :n_xs], all_observations[:, n_xs:] - if y.shape[1] == 1: - y = y.ravel() + + if isinstance(observations[-1], pd.DataFrame): + all_observations = pd.concat(observations) + iv_names = [iv.name for iv in metadata.independent_variables] + dv_names = [dv.name for dv in metadata.dependent_variables] + x, y = all_observations[iv_names], all_observations[dv_names] + elif isinstance(observations[-1], np.ndarray): + all_observations = np.row_stack(observations) + n_xs = len(metadata.independent_variables) + x, y = all_observations[:, :n_xs], all_observations[:, n_xs:] + if y.shape[1] == 1: + y = y.ravel() + else: + raise NotImplementedError(f"type {observations[-1]=} not supported") + new_theorist = copy.deepcopy(estimator) new_theorist.fit(x, y, **params_) + + try: + _logger.debug( + f"fitted {new_theorist=}\nnew_theorist.__dict__:" + f"\n{pprint.pformat(new_theorist.__dict__)}" + ) + except AttributeError: + _logger.debug( + f"fitted {new_theorist=} " + f"new_theorist has no __dict__ attribute, so no results are shown" + ) + new_state = state.update(theories=[new_theorist]) return new_state diff --git a/autora/controller/protocol.py b/autora/controller/protocol.py index 1285f1f07..fa66130fb 100644 --- a/autora/controller/protocol.py +++ b/autora/controller/protocol.py @@ -128,9 +128,7 @@ def __call__(self, __state: State, __params: Dict) -> State: ... -ExecutorName = TypeVar("ExecutorName", bound=str) - -ExecutorCollection = Mapping[ExecutorName, Executor] +ExecutorCollection = Mapping[str, Executor] @runtime_checkable diff --git a/autora/controller/serializer/__init__.py b/autora/controller/serializer/__init__.py index 319d3432c..eba7007a7 100644 --- a/autora/controller/serializer/__init__.py +++ b/autora/controller/serializer/__init__.py @@ -1,15 +1,12 @@ import pickle import tempfile +from abc import abstractmethod from pathlib import Path -from typing import Mapping, NamedTuple, Optional, Type, Union +from typing import Generic, Mapping, NamedTuple, Union import numpy as np -from autora.controller.protocol import ( - ResultKind, - SupportsControllerStateHistory, - SupportsLoadDump, -) +from autora.controller.protocol import ResultKind, State, SupportsLoadDump from autora.controller.serializer import yaml_ as YAMLSerializer from autora.controller.state import History @@ -25,7 +22,17 @@ class _LoadSpec(NamedTuple): mode: str -class HistorySerializer: +class StateSerializer(Generic[State]): + @abstractmethod + def load(self) -> State: + ... + + @abstractmethod + def dump(self, ___state: State): + ... + + +class HistorySerializer(StateSerializer[History]): """Serializes and deserializes History objects.""" def __init__( @@ -51,7 +58,7 @@ def __init__( ".pickle": _LoadSpec(pickle, "rb"), } - def dump(self, data_collection: SupportsControllerStateHistory): + def dump(self, data_collection: History): """ Args: @@ -129,7 +136,7 @@ def dump(self, data_collection: SupportsControllerStateHistory): with open(Path(path, filename), mode) as f: serializer.dump(container, f) - def load(self, cls: Type[SupportsControllerStateHistory] = History): + def load(self) -> History: """ Examples: @@ -185,7 +192,7 @@ def load(self, cls: Type[SupportsControllerStateHistory] = History): loaded_object = serializer.load(f) data.append(loaded_object) - data_collection = cls(history=data) + data_collection = History(history=data) return data_collection diff --git a/autora/theorist/__main__.py b/autora/theorist/__main__.py new file mode 100644 index 000000000..1ab6aeb81 --- /dev/null +++ b/autora/theorist/__main__.py @@ -0,0 +1,116 @@ +import logging +import pathlib +import pickle +import pprint +from typing import Dict, Optional + +import pandas as pd +import typer +import yaml +from pandas import DataFrame +from sklearn.base import BaseEstimator + +from autora.variable import VariableCollection + +_logger = logging.getLogger(__name__) + + +def main( + variables: pathlib.Path, + regressor: pathlib.Path, + parameters: pathlib.Path, + data: pathlib.Path, + output: pathlib.Path, + verbose: bool = False, + debug: bool = False, + overwrite: bool = False, +): + # Initialization + _configure_logger(debug, verbose) + + # Data Loading + variables_ = _load_variables(variables) + parameters_ = _load_parameters(parameters) + regressor_ = _load_regressor(regressor) + data_ = _load_data(data) + + # Fitting + model = _fit_model(data_, parameters_, regressor_, variables_) + + # Writing results + _dump_model(model, output, overwrite) + + return + + +def _configure_logger(debug, verbose): + if debug: + logging.basicConfig(level=logging.DEBUG) + _logger.debug("using DEBUG logging level") + if verbose: + logging.basicConfig(level=logging.INFO) + _logger.info("using INFO logging level") + + +def _load_variables(path: pathlib.Path) -> VariableCollection: + _logger.debug(f"load_variables: loading from {path=}") + variables_: VariableCollection + with open(path, "r") as fv: + variables_ = yaml.load(fv, yaml.Loader) + assert isinstance(variables_, VariableCollection) + return variables_ + + +def _load_parameters(path: pathlib.Path) -> Dict: + _logger.debug(f"load_parameters: loading from {path=}") + with open(path, "r") as fp: + parameters_: Optional[Dict] = yaml.load(fp, yaml.Loader) + if parameters_ is None: + parameters_ = dict() + return parameters_ + + +def _load_regressor(path: pathlib.Path) -> BaseEstimator: + with open(path, "r") as f: + regressor_ = yaml.load(f, yaml.Loader) + return regressor_ + + +def _load_data(data: pathlib.Path) -> DataFrame: + _logger.debug(f"load_data: loading from {data=}") + with open(data, "r") as fd: + data_: DataFrame = pd.read_csv(fd) + return data_ + + +def _fit_model(data, parameters, regressor, variables) -> BaseEstimator: + model = regressor.set_params(**parameters) + x = data[[v.name for v in variables.independent_variables]] + y = data[[v.name for v in variables.dependent_variables]] + _logger.debug(f"fitting the regressor with x's:\n{x}\nand y's:\n{y}") + model.fit(x, y) + try: + _logger.info( + f"fitted {model=}\nmodel.__dict__:" f"\n{pprint.pformat(model.__dict__)}" + ) + except AttributeError: + _logger.warning( + f"fitted {model=} " + f"model has no __dict__ attribute, so no results are shown" + ) + return model + + +def _dump_model(model, output, overwrite): + if overwrite: + mode = "wb" + _logger.info(f"overwriting {output=} if it already exists") + else: + mode = "xb" + _logger.info(f"writing to new file {output=}") + with open(output, mode) as o: + pickle.dump(model, o) + + +if __name__ == "__main__": + typer.run(main) diff --git a/autora/theorist/darts/plot_utils.py b/autora/theorist/darts/plot_utils.py deleted file mode 100755 index d256bf412..000000000 --- a/autora/theorist/darts/plot_utils.py +++ /dev/null @@ -1,1129 +0,0 @@ -import os -import typing -from typing import Optional - -import imageio -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pandas -import seaborn as sns -import torch.nn -from matplotlib import pyplot -from matplotlib.gridspec import GridSpec - -import autora.config as aer_config -import autora.theorist.darts.darts_config as darts_config -from autora.theorist.object_of_study import Object_Of_Study - - -def generate_darts_summary_figures( - figure_names: typing.List[str], - titles: typing.List[str], - filters: typing.List[str], - title_suffix: str, - study_name: str, - y_name: str, - y_label: str, - y_sem_name: str, - x1_name: str, - x1_label: str, - x2_name: str, - x2_label: str, - x_limit: typing.List[float], - y_limit: typing.List[float], - best_model_name: str, - figure_size: typing.Tuple[int, int], - y_reference: Optional[typing.List[float]] = None, - y_reference_label: str = "", - arch_samp_filter: Optional[str] = None, -): - """ - Generates a summary figure for a given DARTS study. - The figure can be composed of different summary plots. - - Arguments: - figure_names: list of strings with the names of the figures to be generated - titles: list of strings with the titles of the figures to be generated - filters: list of strings with the theorist filters to be used to select the models to be - used in the figures - title_suffix: string with the suffix to be added to the titles of the figures - study_name: string with the name of the study (used to identify the study folder) - y_name: string with the name of the y-axis variable - y_label: string with the label of the y-axis variable - y_sem_name: string with the name of the y-axis coding the standard error of the mean - x1_name: string with the name of the (first) x-axis variable - x1_label: string with the label of the (first) x-axis variable - x2_name: string with the name of the second x-axis variable - x2_label: string with the label of the second x-axis variable - x_limit: list with the limits of the x-axis - y_limit: list with the limits of the y-axis - best_model_name: string with the name of the best model to be highlighted in the figure - figure_size: list with the size of the figure - y_reference: list with the values of the reference line - y_reference_label: string with the label of the reference line - arch_samp_filter: string with the name of the filter to be used to select the - samples of the architecture - - """ - - for idx, (figure_name, title, theorist_filter) in enumerate( - zip(figure_names, titles, filters) - ): - - print("##########################: " + figure_name) - title = title + title_suffix - if idx > 0: # after legend - show_legend = False - figure_dimensions = figure_size - else: - show_legend = True - figure_dimensions = (6, 6) - if idx > 1: # after original darts - y_label = " " - - plot_darts_summary( - study_name=study_name, - title=title, - y_name=y_name, - y_label=y_label, - y_sem_name=y_sem_name, - x1_name=x1_name, - x1_label=x1_label, - x2_name=x2_name, - x2_label=x2_label, - metric="mean_min", - x_limit=x_limit, - y_limit=y_limit, - best_model_name=best_model_name, - theorist_filter=theorist_filter, - arch_samp_filter=arch_samp_filter, - figure_name=figure_name, - figure_dimensions=figure_dimensions, - legend_loc=aer_config.legend_loc, - legend_font_size=aer_config.legend_font_size, - axis_font_size=aer_config.axis_font_size, - title_font_size=aer_config.title_font_size, - show_legend=show_legend, - y_reference=y_reference, - y_reference_label=y_reference_label, - save=True, - ) - - -def plot_darts_summary( - study_name: str, - y_name: str, - x1_name: str, - x2_name: str = "", - y_label: str = "", - x1_label: str = "", - x2_label: str = "", - y_sem_name: Optional[str] = None, - metric: str = "min", - y_reference: Optional[typing.List[float]] = None, - y_reference_label: str = "", - figure_dimensions: Optional[typing.Tuple[int, int]] = None, - title: str = "", - legend_loc: int = 0, - legend_font_size: int = 8, - axis_font_size: int = 10, - title_font_size: int = 10, - show_legend: bool = True, - y_limit: Optional[typing.List[float]] = None, - x_limit: Optional[typing.List[float]] = None, - theorist_filter: Optional[str] = None, - arch_samp_filter: Optional[str] = None, - best_model_name: Optional[str] = None, - save: bool = False, - figure_name: str = "figure", -): - """ - Generates a single summary plot for a given DARTS study. - - Arguments: - study_name: string with the name of the study (used to identify the study folder) - y_name: string with the name of the y-axis variable - x1_name: string with the name of the (first) x-axis variable - x2_name: string with the name of the second x-axis variable - y_label: string with the label of the y-axis variable - x1_label: string with the label of the (first) x-axis variable - x2_label: string with the label of the second x-axis variable - y_sem_name: string with the name of the y-axis coding the standard error of the mean - metric: string with the metric to be used to select the best model - y_reference: list with the values of the reference line - y_reference_label: string with the label of the reference line - figure_dimensions: list with the size of the figure - title: string with the title of the figure - legend_loc: integer with the location of the legend - legend_font_size: integer with the font size of the legend - axis_font_size: integer with the font size of the axis - title_font_size: integer with the font size of the title - show_legend: boolean with the flag to show the legend - y_limit: list with the limits of the y-axis - x_limit: list with the limits of the x-axis - theorist_filter: string with the name of the filter to be used to select the theorist - arch_samp_filter: string with the name of the filter to be used to select the architecture - best_model_name: string with the name of the best model to be highlighted in the figure - save: boolean with the flag to save the figure - figure_name: string with the name of the figure - """ - - palette = "PuBu" - - if figure_dimensions is None: - figure_dimensions = (4, 3) - - if y_label == "": - y_label = y_name - - if x1_label == "": - x1_label = x1_name - - if x2_label == "": - x2_label = x2_name - - if y_reference_label == "": - y_reference_label = "Data Generating Model" - - # determine directory for study results and figures - results_path = ( - aer_config.studies_folder - + study_name - + "/" - + aer_config.models_folder - + aer_config.models_results_folder - ) - - figures_path = ( - aer_config.studies_folder - + study_name - + "/" - + aer_config.models_folder - + aer_config.models_results_figures_folder - ) - - # read in all csv files - files = list() - for file in os.listdir(results_path): - if file.endswith(".csv"): - if "model_" not in file: - continue - - if theorist_filter is not None: - if theorist_filter not in file: - continue - files.append(os.path.join(results_path, file)) - - print("Found " + str(len(files)) + " files.") - - # generate a plot dictionary - plot_dict: typing.Dict[typing.Optional[str], typing.List] = dict() - plot_dict[darts_config.csv_arch_file_name] = list() - plot_dict[y_name] = list() - plot_dict[x1_name] = list() - if x2_name != "": - plot_dict[x2_name] = list() - if y_sem_name is not None: - plot_dict[y_sem_name] = list() - - # load csv files into a common dictionary - for file in files: - data = pandas.read_csv(file, header=0) - - valid_data = list() - - # filter for arch samp - if arch_samp_filter is not None: - for idx, arch_file_name in enumerate(data[darts_config.csv_arch_file_name]): - arch_samp = int( - float(arch_file_name.split("_sample", 1)[1].split("_", 1)[0]) - ) - if arch_samp == arch_samp_filter: - valid_data.append(idx) - else: - for idx in range(len(data[darts_config.csv_arch_file_name])): - valid_data.append(idx) - - plot_dict[darts_config.csv_arch_file_name].extend( - data[darts_config.csv_arch_file_name][valid_data] - ) - if y_name in data.keys(): - plot_dict[y_name].extend(data[y_name][valid_data]) - else: - raise Exception( - 'Could not find key "' + y_name + '" in the data file: ' + str(file) - ) - if x1_name in data.keys(): - plot_dict[x1_name].extend(data[x1_name][valid_data]) - else: - raise Exception( - 'Could not find key "' + x1_name + '" in the data file: ' + str(file) - ) - if x2_name != "": - if x2_name in data.keys(): - plot_dict[x2_name].extend(data[x2_name][valid_data]) - else: - raise Exception( - 'Could not find key "' - + x2_name - + '" in the data file: ' - + str(file) - ) - if y_sem_name is not None: - # extract seed number from model file name - - if y_sem_name in data.keys(): - plot_dict[y_sem_name].extend(data[y_sem_name]) - elif y_sem_name == "seed": - y_sem_list = list() - for file_name in data[darts_config.csv_arch_file_name][valid_data]: - y_sem_list.append( - int(float(file_name.split("_s_", 1)[1].split("_sample", 1)[0])) - ) - plot_dict[y_sem_name].extend(y_sem_list) - - else: - - raise Exception( - 'Could not find key "' - + y_sem_name - + '" in the data file: ' - + str(file) - ) - - model_name_list = plot_dict[darts_config.csv_arch_file_name] - x1_data = np.asarray(plot_dict[x1_name]) - y_data = np.asarray(plot_dict[y_name]) - if x2_name == "": # determine for each value of x1 the corresponding y - x1_data = np.asarray(plot_dict[x1_name]) - x1_unique = np.sort(np.unique(x1_data)) - - y_plot = np.empty(x1_unique.shape) - y_plot[:] = np.nan - y_sem_plot = np.empty(x1_unique.shape) - y_sem_plot[:] = np.nan - y2_plot = np.empty(x1_unique.shape) - y2_plot[:] = np.nan - x1_plot = np.empty(x1_unique.shape) - x1_plot[:] = np.nan - for idx_unique, x1_unique_val in enumerate(x1_unique): - y_match = list() - model_name_match = list() - for idx_data, x_data_val in enumerate(x1_data): - if x1_unique_val == x_data_val: - y_match.append(y_data[idx_data]) - model_name_match.append(model_name_list[idx_data]) - x1_plot[idx_unique] = x1_unique_val - - if metric == "min": - y_plot[idx_unique] = np.min(y_match) - idx_target = np.argmin(y_match) - legend_label_spec = " (min)" - elif metric == "max": - y_plot[idx_unique] = np.max(y_match) - idx_target = np.argmax(y_match) - legend_label_spec = " (max)" - elif metric == "mean": - y_plot[idx_unique] = np.mean(y_match) - idx_target = 0 - legend_label_spec = " (avg)" - elif metric == "mean_min": - y_plot[idx_unique] = np.mean(y_match) - y2_plot[idx_unique] = np.min(y_match) - idx_target = np.argmin(y_match) - legend_label_spec = " (avg)" - legend_label2_spec = " (min)" - elif metric == "mean_max": - y_plot[idx_unique] = np.mean(y_match) - y2_plot[idx_unique] = np.max(y_match) - idx_target = np.argmax(y_match) - legend_label_spec = " (avg)" - legend_label2_spec = " (max)" - else: - raise Exception( - 'Argument "metric" may either be "min", "max", "mean", "mean_min" or "min_max".' - ) - - # compute standard error along given dimension - if y_sem_name is not None: - y_sem_data = np.asarray(plot_dict[y_sem_name]) - y_sem_unique = np.sort(np.unique(y_sem_data)) - y_sem = np.empty(y_sem_unique.shape) - # first average y over all other variables - for idx_y_sem_unique, y_sem_unique_val in enumerate(y_sem_unique): - y_sem_match = list() - for idx_y_sem, ( - y_sem_data_val, - x1_data_val, - y_data_val, - ) in enumerate(zip(y_sem_data, x1_data, y_data)): - if ( - y_sem_unique_val == y_sem_data_val - and x1_unique_val == x1_data_val - ): - y_sem_match.append(y_data_val) - y_sem[idx_y_sem_unique] = np.mean(y_sem_match) - # now compute sem - y_sem_plot[idx_unique] = np.nanstd(y_sem) / np.sqrt(len(y_sem)) - - print( - x1_label - + " = " - + str(x1_unique_val) - + " (" - + str(y_plot[idx_unique]) - + "): " - + model_name_match[idx_target] - ) - - else: # determine for each combination of x1 and x2 (unique rows) the lowest y - x2_data = np.asarray(plot_dict[x2_name]) - x2_unique = np.sort(np.unique(x2_data)) - - y_plot = list() - y_sem_plot = list() - y2_plot = list() - x1_plot = list() - x2_plot = list() - for idx_x2_unique, x2_unique_val in enumerate(x2_unique): - - # collect all x1 and y values matching the current x2 value - model_name_x2_match = list() - y_x2_match = list() - x1_x2_match = list() - for idx_x2_data, x2_data_val in enumerate(x2_data): - if x2_unique_val == x2_data_val: - model_name_x2_match.append(model_name_list[idx_x2_data]) - y_x2_match.append(y_data[idx_x2_data]) - x1_x2_match.append(x1_data[idx_x2_data]) - - # now determine unique x1 values for current x2 value - x1_unique = np.sort(np.unique(x1_x2_match)) - x1_x2_plot = np.empty(x1_unique.shape) - x1_x2_plot[:] = np.nan - y_x2_plot = np.empty(x1_unique.shape) - y_x2_plot[:] = np.nan - y_sem_x2_plot = np.empty(x1_unique.shape) - y_sem_x2_plot[:] = np.nan - y2_x2_plot = np.empty(x1_unique.shape) - y2_x2_plot[:] = np.nan - for idx_x1_unique, x1_unique_val in enumerate(x1_unique): - y_x2_x1_match = list() - model_name_x2_x1_match = list() - for idx_x1_data, x1_data_val in enumerate(x1_x2_match): - if x1_unique_val == x1_data_val: - model_name_x2_x1_match.append(model_name_x2_match[idx_x1_data]) - y_x2_x1_match.append(y_x2_match[idx_x1_data]) - x1_x2_plot[idx_x1_unique] = x1_unique_val - - if metric == "min": - y_x2_plot[idx_x1_unique] = np.min(y_x2_x1_match) - idx_target = np.argmin(y_x2_x1_match) - legend_label_spec = " (min)" - elif metric == "max": - y_x2_plot[idx_x1_unique] = np.max(y_x2_x1_match) - idx_target = np.argmax(y_x2_x1_match) - legend_label_spec = " (max)" - elif metric == "mean": - y_x2_plot[idx_x1_unique] = np.mean(y_x2_x1_match) - idx_target = 0 - legend_label_spec = " (avg)" - elif metric == "mean_min": - y_x2_plot[idx_x1_unique] = np.mean(y_x2_x1_match) - y2_x2_plot[idx_x1_unique] = np.min(y_x2_x1_match) - idx_target = np.argmin(y_x2_x1_match) - legend_label_spec = " (avg)" - legend_label2_spec = " (min)" - elif metric == "mean_max": - y_x2_plot[idx_x1_unique] = np.mean(y_x2_x1_match) - y2_x2_plot[idx_x1_unique] = np.max(y_x2_x1_match) - idx_target = np.argmax(y_x2_x1_match) - legend_label_spec = " (avg)" - legend_label2_spec = " (max)" - else: - raise Exception( - 'Argument "metric" may either be "min", "max", "mean", ' - '"mean_min" or "min_max".' - ) - - # compute standard error along given dimension - if y_sem_name is not None: - y_sem_data = np.asarray(plot_dict[y_sem_name]) - y_sem_unique = np.sort(np.unique(y_sem_data)) - y_sem = np.empty(y_sem_unique.shape) - # first average y over all other variables - for idx_y_sem_unique, y_sem_unique_val in enumerate(y_sem_unique): - y_sem_match = list() - for idx_y_sem, ( - y_sem_data_val, - x1_data_val, - x2_data_val, - y_data_val, - ) in enumerate(zip(y_sem_data, x1_data, x2_data, y_data)): - if ( - y_sem_unique_val == y_sem_data_val - and x1_unique_val == x1_data_val - and x2_unique_val == x2_data_val - ): - y_sem_match.append(y_data_val) - y_sem[idx_y_sem_unique] = np.nanmean(y_sem_match) - # now compute sem - y_sem_x2_plot[idx_x1_unique] = np.nanstd(y_sem) / np.sqrt( - len(y_sem) - ) - - if metric == "mean_min" or metric == "mean_max": - best_val_str = str(y2_x2_plot[idx_x1_unique]) - else: - best_val_str = str(y_x2_plot[idx_x1_unique]) - - print( - x1_label - + " = " - + str(x1_unique_val) - + ", " - + x2_label - + " = " - + str(x2_unique_val) - + " (" - + best_val_str - + "): " - + model_name_x2_x1_match[idx_target] - ) - - y_plot.append(y_x2_plot) - y2_plot.append(y2_x2_plot) - y_sem_plot.append(y_sem_x2_plot) - x1_plot.append(x1_x2_plot) - x2_plot.append(x2_unique_val) - # plot - # plt.axhline - - # determine best model coordinates - best_model_x1 = None - best_model_x2 = None - best_model_y = None - if best_model_name is not None: - theorist = best_model_name.split("weights_", 1)[1].split("_v_", 1)[0] - if theorist_filter is not None: - if theorist_filter == theorist: - determine_best_model = True - else: - determine_best_model = False - else: - determine_best_model = True - - if determine_best_model: - idx = plot_dict[darts_config.csv_arch_file_name].index(best_model_name) - best_model_x1 = plot_dict[x1_name][idx] - best_model_x2 = plot_dict[x2_name][idx] - best_model_y = plot_dict[y_name][idx] - - fig, ax = pyplot.subplots(figsize=figure_dimensions) - - if x2_name == "": - - colors = sns.color_palette(palette, 10) - color = colors[-1] - full_label = "Reconstructed Model" + legend_label_spec - sns.lineplot( - x=x1_plot, - y=y_plot, - marker="o", - linewidth=2, - ax=ax, - label=full_label, - color=color, - ) - - # draw error bars - if y_sem_name is not None: - ax.errorbar(x=x1_plot, y=y_plot, yerr=y_sem_plot, color=color) - - # draw second y value - if metric == "mean_min" or metric == "mean_max": - full_label = "Reconstructed Model" + legend_label2_spec - ax.plot(x1_plot, y2_plot, "*", linewidth=2, label=full_label, color=color) - - if show_legend: - handles, _ = ax.get_legend_handles_labels() - ax.legend(handles=handles, loc=legend_loc) - plt.setp(ax.get_legend().get_texts(), fontsize=legend_font_size) - - # draw selected model - if best_model_x1 is not None and best_model_y is not None: - ax.plot( - best_model_x1, - best_model_y, - "o", - fillstyle="none", - color="black", - markersize=10, - ) - - ax.set_xlabel(x1_label, fontsize=axis_font_size) - ax.set_ylabel(y_label, fontsize=axis_font_size) - ax.set_title(title, fontsize=title_font_size) - - if y_limit is not None: - ax.set_ylim(y_limit) - - if x_limit is not None: - ax.set_xlim(x_limit) - - # generate legend - # ax.scatter(x1_plot, y_plot, marker='.', c='r') - # g = sns.relplot(data=data_plot, x=x1_label, y=y_label, ax=ax) - # g._legend.remove() - if y_reference is not None: - ax.axhline( - y_reference, c="black", linestyle="dashed", label=y_reference_label - ) - - if show_legend: - # generate legend - handles, _ = ax.get_legend_handles_labels() - ax.legend(handles=handles, loc=legend_loc) - plt.setp(ax.get_legend().get_texts(), fontsize=legend_font_size) - else: - - colors = sns.color_palette(palette, len(x2_plot)) - - for idx, x2 in enumerate(x2_plot): - - x1_plot_line = x1_plot[idx] - y_plot_line = y_plot[idx] - label = x2_label + "$ = " + str(x2) + "$" + legend_label_spec - color = colors[idx] - - sns.lineplot( - x=x1_plot_line, - y=y_plot_line, - marker="o", - linewidth=2, - ax=ax, - label=label, - color=color, - alpha=1, - ) - - # draw error bars - if y_sem_name is not None: - y_sem_plot_line = y_sem_plot[idx] - ax.errorbar( - x=x1_plot_line, - y=y_plot_line, - yerr=y_sem_plot_line, - color=color, - alpha=1, - ) - - # # draw second y value on top - # for idx, x2 in enumerate(x2_plot): - # x1_plot_line = x1_plot[idx] - # color = colors[idx] - # - # if metric == 'mean_min' or metric == 'mean_max': - # y2_plot_line = y2_plot[idx] - # label = x2_label + '$ = ' + str(x2) + "$" + legend_label2_spec - # ax.plot(x1_plot_line, y2_plot_line, '*', linewidth=2, label=label, color=color) - - # draw selected model - if best_model_x1 is not None and best_model_y is not None: - ax.plot( - best_model_x1, - best_model_y, - "o", - fillstyle="none", - color="black", - markersize=10, - ) - - for idx, x2 in enumerate(x2_plot): - if best_model_x2 == x2: - color = colors[idx] - ax.plot( - best_model_x1, - best_model_y, - "*", - linewidth=2, - label="Best Model", - color=color, - ) - - if y_reference is not None: - ax.axhline( - y_reference, c="black", linestyle="dashed", label=y_reference_label - ) - - handles, _ = ax.get_legend_handles_labels() - leg = ax.legend( - handles=handles, loc=legend_loc, bbox_to_anchor=(1.05, 1) - ) # , title='Legend' - plt.setp(ax.get_legend().get_texts(), fontsize=legend_font_size) - - if not show_legend: - leg.remove() - - if y_limit is not None: - ax.set_ylim(y_limit) - - if x_limit is not None: - ax.set_xlim(x_limit) - - sns.despine(trim=True) - ax.set_ylabel(y_label, fontsize=axis_font_size) - ax.set_xlabel(x1_label, fontsize=axis_font_size) - ax.set_title(title, fontsize=title_font_size) - plt.show() - - # save plot - if save: - if not os.path.exists(figures_path): - os.mkdir(figures_path) - fig.savefig(os.path.join(figures_path, figure_name)) - - -def plot_model_graph( - study_name: str, - arch_weights_name: str, - model_weights_name: str, - object_of_study: Object_Of_Study, - figure_name: str = "graph", -): - """ - Plot the graph of the DARTS model. - - Arguments: - study_name: name of the study (used to identify the relevant study folder) - arch_weights_name: name of the architecture weights file - model_weights_name: name of the model weights file (that contains the trained parameters) - object_of_study: name of the object of study - figure_name: name of the figure - """ - - import os - - import autora.theorist.darts.utils as utils - import autora.theorist.darts.visualize as viz - - figures_path = ( - aer_config.studies_folder - + study_name - + "/" - + aer_config.models_folder - + aer_config.models_results_figures_folder - ) - - model = load_model( - study_name, model_weights_name, arch_weights_name, object_of_study - ) - - (n_params_total, n_params_base, param_list) = model.countParameters( - print_parameters=True - ) - genotype = model.genotype() - filepath = os.path.join(figures_path, figure_name) - viz.plot( - genotype.normal, - filepath, - file_format="png", - view_file=True, - full_label=True, - param_list=param_list, - input_labels=object_of_study.__get_input_labels__(), - out_dim=object_of_study.__get_output_dim__(), - out_fnc=utils.get_output_str(object_of_study.__get_output_type__()), - ) - - -# old - - -def load_model( - study_name: str, - model_weights_name: str, - arch_weights_name: str, - object_of_study: Object_Of_Study, -) -> torch.nn.Module: - """ - Load the model. - - Arguments: - study_name: name of the study (used to identify the relevant study folder) - model_weights_name: name of the model weights file (that contains the trained parameters) - arch_weights_name: name of the architecture weights file - object_of_study: name of the object of study - - Returns: - model: DARTS model - """ - - import os - - import torch - - import autora.theorist.darts.utils as utils - from autora.theorist.darts.model_search import Network - - num_output = object_of_study.__get_output_dim__() - num_input = object_of_study.__get_input_dim__() - k = int(float(arch_weights_name.split("_k_", 1)[1].split("_s_", 1)[0])) - - results_weights_path = ( - aer_config.studies_folder - + study_name - + "/" - + aer_config.models_folder - + aer_config.models_results_weights_folder - ) - - model_path = os.path.join(results_weights_path, model_weights_name + ".pt") - arch_path = os.path.join(results_weights_path, arch_weights_name + ".pt") - criterion = utils.sigmid_mse - model = Network(num_output, criterion, steps=k, n_input_states=num_input) - utils.load(model, model_path) - alphas_normal = torch.load(arch_path) - model.fix_architecture(True, new_weights=alphas_normal) - - return model - - -class DebugWindow: - """ - A window with plots that are used for debugging. - """ - - def __init__( - self, - num_epochs: int, - numArchEdges: int = 1, - numArchOps: int = 1, - ArchOpsLabels: typing.Tuple = (), - fitPlot3D: bool = False, - show_arch_weights: bool = True, - ): - """ - Initializes the debug window. - - Arguments: - num_epochs: number of architecture training epochs - numArchEdges: number of architecture edges - numArchOps: number of architecture operations - ArchOpsLabels: list of architecture operation labels - fitPlot3D: if True, the 3D plot of the fit is shown - show_arch_weights: if True, the architecture weights are shown - """ - - # initialization - matplotlib.use("TkAgg") # need to add this for PyCharm environment - - plt.ion() - - # SETTINGS - self.show_arch_weights = show_arch_weights - self.fontSize = 10 - - self.performancePlot_limit = (0, 1) - self.modelFitPlot_limit = (0, 500) - self.mismatchPlot_limit = (0, 1) - self.architectureWeightsPlot_limit = (0.1, 0.2) - - self.numPatternsShown = 100 - - # FIGURE - self.fig = plt.figure() - self.fig.set_size_inches(13, 7) - - if self.show_arch_weights is False: - numArchEdges = 0 - - # set up grid - numRows = np.max((1 + np.ceil((numArchEdges + 1) / 4), 2)) - gs = GridSpec(numRows.astype(int), 4, figure=self.fig) - - self.fig.subplots_adjust( - left=0.1, bottom=0.1, right=0.90, top=0.9, wspace=0.4, hspace=0.5 - ) - self.modelGraph = self.fig.add_subplot(gs[1, 0]) - self.performancePlot = self.fig.add_subplot(gs[0, 0]) - self.modelFitPlot = self.fig.add_subplot(gs[0, 1]) - if fitPlot3D: - self.mismatchPlot = self.fig.add_subplot(gs[0, 2], projection="3d") - else: - self.mismatchPlot = self.fig.add_subplot(gs[0, 2]) - self.examplePatternsPlot = self.fig.add_subplot(gs[0, 3]) - - self.architecturePlot = [] - - for edge in range(numArchEdges): - row = np.ceil((edge + 2) / 4).astype(int) - col = (edge + 1) % 4 - self.architecturePlot.append(self.fig.add_subplot(gs[row, col])) - - self.colors = ( - "black", - "red", - "green", - "blue", - "purple", - "orange", - "brown", - "pink", - "grey", - "olive", - "cyan", - "yellow", - "skyblue", - "coral", - "magenta", - "seagreen", - "sandybrown", - ) - - # PERFORMANCE PLOT - x = 1 - y = 1 - (self.train_error,) = self.performancePlot.plot(x, y, "r-") - (self.valid_error,) = self.performancePlot.plot(x, y, "b", linestyle="dashed") - - # set labels - self.performancePlot.set_xlabel("Epoch", fontsize=self.fontSize) - self.performancePlot.set_ylabel("Cross-Entropy Loss", fontsize=self.fontSize) - self.performancePlot.set_title("Performance", fontsize=self.fontSize) - self.performancePlot.legend( - (self.train_error, self.valid_error), ("training error", "validation error") - ) - - # adjust axes - self.performancePlot.set_xlim(0, num_epochs) - self.performancePlot.set_ylim( - self.performancePlot_limit[0], self.performancePlot_limit[1] - ) - - # MODEL FIT PLOT - x = 1 - y = 1 - (self.BIC,) = self.modelFitPlot.plot(x, y, color="black") - (self.AIC,) = self.modelFitPlot.plot(x, y, color="grey") - - # set labels - self.modelFitPlot.set_xlabel("Epoch", fontsize=self.fontSize) - self.modelFitPlot.set_ylabel("Information Criterion", fontsize=self.fontSize) - self.modelFitPlot.set_title("Model Fit", fontsize=self.fontSize) - self.modelFitPlot.legend((self.BIC, self.AIC), ("BIC", "AIC")) - - # adjust axes - self.modelFitPlot.set_xlim(0, num_epochs) - self.modelFitPlot.set_ylim( - self.modelFitPlot_limit[0], self.modelFitPlot_limit[1] - ) - - # RANGE PREDICTION FIT PLOT - x = 1 - y = 1 - if fitPlot3D: - x = np.arange(0, 1, 0.1) - y = np.arange(0, 1, 0.1) - X, Y = np.meshgrid(x, y) - Z = X * np.exp(-X - Y) - - self.range_target = self.mismatchPlot.plot_surface(X, Y, Z) - self.range_prediction = self.mismatchPlot.plot_surface(X, Y, Z) - self.mismatchPlot.set_zlim( - self.mismatchPlot_limit[0], self.mismatchPlot_limit[1] - ) - - # set labels - self.mismatchPlot.set_xlabel("Stimulus 1", fontsize=self.fontSize) - self.mismatchPlot.set_ylabel("Stimulus 2", fontsize=self.fontSize) - self.mismatchPlot.set_zlabel("Outcome Value", fontsize=self.fontSize) - - else: - (self.range_target,) = self.mismatchPlot.plot(x, y, color="black") - (self.range_prediction,) = self.mismatchPlot.plot(x, y, "--", color="red") - - # set labels - self.mismatchPlot.set_xlabel("Stimulus Value", fontsize=self.fontSize) - self.mismatchPlot.set_ylabel("Outcome Value", fontsize=self.fontSize) - self.mismatchPlot.legend( - (self.range_target, self.range_prediction), ("target", "prediction") - ) - - self.mismatchPlot.set_title("Target vs. Prediction", fontsize=self.fontSize) - - # adjust axes - self.mismatchPlot.set_xlim(0, 1) - self.mismatchPlot.set_ylim(0, 1) - - # ARCHITECTURE WEIGHT PLOT - if self.show_arch_weights: - - self.architectureWeights = [] - for idx, architecturePlot in enumerate(self.architecturePlot): - plotWeights = [] - x = 1 - y = 1 - for op in range(numArchOps): - (plotWeight,) = architecturePlot.plot(x, y, color=self.colors[op]) - plotWeights.append(plotWeight) - - # set legend - if idx == 0: - architecturePlot.legend( - plotWeights, ArchOpsLabels, prop={"size": 6} - ) - - # add labels - architecturePlot.set_ylabel("Weight", fontsize=self.fontSize) - architecturePlot.set_title( - "(" + str(idx) + ") Edge Weight", fontsize=self.fontSize - ) - if idx == len(self.architecturePlot) - 1: - architecturePlot.set_xlabel("Epoch", fontsize=self.fontSize) - - # adjust axes - architecturePlot.set_xlim(0, num_epochs) - architecturePlot.set_ylim( - self.architectureWeightsPlot_limit[0], - self.architectureWeightsPlot_limit[1], - ) - - self.architectureWeights.append(plotWeights) - - # draw - plt.draw() - - def update( - self, - train_error: Optional[np.array] = None, - valid_error: Optional[np.array] = None, - weights: Optional[np.array] = None, - BIC: Optional[np.array] = None, - AIC: Optional[np.array] = None, - model_graph: Optional[str] = None, - range_input1: Optional[np.array] = None, - range_input2: Optional[np.array] = None, - range_target: Optional[np.array] = None, - range_prediction: Optional[np.array] = None, - target: Optional[np.array] = None, - prediction: Optional[np.array] = None, - ): - """ - Update the debug plot with new data. - - Arguments: - train_error: training error - valid_error: validation error - weights: weights of the model - BIC: Bayesian information criterion of the model - AIC: Akaike information criterion of the model - model_graph: the graph of the model - range_input1: the range of the first input - range_input2: the range of the second input - range_target: the range of the target - range_prediction: the range of the prediction - target: the target - prediction: the prediction - """ - - # update training error - if train_error is not None: - self.train_error.set_xdata( - np.linspace(1, len(train_error), len(train_error)) - ) - self.train_error.set_ydata(train_error) - - # update validation error - if valid_error is not None: - self.valid_error.set_xdata( - np.linspace(1, len(valid_error), len(valid_error)) - ) - self.valid_error.set_ydata(valid_error) - - # update BIC - if BIC is not None: - self.BIC.set_xdata(np.linspace(1, len(BIC), len(BIC))) - self.BIC.set_ydata(BIC) - - # update AIC - if AIC is not None: - self.AIC.set_xdata(np.linspace(1, len(AIC), len(AIC))) - self.AIC.set_ydata(AIC) - - # update target vs. prediction plot - if ( - range_input1 is not None - and range_target is not None - and range_prediction is not None - and range_input2 is None - ): - self.range_target.set_xdata(range_input1) - self.range_target.set_ydata(range_target) - self.range_prediction.set_xdata(range_input1) - self.range_prediction.set_ydata(range_prediction) - elif ( - range_input1 is not None - and range_target is not None - and range_prediction is not None - and range_input2 is not None - ): - - # update plot - self.mismatchPlot.cla() - self.range_target = self.mismatchPlot.plot_surface( - range_input1, range_input2, range_target, color=(0, 0, 0, 0.5) - ) - self.range_prediction = self.mismatchPlot.plot_surface( - range_input1, range_input2, range_prediction, color=(1, 0, 0, 0.5) - ) - - # set labels - self.mismatchPlot.set_xlabel("Stimulus 1", fontsize=self.fontSize) - self.mismatchPlot.set_ylabel("Stimulus 2", fontsize=self.fontSize) - self.mismatchPlot.set_zlabel("Outcome Value", fontsize=self.fontSize) - self.mismatchPlot.set_title("Target vs. Prediction", fontsize=self.fontSize) - - # update example pattern plot - if target is not None and prediction is not None: - - # select limited number of patterns - self.numPatternsShown = np.min((self.numPatternsShown, target.shape[0])) - target = target[0 : self.numPatternsShown, :] - prediction = prediction[0 : self.numPatternsShown, :] - - im = np.concatenate((target, prediction), axis=1) - self.examplePatternsPlot.cla() - self.examplePatternsPlot.imshow(im, interpolation="nearest", aspect="auto") - x = np.ones(target.shape[0]) * (target.shape[1] - 0.5) - y = np.linspace(1, target.shape[0], target.shape[0]) - self.examplePatternsPlot.plot(x, y, color="red") - - # set labels - self.examplePatternsPlot.set_xlabel("Output", fontsize=self.fontSize) - self.examplePatternsPlot.set_ylabel("Pattern", fontsize=self.fontSize) - self.examplePatternsPlot.set_title( - "Target vs. Prediction", fontsize=self.fontSize - ) - - if self.show_arch_weights: - # update weights - if weights is not None: - for plotIdx, architectureWeights in enumerate(self.architectureWeights): - for lineIdx, plotWeight in enumerate(architectureWeights): - plotWeight.set_xdata( - np.linspace(1, weights.shape[0], weights.shape[0]) - ) - plotWeight.set_ydata(weights[:, plotIdx, lineIdx]) - - # draw current graph - if model_graph is not None: - im = imageio.imread(model_graph) - self.modelGraph.cla() - self.modelGraph.imshow(im) - self.modelGraph.axis("off") - - # re-draw plot - plt.draw() - plt.pause(0.02) diff --git a/poetry.lock b/poetry.lock index 25083ab1a..f1b64e806 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "anyio" @@ -313,7 +313,7 @@ unicode-backport = ["unicodedata2"] name = "click" version = "8.1.3" description = "Composable command line interface toolkit" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1312,6 +1312,31 @@ importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""} [package.extras] testing = ["coverage", "pyyaml"] +[[package]] +name = "markdown-it-py" +version = "2.2.0" +description = "Python port of markdown-it. Markdown parsing, done right!" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "markdown-it-py-2.2.0.tar.gz", hash = "sha256:7c9a5e412688bc771c67432cbfebcdd686c93ce6484913dccf06cb5a0bea35a1"}, + {file = "markdown_it_py-2.2.0-py3-none-any.whl", hash = "sha256:5a35f8d1870171d9acc47b99612dc146129b631baf04970128b568f190d0cc30"}, +] + +[package.dependencies] +mdurl = ">=0.1,<1.0" + +[package.extras] +benchmarking = ["psutil", "pytest", "pytest-benchmark"] +code-style = ["pre-commit (>=3.0,<4.0)"] +compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] +linkify = ["linkify-it-py (>=1,<3)"] +plugins = ["mdit-py-plugins"] +profiling = ["gprof2dot"] +rtd = ["attrs", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + [[package]] name = "markupsafe" version = "2.1.1" @@ -1452,6 +1477,18 @@ files = [ {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, ] +[[package]] +name = "mdurl" +version = "0.1.2" +description = "Markdown URL utilities" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, +] + [[package]] name = "mergedeep" version = "1.3.4" @@ -2349,7 +2386,7 @@ files = [ name = "pygments" version = "2.14.0" description = "Pygments is a syntax highlighting package written in Python." -category = "dev" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2530,7 +2567,7 @@ files = [ name = "pyyaml" version = "6.0" description = "YAML parser and emitter for Python" -category = "dev" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2870,6 +2907,26 @@ files = [ {file = "rfc3986_validator-0.1.1.tar.gz", hash = "sha256:3d44bde7921b3b9ec3ae4e3adca370438eccebc676456449b145d533b240d055"}, ] +[[package]] +name = "rich" +version = "13.3.3" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +category = "main" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "rich-13.3.3-py3-none-any.whl", hash = "sha256:540c7d6d26a1178e8e8b37e9ba44573a3cd1464ff6348b99ee7061b95d1c6333"}, + {file = "rich-13.3.3.tar.gz", hash = "sha256:dc84400a9d842b3a9c5ff74addd8eb798d155f36c1c91303888e0a66850d2a15"}, +] + +[package.dependencies] +markdown-it-py = ">=2.2.0,<3.0.0" +pygments = ">=2.13.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""} + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + [[package]] name = "scikit-learn" version = "1.2.2" @@ -3008,6 +3065,18 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-g testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +[[package]] +name = "shellingham" +version = "1.5.0.post1" +description = "Tool to Detect Surrounding Shell" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "shellingham-1.5.0.post1-py2.py3-none-any.whl", hash = "sha256:368bf8c00754fd4f55afb7bbb86e272df77e4dc76ac29dbcbb81a59e9fc15744"}, + {file = "shellingham-1.5.0.post1.tar.gz", hash = "sha256:823bc5fb5c34d60f285b624e7264f4dda254bc803a3774a147bf99c0e3004a28"}, +] + [[package]] name = "six" version = "1.16.0" @@ -3255,6 +3324,27 @@ lint = ["black (>=22.6.0)", "mdformat (>0.7)", "ruff (>=0.0.156)"] test = ["pre-commit", "pytest"] typing = ["mypy (>=0.990)"] +[[package]] +name = "typer" +version = "0.7.0" +description = "Typer, build great CLIs. Easy to code. Based on Python type hints." +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "typer-0.7.0-py3-none-any.whl", hash = "sha256:b5e704f4e48ec263de1c0b3a2387cd405a13767d2f907f44c1a08cbad96f606d"}, + {file = "typer-0.7.0.tar.gz", hash = "sha256:ff797846578a9f2a201b53442aedeb543319466870fbe1c701eab66dd7681165"}, +] + +[package.dependencies] +click = ">=7.1.1,<9.0.0" + +[package.extras] +all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] +doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"] +test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] + [[package]] name = "typing-extensions" version = "4.4.0" @@ -3463,4 +3553,4 @@ tinkerforge = ["tinkerforge"] [metadata] lock-version = "2.0" python-versions = ">=3.8.10,<3.11" -content-hash = "462fb631f876e5145435416ab989b93a1b71cef94cb8d1afbf3e88762719824d" +content-hash = "cc78e9181360d93680975fa7f6918744d64267fff2ee382872c023e2b2ceca54" diff --git a/pyproject.toml b/pyproject.toml index 90b4895cc..18bd166cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,12 @@ sympy = "^1.10.1" tinkerforge = {version = "^2.1.25", optional=true} torch = "1.13.1" tqdm = "^4.64.0" +typer = "^0.7.0" +rich = "^13.3.3" +shellingham = "^1.5.0.post1" +pyyaml = "^6.0" + + [tool.poetry.extras] tinkerforge = ["tinkerforge"] diff --git a/tests/cli/controller/basic-usage/.gitignore b/tests/cli/controller/basic-usage/.gitignore new file mode 100644 index 000000000..c375d5ba0 --- /dev/null +++ b/tests/cli/controller/basic-usage/.gitignore @@ -0,0 +1 @@ +history diff --git a/tests/cli/controller/basic-usage/controller.yml b/tests/cli/controller/basic-usage/controller.yml new file mode 100644 index 000000000..b0e9e4c32 --- /dev/null +++ b/tests/cli/controller/basic-usage/controller.yml @@ -0,0 +1,17 @@ +!!python/object:autora.controller.controller.Controller + _experiment_runner_callable: null + _experimentalist_pipeline: null + _theorist_estimator: null + executor_collection: + experiment_runner: &id001 !!python/name:autora.controller.executor.no_op '' + experimentalist: *id001 + theorist: *id001 + monitor: null + planner: !!python/name:autora.controller.planner.last_result_kind_planner '' + state: !!python/object:autora.controller.state.history.History + _history: + - !!python/object:autora.controller.state.history.Result + data: {} + kind: !!python/object/apply:autora.controller.protocol.ResultKind + - PARAMS + diff --git a/tests/cli/controller/custom-function-with-cylc-slurm/README.md b/tests/cli/controller/custom-function-with-cylc-slurm/README.md new file mode 100644 index 000000000..b4f0b25d0 --- /dev/null +++ b/tests/cli/controller/custom-function-with-cylc-slurm/README.md @@ -0,0 +1,16 @@ +# Example of running the controller CLI under cylc + +Requires a conda environment called `autora-cylc` with the following dependencies: +- `autora` 3.0.0a0+ +- `cylc-flow` + +```bash +conda activate /users/jholla10/anaconda/autora-cylc +``` + +Run and show output from this directory using +```bash +cylc install . +cylc play custom-function-with-cylc-slurm +cylc tui custom-function-with-cylc-slurm +``` diff --git a/tests/cli/controller/custom-function-with-cylc-slurm/controller.yml b/tests/cli/controller/custom-function-with-cylc-slurm/controller.yml new file mode 100644 index 000000000..feff9f3d1 --- /dev/null +++ b/tests/cli/controller/custom-function-with-cylc-slurm/controller.yml @@ -0,0 +1,41 @@ +!!python/object:autora.controller.controller.Controller +_experiment_runner_callable: &id002 !!python/name:func.plus_1 '' +_experimentalist_pipeline: &id004 !!python/object:autora.experimentalist.pipeline.Pipeline + params: {} + steps: + - !!python/tuple + - ten_random_samples + - !!python/name:func.ten_random_samples '' +_theorist_estimator: &id006 !!python/object:sklearn.linear_model._base.LinearRegression + _sklearn_version: 1.2.2 + copy_X: true + fit_intercept: true + n_jobs: null + positive: false +executor_collection: + experiment_runner: !!python/object/apply:functools.partial + args: + - &id001 !!python/name:autora.controller.executor.experiment_runner_wrapper '' + state: !!python/tuple + - *id001 + - !!python/tuple [] + - callable: *id002 + - null + experimentalist: !!python/object/apply:functools.partial + args: + - &id003 !!python/name:autora.controller.executor.experimentalist_wrapper '' + state: !!python/tuple + - *id003 + - !!python/tuple [] + - pipeline: *id004 + - null + theorist: !!python/object/apply:functools.partial + args: + - &id005 !!python/name:autora.controller.executor.theorist_wrapper '' + state: !!python/tuple + - *id005 + - !!python/tuple [] + - estimator: *id006 + - null +monitor: null +planner: !!python/name:autora.controller.planner.last_result_kind_planner '' diff --git a/tests/cli/controller/custom-function-with-cylc-slurm/environment.yml b/tests/cli/controller/custom-function-with-cylc-slurm/environment.yml new file mode 100644 index 000000000..d41c6fac8 --- /dev/null +++ b/tests/cli/controller/custom-function-with-cylc-slurm/environment.yml @@ -0,0 +1,244 @@ +name: /users/jholla10/anaconda/autora-cylc +channels: + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - aiofiles=0.7.0=pyhd8ed1ab_0 + - alembic=1.10.3=pyhd8ed1ab_0 + - aniso8601=7.0.0=py_0 + - ansimarkup=1.5.0=pyh44b312d_0 + - anyio=3.6.2=pyhd8ed1ab_0 + - argon2-cffi=21.3.0=pyhd8ed1ab_0 + - argon2-cffi-bindings=21.2.0=py39hb9d737c_3 + - async-timeout=4.0.2=pyhd8ed1ab_0 + - async_generator=1.10=py_0 + - atk-1.0=2.38.0=hd4edc92_1 + - attrs=22.2.0=pyh71513ae_0 + - beautifulsoup4=4.12.1=pyha770c72_0 + - bleach=6.0.0=pyhd8ed1ab_0 + - blinker=1.6=pyhd8ed1ab_0 + - brotli=1.0.9=h166bdaf_8 + - brotli-bin=1.0.9=h166bdaf_8 + - brotlipy=0.7.0=py39hb9d737c_1005 + - bzip2=1.0.8=h7f98852_4 + - c-ares=1.18.1=h7f98852_0 + - ca-certificates=2022.12.7=ha878542_0 + - cairo=1.16.0=ha61ee94_1014 + - certifi=2022.12.7=pyhd8ed1ab_0 + - certipy=0.1.3=py_0 + - cffi=1.15.1=py39he91dace_3 + - charset-normalizer=3.1.0=pyhd8ed1ab_0 + - colorama=0.4.6=pyhd8ed1ab_0 + - configurable-http-proxy=4.5.4=he2f69ee_2 + - contourpy=1.0.7=py39h4b4f3f3_0 + - cryptography=40.0.1=py39h079d5ae_0 + - cycler=0.11.0=pyhd8ed1ab_0 + - cylc-flow=8.1.2=pyhb6b8b6f_1 + - cylc-flow-base=8.1.2=pyhd8ed1ab_1 + - cylc-uiserver=1.2.1=pyhd8ed1ab_0 + - cylc-uiserver-base=1.2.1=pyhd8ed1ab_0 + - defusedxml=0.7.1=pyhd8ed1ab_0 + - empy=3.3.4=pyh9f0ad1d_1 + - entrypoints=0.4=pyhd8ed1ab_0 + - expat=2.5.0=hcb278e6_1 + - flit-core=3.8.0=pyhd8ed1ab_0 + - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 + - font-ttf-inconsolata=3.000=h77eed37_0 + - font-ttf-source-code-pro=2.038=h77eed37_0 + - font-ttf-ubuntu=0.83=hab24e00_0 + - fontconfig=2.14.2=h14ed4e7_0 + - fonts-conda-ecosystem=1=0 + - fonts-conda-forge=1=0 + - fonttools=4.39.3=py39h72bdee0_0 + - freetype=2.12.1=hca18f0e_1 + - fribidi=1.0.10=h36c2ea0_0 + - gdk-pixbuf=2.42.10=h05c8ddd_0 + - gettext=0.21.1=h27087fc_0 + - gh=2.25.1=ha8f183a_0 + - giflib=5.2.1=h0b41bf4_3 + - graphene=2.1.9=pyhd8ed1ab_0 + - graphene-tornado=2.6.1=py_0 + - graphite2=1.3.13=h58526e2_1001 + - graphql-core=2.3.2=pyh9f0ad1d_0 + - graphql-relay=2.0.1=py_0 + - graphql-ws=0.4.4=pyhd8ed1ab_0 + - graphviz=7.1.0=h2e5815a_0 + - greenlet=2.0.2=py39h227be39_0 + - gtk2=2.24.33=h90689f9_2 + - gts=0.7.6=h64030ff_2 + - harfbuzz=6.0.0=h8e241bc_0 + - icu=70.1=h27087fc_0 + - idna=3.4=pyhd8ed1ab_0 + - importlib-metadata=6.1.0=pyha770c72_0 + - importlib-resources=5.12.0=pyhd8ed1ab_0 + - importlib_resources=5.12.0=pyhd8ed1ab_0 + - jinja2=3.0.3=pyhd8ed1ab_0 + - jpeg=9e=h0b41bf4_3 + - jsonschema=4.17.3=pyhd8ed1ab_0 + - jupyter_client=7.3.1=pyhd8ed1ab_0 + - jupyter_core=5.3.0=py39hf3d152e_0 + - jupyter_server=1.23.6=pyhd8ed1ab_0 + - jupyter_telemetry=0.1.0=pyhd8ed1ab_1 + - jupyterhub=3.1.1=pyh2a2186d_0 + - jupyterhub-base=3.1.1=pyh2a2186d_0 + - jupyterlab_pygments=0.2.2=pyhd8ed1ab_0 + - keyutils=1.6.1=h166bdaf_0 + - kiwisolver=1.4.4=py39hf939315_1 + - krb5=1.20.1=h81ceb04_0 + - lcms2=2.15=hfd0df8a_0 + - ld_impl_linux-64=2.40=h41732ed_0 + - lerc=4.0.0=h27087fc_0 + - libblas=3.9.0=16_linux64_openblas + - libbrotlicommon=1.0.9=h166bdaf_8 + - libbrotlidec=1.0.9=h166bdaf_8 + - libbrotlienc=1.0.9=h166bdaf_8 + - libcblas=3.9.0=16_linux64_openblas + - libcurl=7.88.1=hdc1c0ab_1 + - libdeflate=1.17=h0b41bf4_0 + - libedit=3.1.20191231=he28a2e2_2 + - libev=4.33=h516909a_1 + - libexpat=2.5.0=hcb278e6_1 + - libffi=3.4.2=h7f98852_5 + - libgcc-ng=12.2.0=h65d4601_19 + - libgd=2.3.3=h5aea950_4 + - libgfortran-ng=12.2.0=h69a702a_19 + - libgfortran5=12.2.0=h337968e_19 + - libglib=2.74.1=h606061b_1 + - libgomp=12.2.0=h65d4601_19 + - libiconv=1.17=h166bdaf_0 + - liblapack=3.9.0=16_linux64_openblas + - libnghttp2=1.52.0=h61bc06f_0 + - libnsl=2.0.0=h7f98852_0 + - libopenblas=0.3.21=pthreads_h78a6416_3 + - libpng=1.6.39=h753d276_0 + - libprotobuf=3.21.12=h3eb15da_0 + - librsvg=2.54.4=h7abd40a_0 + - libsodium=1.0.18=h36c2ea0_1 + - libsqlite=3.40.0=h753d276_0 + - libssh2=1.10.0=hf14f497_3 + - libstdcxx-ng=12.2.0=h46fd767_19 + - libtiff=4.5.0=h6adf6a1_2 + - libtool=2.4.7=h27087fc_0 + - libuuid=2.38.1=h0b41bf4_0 + - libuv=1.44.2=h166bdaf_0 + - libwebp=1.2.4=h1daa5a0_1 + - libwebp-base=1.2.4=h166bdaf_0 + - libxcb=1.13=h7f98852_1004 + - libxml2=2.10.3=hca2bb57_4 + - libzlib=1.2.13=h166bdaf_4 + - mako=1.2.4=pyhd8ed1ab_0 + - markupsafe=2.1.2=py39h72bdee0_0 + - matplotlib-base=3.7.1=py39he190548_0 + - metomi-isodatetime=1!3.0.0=pyhd8ed1ab_0 + - mistune=2.0.5=pyhd8ed1ab_0 + - munkres=1.1.4=pyh9f0ad1d_0 + - nbclient=0.7.3=pyhd8ed1ab_0 + - nbconvert-core=7.3.0=pyhd8ed1ab_2 + - nbformat=5.8.0=pyhd8ed1ab_0 + - ncurses=6.3=h27087fc_1 + - nest-asyncio=1.5.6=pyhd8ed1ab_0 + - nodejs=18.15.0=h8d033a5_0 + - numpy=1.24.2=py39h7360e5f_0 + - oauthlib=3.2.2=pyhd8ed1ab_0 + - openjpeg=2.5.0=hfec8fc6_2 + - openssl=3.1.0=h0b41bf4_0 + - packaging=23.0=pyhd8ed1ab_0 + - pamela=1.0.0=py_0 + - pandas=1.5.3=py39h2ad29b5_1 + - pandocfilters=1.5.0=pyhd8ed1ab_0 + - pango=1.50.14=hd33c08f_0 + - pcre2=10.40=hc3806b6_0 + - pillow=9.4.0=py39h2320bf1_1 + - pip=23.0.1=pyhd8ed1ab_0 + - pixman=0.40.0=h36c2ea0_0 + - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0 + - platformdirs=3.2.0=pyhd8ed1ab_0 + - prometheus_client=0.16.0=pyhd8ed1ab_0 + - promise=2.3=py39hf3d152e_7 + - protobuf=4.21.12=py39h227be39_0 + - psutil=5.9.4=py39hb9d737c_0 + - pthread-stubs=0.4=h36c2ea0_1001 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - pycparser=2.21=pyhd8ed1ab_0 + - pycurl=7.45.1=py39h9297c8b_3 + - pygments=2.14.0=pyhd8ed1ab_0 + - pyjwt=2.6.0=pyhd8ed1ab_0 + - pyopenssl=23.1.1=pyhd8ed1ab_0 + - pyparsing=3.0.9=pyhd8ed1ab_0 + - pyrsistent=0.19.3=py39h72bdee0_0 + - pysocks=1.7.1=pyha2e5f31_6 + - python=3.9.16=h2782a2a_0_cpython + - python-dateutil=2.8.2=pyhd8ed1ab_0 + - python-fastjsonschema=2.16.3=pyhd8ed1ab_0 + - python-json-logger=2.0.7=pyhd8ed1ab_0 + - python_abi=3.9=3_cp39 + - pytz=2023.3=pyhd8ed1ab_0 + - pyzmq=22.3.0=py39headdf64_2 + - readline=8.2=h8228510_1 + - requests=2.28.2=pyhd8ed1ab_1 + - ruamel.yaml=0.17.21=py39h72bdee0_3 + - ruamel.yaml.clib=0.2.7=py39h72bdee0_1 + - rx=1.6.1=py_0 + - send2trash=1.8.0=pyhd8ed1ab_0 + - setuptools=66.1.1=pyhd8ed1ab_0 + - six=1.16.0=pyh6c4a22f_0 + - sniffio=1.3.0=pyhd8ed1ab_0 + - soupsieve=2.3.2.post1=pyhd8ed1ab_0 + - sqlalchemy=2.0.9=py39h72bdee0_0 + - terminado=0.17.1=pyh41d4057_0 + - tinycss2=1.2.1=pyhd8ed1ab_0 + - tk=8.6.12=h27826a3_0 + - tomli=2.0.1=pyhd8ed1ab_0 + - tornado=6.2=py39hb9d737c_1 + - traitlets=5.9.0=pyhd8ed1ab_0 + - typing-extensions=4.5.0=hd8ed1ab_0 + - typing_extensions=4.5.0=pyha770c72_0 + - tzdata=2023c=h71feb2d_0 + - unicodedata2=15.0.0=py39hb9d737c_0 + - urllib3=1.26.15=pyhd8ed1ab_0 + - urwid=2.1.2=py39hb9d737c_7 + - webencodings=0.5.1=py_1 + - websocket-client=1.5.1=pyhd8ed1ab_0 + - werkzeug=0.12.2=py_1 + - wheel=0.40.0=pyhd8ed1ab_0 + - xorg-kbproto=1.0.7=h7f98852_1002 + - xorg-libice=1.0.10=h7f98852_0 + - xorg-libsm=1.2.3=hd9c2040_1000 + - xorg-libx11=1.8.4=h0b41bf4_0 + - xorg-libxau=1.0.9=h7f98852_0 + - xorg-libxdmcp=1.1.3=h7f98852_0 + - xorg-libxext=1.3.4=h0b41bf4_2 + - xorg-libxrender=0.9.10=h7f98852_1003 + - xorg-renderproto=0.11.1=h7f98852_1002 + - xorg-xextproto=7.3.0=h0b41bf4_1003 + - xorg-xproto=7.0.31=h7f98852_1007 + - xz=5.2.6=h166bdaf_0 + - zeromq=4.3.4=h9c3ff4c_1 + - zipp=3.15.0=pyhd8ed1ab_0 + - zlib=1.2.13=h166bdaf_4 + - zstd=1.5.2=h3eb15da_6 + - pip: + - autora==0.0.0 + - imageio==2.27.0 + - joblib==1.2.0 + - markdown-it-py==2.2.0 + - mdurl==0.1.2 + - mpmath==1.3.0 + - nvidia-cublas-cu11==11.10.3.66 + - nvidia-cuda-nvrtc-cu11==11.7.99 + - nvidia-cuda-runtime-cu11==11.7.99 + - nvidia-cudnn-cu11==8.5.0.96 + - python-graphviz==0.20.1 + - rich==13.3.4 + - scikit-learn==1.2.2 + - scipy==1.10.1 + - seaborn==0.12.2 + - shellingham==1.5.0.post1 + - sympy==1.11.1 + - threadpoolctl==3.1.0 + - torch==1.13.1 + - tqdm==4.65.0 + - typer==0.7.0 +prefix: /users/jholla10/anaconda/autora-cylc diff --git a/tests/cli/controller/custom-function-with-cylc-slurm/flow.cylc b/tests/cli/controller/custom-function-with-cylc-slurm/flow.cylc new file mode 100644 index 000000000..409a57540 --- /dev/null +++ b/tests/cli/controller/custom-function-with-cylc-slurm/flow.cylc @@ -0,0 +1,28 @@ +[runtime] + [[setup]] + script = """ + cp $CYLC_WORKFLOW_RUN_DIR/controller.yml $CYLC_WORKFLOW_SHARE_DIR; + cp -R $CYLC_WORKFLOW_RUN_DIR/history $CYLC_WORKFLOW_SHARE_DIR/history + conda create -p $CYLC_WORKFLOW_SHARE_DIR/env --clone /users/jholla10/anaconda/autora-cylc + """ + [[experimentalist]] + script = $CYLC_WORKFLOW_SHARE_DIR/env/bin/python -m autora.controller "$CYLC_WORKFLOW_SHARE_DIR/controller.yml" "$CYLC_WORKFLOW_SHARE_DIR/history/." --step-name=experimentalist --verbose --debug + platform = oscar + [[experiment_runner]] + script = /$CYLC_WORKFLOW_SHARE_DIR/env/bin/python -m autora.controller "$CYLC_WORKFLOW_SHARE_DIR/controller.yml" "$CYLC_WORKFLOW_SHARE_DIR/history/." --step-name=experiment_runner --verbose --debug + platform = oscar + [[theorist]] + script = $CYLC_WORKFLOW_SHARE_DIR/env/bin/python -m autora.controller "$CYLC_WORKFLOW_SHARE_DIR/controller.yml" "$CYLC_WORKFLOW_SHARE_DIR/history/." --step-name=theorist --verbose --debug + platform = oscar + +[scheduling] + cycling mode = integer + initial cycle point = 1 + [[graph]] + R1 = """ + setup => experimentalist + """ + P1 = """ + experimentalist => experiment_runner => theorist + theorist[-P1] => experimentalist + """ diff --git a/tests/cli/controller/custom-function-with-cylc-slurm/global.cylc b/tests/cli/controller/custom-function-with-cylc-slurm/global.cylc new file mode 100644 index 000000000..715141f2b --- /dev/null +++ b/tests/cli/controller/custom-function-with-cylc-slurm/global.cylc @@ -0,0 +1,8 @@ +[platforms] + [[oscar]] + hosts = localhost + install target = localhost + job runner = slurm + retrieve job logs = True + cylc path = /users/jholla10/anaconda/cylc/bin/cylc + diff --git a/tests/cli/controller/custom-function-with-cylc-slurm/history/00000000-METADATA.yaml b/tests/cli/controller/custom-function-with-cylc-slurm/history/00000000-METADATA.yaml new file mode 100644 index 000000000..c7cfd0006 --- /dev/null +++ b/tests/cli/controller/custom-function-with-cylc-slurm/history/00000000-METADATA.yaml @@ -0,0 +1,27 @@ +!!python/object:autora.controller.state.history.Result +data: !!python/object:autora.variable.VariableCollection + covariates: [] + dependent_variables: + - !!python/object:autora.variable.Variable + allowed_values: null + is_covariate: false + name: y + rescale: 1 + type: !!python/object/apply:autora.variable.ValueType + - real + units: '' + value_range: null + variable_label: '' + independent_variables: + - !!python/object:autora.variable.Variable + allowed_values: null + is_covariate: false + name: x + rescale: 1 + type: !!python/object/apply:autora.variable.ValueType + - real + units: '' + value_range: null + variable_label: '' +kind: !!python/object/apply:autora.controller.protocol.ResultKind +- METADATA diff --git a/tests/cli/controller/custom-function-with-cylc-slurm/history/00000001-PARAMS.yaml b/tests/cli/controller/custom-function-with-cylc-slurm/history/00000001-PARAMS.yaml new file mode 100644 index 000000000..3217172d8 --- /dev/null +++ b/tests/cli/controller/custom-function-with-cylc-slurm/history/00000001-PARAMS.yaml @@ -0,0 +1,4 @@ +!!python/object:autora.controller.state.history.Result +data: {} +kind: !!python/object/apply:autora.controller.protocol.ResultKind +- PARAMS diff --git a/tests/cli/controller/custom-function-with-cylc-slurm/lib/python/func.py b/tests/cli/controller/custom-function-with-cylc-slurm/lib/python/func.py new file mode 100644 index 000000000..428e34513 --- /dev/null +++ b/tests/cli/controller/custom-function-with-cylc-slurm/lib/python/func.py @@ -0,0 +1,16 @@ +import numpy as np +import pandas as pd + + +def plus_1(df: pd.DataFrame): + x = df[["x"]] + y = x + 1 + df["y"] = y + return df + + +def ten_random_samples(random_state: int = 180): + rng = np.random.default_rng(random_state) + values = rng.uniform(0, 10, 10) + df = pd.DataFrame({"x": values}) + return df diff --git a/tests/cli/controller/custom-function-with-cylc/README.md b/tests/cli/controller/custom-function-with-cylc/README.md new file mode 100644 index 000000000..20c8f36f9 --- /dev/null +++ b/tests/cli/controller/custom-function-with-cylc/README.md @@ -0,0 +1,12 @@ +# Example of running the controller CLI under cylc + +Requires a conda environment called `autora-cylc` with the following dependencies: +- `autora` 3.0.0a0+ +- `cylc-flow` + +Run and show output from this directory using +```zsh +cylc install . +cylc play custom-function-with-cylc +cylc tui custom-function-with-cylc +``` diff --git a/tests/cli/controller/custom-function-with-cylc/controller.yml b/tests/cli/controller/custom-function-with-cylc/controller.yml new file mode 100644 index 000000000..feff9f3d1 --- /dev/null +++ b/tests/cli/controller/custom-function-with-cylc/controller.yml @@ -0,0 +1,41 @@ +!!python/object:autora.controller.controller.Controller +_experiment_runner_callable: &id002 !!python/name:func.plus_1 '' +_experimentalist_pipeline: &id004 !!python/object:autora.experimentalist.pipeline.Pipeline + params: {} + steps: + - !!python/tuple + - ten_random_samples + - !!python/name:func.ten_random_samples '' +_theorist_estimator: &id006 !!python/object:sklearn.linear_model._base.LinearRegression + _sklearn_version: 1.2.2 + copy_X: true + fit_intercept: true + n_jobs: null + positive: false +executor_collection: + experiment_runner: !!python/object/apply:functools.partial + args: + - &id001 !!python/name:autora.controller.executor.experiment_runner_wrapper '' + state: !!python/tuple + - *id001 + - !!python/tuple [] + - callable: *id002 + - null + experimentalist: !!python/object/apply:functools.partial + args: + - &id003 !!python/name:autora.controller.executor.experimentalist_wrapper '' + state: !!python/tuple + - *id003 + - !!python/tuple [] + - pipeline: *id004 + - null + theorist: !!python/object/apply:functools.partial + args: + - &id005 !!python/name:autora.controller.executor.theorist_wrapper '' + state: !!python/tuple + - *id005 + - !!python/tuple [] + - estimator: *id006 + - null +monitor: null +planner: !!python/name:autora.controller.planner.last_result_kind_planner '' diff --git a/tests/cli/controller/custom-function-with-cylc/environment.yml b/tests/cli/controller/custom-function-with-cylc/environment.yml new file mode 100644 index 000000000..6cf1b2f85 --- /dev/null +++ b/tests/cli/controller/custom-function-with-cylc/environment.yml @@ -0,0 +1,7 @@ +channels: + - conda-forge + - autoresearch + - defaults +dependencies: + - cylc-flow + - autora=3.0.0a0 diff --git a/tests/cli/controller/custom-function-with-cylc/flow.cylc b/tests/cli/controller/custom-function-with-cylc/flow.cylc new file mode 100644 index 000000000..1f2890105 --- /dev/null +++ b/tests/cli/controller/custom-function-with-cylc/flow.cylc @@ -0,0 +1,25 @@ +[runtime] + [[setup]] + script = """ + cp $CYLC_WORKFLOW_RUN_DIR/controller.yml $CYLC_WORKFLOW_SHARE_DIR; + cp -R $CYLC_WORKFLOW_RUN_DIR/history $CYLC_WORKFLOW_SHARE_DIR/history + conda create -p $CYLC_WORKFLOW_SHARE_DIR/env --clone autora-cylc + """ + [[experimentalist]] + script = $CYLC_WORKFLOW_SHARE_DIR/env/bin/python -m autora.controller "$CYLC_WORKFLOW_SHARE_DIR/controller.yml" "$CYLC_WORKFLOW_SHARE_DIR/history/." --step-name=experimentalist --verbose --debug + [[experiment_runner]] + script = /$CYLC_WORKFLOW_SHARE_DIR/env/bin/python -m autora.controller "$CYLC_WORKFLOW_SHARE_DIR/controller.yml" "$CYLC_WORKFLOW_SHARE_DIR/history/." --step-name=experiment_runner --verbose --debug + [[theorist]] + script = $CYLC_WORKFLOW_SHARE_DIR/env/bin/python -m autora.controller "$CYLC_WORKFLOW_SHARE_DIR/controller.yml" "$CYLC_WORKFLOW_SHARE_DIR/history/." --step-name=theorist --verbose --debug + +[scheduling] + cycling mode = integer + initial cycle point = 1 + [[graph]] + R1 = """ + setup => experimentalist + """ + P1 = """ + experimentalist => experiment_runner => theorist + theorist[-P1] => experimentalist + """ diff --git a/tests/cli/controller/custom-function-with-cylc/history/00000000-METADATA.yaml b/tests/cli/controller/custom-function-with-cylc/history/00000000-METADATA.yaml new file mode 100644 index 000000000..c7cfd0006 --- /dev/null +++ b/tests/cli/controller/custom-function-with-cylc/history/00000000-METADATA.yaml @@ -0,0 +1,27 @@ +!!python/object:autora.controller.state.history.Result +data: !!python/object:autora.variable.VariableCollection + covariates: [] + dependent_variables: + - !!python/object:autora.variable.Variable + allowed_values: null + is_covariate: false + name: y + rescale: 1 + type: !!python/object/apply:autora.variable.ValueType + - real + units: '' + value_range: null + variable_label: '' + independent_variables: + - !!python/object:autora.variable.Variable + allowed_values: null + is_covariate: false + name: x + rescale: 1 + type: !!python/object/apply:autora.variable.ValueType + - real + units: '' + value_range: null + variable_label: '' +kind: !!python/object/apply:autora.controller.protocol.ResultKind +- METADATA diff --git a/tests/cli/controller/custom-function-with-cylc/history/00000001-PARAMS.yaml b/tests/cli/controller/custom-function-with-cylc/history/00000001-PARAMS.yaml new file mode 100644 index 000000000..3217172d8 --- /dev/null +++ b/tests/cli/controller/custom-function-with-cylc/history/00000001-PARAMS.yaml @@ -0,0 +1,4 @@ +!!python/object:autora.controller.state.history.Result +data: {} +kind: !!python/object/apply:autora.controller.protocol.ResultKind +- PARAMS diff --git a/tests/cli/controller/custom-function-with-cylc/lib/python/func.py b/tests/cli/controller/custom-function-with-cylc/lib/python/func.py new file mode 100644 index 000000000..428e34513 --- /dev/null +++ b/tests/cli/controller/custom-function-with-cylc/lib/python/func.py @@ -0,0 +1,16 @@ +import numpy as np +import pandas as pd + + +def plus_1(df: pd.DataFrame): + x = df[["x"]] + y = x + 1 + df["y"] = y + return df + + +def ten_random_samples(random_state: int = 180): + rng = np.random.default_rng(random_state) + values = rng.uniform(0, 10, 10) + df = pd.DataFrame({"x": values}) + return df diff --git a/tests/cli/controller/custom-function/.gitignore b/tests/cli/controller/custom-function/.gitignore new file mode 100644 index 000000000..8139a023f --- /dev/null +++ b/tests/cli/controller/custom-function/.gitignore @@ -0,0 +1,4 @@ +!.gitignore +!history/00000000-METADATA.yaml +!history/00000001-PARAMS.yaml +history/* diff --git a/tests/cli/controller/custom-function/README.md b/tests/cli/controller/custom-function/README.md new file mode 100644 index 000000000..2d5a7fef6 --- /dev/null +++ b/tests/cli/controller/custom-function/README.md @@ -0,0 +1,11 @@ +running on command line using default next step: +```zsh +python -m autora.controller controller.yml history/. --verbose --debug +``` + +running a particular next step: +```zsh +python -m autora.controller controller.yml history/. --step-name=experimentalist --verbose --debug +python -m autora.controller controller.yml history/. --step-name=experiment_runner --verbose --debug +python -m autora.controller controller.yml history/. --step-name=theorist --verbose --debug +``` diff --git a/tests/cli/controller/custom-function/controller.yml b/tests/cli/controller/custom-function/controller.yml new file mode 100644 index 000000000..feff9f3d1 --- /dev/null +++ b/tests/cli/controller/custom-function/controller.yml @@ -0,0 +1,41 @@ +!!python/object:autora.controller.controller.Controller +_experiment_runner_callable: &id002 !!python/name:func.plus_1 '' +_experimentalist_pipeline: &id004 !!python/object:autora.experimentalist.pipeline.Pipeline + params: {} + steps: + - !!python/tuple + - ten_random_samples + - !!python/name:func.ten_random_samples '' +_theorist_estimator: &id006 !!python/object:sklearn.linear_model._base.LinearRegression + _sklearn_version: 1.2.2 + copy_X: true + fit_intercept: true + n_jobs: null + positive: false +executor_collection: + experiment_runner: !!python/object/apply:functools.partial + args: + - &id001 !!python/name:autora.controller.executor.experiment_runner_wrapper '' + state: !!python/tuple + - *id001 + - !!python/tuple [] + - callable: *id002 + - null + experimentalist: !!python/object/apply:functools.partial + args: + - &id003 !!python/name:autora.controller.executor.experimentalist_wrapper '' + state: !!python/tuple + - *id003 + - !!python/tuple [] + - pipeline: *id004 + - null + theorist: !!python/object/apply:functools.partial + args: + - &id005 !!python/name:autora.controller.executor.theorist_wrapper '' + state: !!python/tuple + - *id005 + - !!python/tuple [] + - estimator: *id006 + - null +monitor: null +planner: !!python/name:autora.controller.planner.last_result_kind_planner '' diff --git a/tests/cli/controller/custom-function/func.py b/tests/cli/controller/custom-function/func.py new file mode 100644 index 000000000..428e34513 --- /dev/null +++ b/tests/cli/controller/custom-function/func.py @@ -0,0 +1,16 @@ +import numpy as np +import pandas as pd + + +def plus_1(df: pd.DataFrame): + x = df[["x"]] + y = x + 1 + df["y"] = y + return df + + +def ten_random_samples(random_state: int = 180): + rng = np.random.default_rng(random_state) + values = rng.uniform(0, 10, 10) + df = pd.DataFrame({"x": values}) + return df diff --git a/tests/cli/controller/custom-function/history/00000000-METADATA.yaml b/tests/cli/controller/custom-function/history/00000000-METADATA.yaml new file mode 100644 index 000000000..c7cfd0006 --- /dev/null +++ b/tests/cli/controller/custom-function/history/00000000-METADATA.yaml @@ -0,0 +1,27 @@ +!!python/object:autora.controller.state.history.Result +data: !!python/object:autora.variable.VariableCollection + covariates: [] + dependent_variables: + - !!python/object:autora.variable.Variable + allowed_values: null + is_covariate: false + name: y + rescale: 1 + type: !!python/object/apply:autora.variable.ValueType + - real + units: '' + value_range: null + variable_label: '' + independent_variables: + - !!python/object:autora.variable.Variable + allowed_values: null + is_covariate: false + name: x + rescale: 1 + type: !!python/object/apply:autora.variable.ValueType + - real + units: '' + value_range: null + variable_label: '' +kind: !!python/object/apply:autora.controller.protocol.ResultKind +- METADATA diff --git a/tests/cli/controller/custom-function/history/00000001-PARAMS.yaml b/tests/cli/controller/custom-function/history/00000001-PARAMS.yaml new file mode 100644 index 000000000..3217172d8 --- /dev/null +++ b/tests/cli/controller/custom-function/history/00000001-PARAMS.yaml @@ -0,0 +1,4 @@ +!!python/object:autora.controller.state.history.Result +data: {} +kind: !!python/object/apply:autora.controller.protocol.ResultKind +- PARAMS diff --git a/tests/cli/theorist/basic-usage/.gitignore b/tests/cli/theorist/basic-usage/.gitignore new file mode 100644 index 000000000..f89b52daf --- /dev/null +++ b/tests/cli/theorist/basic-usage/.gitignore @@ -0,0 +1 @@ +model.pickle diff --git a/tests/cli/theorist/basic-usage/README.md b/tests/cli/theorist/basic-usage/README.md new file mode 100644 index 000000000..b8b34b9d2 --- /dev/null +++ b/tests/cli/theorist/basic-usage/README.md @@ -0,0 +1,4 @@ +Run using: +```zsh +python -m autora.theorist variables.yml regressor.yml parameters.yml data.csv model.pickle --verbose --overwrite +``` diff --git a/tests/cli/theorist/basic-usage/data.csv b/tests/cli/theorist/basic-usage/data.csv new file mode 100644 index 000000000..52c6adc05 --- /dev/null +++ b/tests/cli/theorist/basic-usage/data.csv @@ -0,0 +1,5 @@ +x1,x2,c1,y +1,1,7,2 +1,2,7,3 +2,2,7,4 +0,0,7,0 diff --git a/tests/cli/theorist/basic-usage/parameters.yml b/tests/cli/theorist/basic-usage/parameters.yml new file mode 100644 index 000000000..8f78d83ae --- /dev/null +++ b/tests/cli/theorist/basic-usage/parameters.yml @@ -0,0 +1 @@ +fit_intercept: True diff --git a/tests/cli/theorist/basic-usage/regressor.yml b/tests/cli/theorist/basic-usage/regressor.yml new file mode 100644 index 000000000..5f1917253 --- /dev/null +++ b/tests/cli/theorist/basic-usage/regressor.yml @@ -0,0 +1,6 @@ +!!python/object:sklearn.linear_model._base.LinearRegression + _sklearn_version: 1.2.2 + copy_X: true + fit_intercept: true + n_jobs: null + positive: false diff --git a/tests/cli/theorist/basic-usage/variables.yml b/tests/cli/theorist/basic-usage/variables.yml new file mode 100644 index 000000000..c5e11e9e0 --- /dev/null +++ b/tests/cli/theorist/basic-usage/variables.yml @@ -0,0 +1,10 @@ +!!python/object:autora.variable.VariableCollection +covariates: [] +dependent_variables: + - !!python/object:autora.variable.Variable + name: y +independent_variables: + - !!python/object:autora.variable.Variable + name: x1 + - !!python/object:autora.variable.Variable + name: x2