Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(controller): 8 - add command line interface for running on OSCAR #347

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
fd66870
feat: add typer
hollandjg Apr 11, 2023
67c095a
docs: add import for class plus debug flags
hollandjg Apr 13, 2023
e0c8f36
chore: remove broken plot_utils for darts
hollandjg Apr 13, 2023
b4509e9
feat: add example running theorist
hollandjg Apr 13, 2023
93b8d2e
feat: add example using parameters
hollandjg Apr 13, 2023
48699b3
docs: updating docstrings
hollandjg Apr 13, 2023
6d3b641
refactor: reorder file to match execution order
hollandjg Apr 13, 2023
a7be18b
refactor: reorder file to match execution order
hollandjg Apr 13, 2023
3b23df3
feat: add ability to pre-set parameters of regressor
hollandjg Apr 13, 2023
d2f8d51
feat: rename output to model.pickle
hollandjg Apr 13, 2023
ec5e9a0
feat: add target run mode to controller
hollandjg Apr 14, 2023
d0f96d6
revert: remove run_fn
hollandjg Apr 14, 2023
6926c82
chore: remove duplicated ExecutorName
hollandjg Apr 14, 2023
3505613
chore: tighten types for HistorySerializer
hollandjg Apr 14, 2023
9d741a3
chore: remove duplicated State type
hollandjg Apr 14, 2023
9a7e765
chore: remove excess protocol
hollandjg Apr 14, 2023
f9628f6
chore: update type on serializer
hollandjg Apr 14, 2023
b5cfd49
chore: add example which has a real function
hollandjg Apr 14, 2023
d6c41da
feat: make step_name optional (use the controller default if supplied)
hollandjg Apr 14, 2023
011b0d7
feat: make step_name optional (use the controller default if supplied)
hollandjg Apr 14, 2023
11cd58a
feat: make step_name optional (use the controller default if supplied)
hollandjg Apr 14, 2023
96259b6
refactor: reorder file
hollandjg Apr 14, 2023
3b95527
refactor: separate out planner update function
hollandjg Apr 14, 2023
69190cf
feat: add support for more types in executor
hollandjg Apr 14, 2023
33c67ec
test: update random_state for random sample pooler
hollandjg Apr 14, 2023
f065249
docs: add warning about current path in controller
hollandjg Apr 14, 2023
b0afbb7
docs: add readme for custom function execution
hollandjg Apr 14, 2023
26f26b1
docs: add example of running under cylc
hollandjg Apr 14, 2023
6649e1a
docs: add readme
hollandjg Apr 14, 2023
5df55db
deps: add pyyaml dependency
hollandjg Apr 14, 2023
a3fd8e5
deps: add func file
hollandjg Apr 14, 2023
899398b
deps: add func file
hollandjg Apr 14, 2023
ba75e45
docs: add working slurm example
hollandjg Apr 14, 2023
29da43c
try using variables for repeated values
hollandjg Apr 14, 2023
639757b
Revert "try using variables for repeated values"
hollandjg Apr 14, 2023
df78a9d
docs: add example README for theorist
hollandjg Apr 14, 2023
6a8a0e8
docs: add example README for theorist
hollandjg Apr 14, 2023
ceb92e0
Merge remote-tracking branch 'origin/feat/controller-cli' into feat/c…
hollandjg Apr 14, 2023
0893f09
Update autora/controller/__main__.py
hollandjg Apr 20, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .idea/runConfigurations/autora_theorist_basic_usage.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 60 additions & 0 deletions autora/controller/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import logging
import os
import pathlib
from typing import Optional

import typer
import yaml

from autora.controller import Controller
from autora.theorist.__main__ import _configure_logger

_logger = logging.getLogger(__name__)


def main(
manager: pathlib.Path=typer.Argument(..., help="Manager path"),
directory: pathlib.Path=typer.Argument(..., help="Directory path"),
step_name: Optional[str] = typer.Argument(None, help="Name of step"),
verbose: bool = typer.Option(False, help="Turns on info logging level."),
debug: bool = typer.Option(False, help="Turns on debug logging level."),
):
_logger.debug("initializing")
_configure_logger(debug, verbose)
controller_ = _load_manager(manager)

_logger.debug(f"loading manager state from {directory=}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the the = after the directory do?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that prints "directory=<directory name>" which otherwise you'd have to print like: f"directory={directory}"

controller_.load(directory)

if step_name is not None:
controller_ = _set_next_step_name(controller_, step_name)

_logger.info("running next step")
next(controller_)

_logger.debug(f"last result: {controller_.state.history[-1]}")

_logger.info("writing out results")
controller_.dump(directory)

return


def _load_manager(path: pathlib.Path) -> Controller:
_logger.debug(f"_load_manager: loading from {path=} (currently in {os.getcwd()})")
with open(path, "r") as f:
controller_ = yaml.load(f, yaml.Loader)
assert isinstance(
controller_, Controller
), f"controller type {type(controller_)=} unsupported"
return controller_


def _set_next_step_name(controller: Controller, step_name: str):
_logger.info(f"setting next {step_name=}")
controller.planner = lambda _: step_name
return controller


if __name__ == "__main__":
typer.run(main)
14 changes: 6 additions & 8 deletions autora/controller/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
from __future__ import annotations

import logging
from typing import Callable, Mapping, Optional, TypeVar

_logger = logging.getLogger(__name__)
from typing import Callable, Generic, Mapping, Optional

from autora.controller.protocol import State

State = TypeVar("State")
ExecutorName = TypeVar("ExecutorName", bound=str)
_logger = logging.getLogger(__name__)


class BaseController:
class BaseController(Generic[State]):
"""
Runs an experimentalist, theorist and experiment runner in a loop.

Expand All @@ -36,8 +34,8 @@ class BaseController:
def __init__(
self,
state: State,
planner: Callable[[State], ExecutorName],
executor_collection: Mapping[ExecutorName, Callable[[State], State]],
planner: Callable[[State], str],
executor_collection: Mapping[str, Callable[[State], State]],
monitor: Optional[Callable[[State], None]] = None,
):
"""
Expand Down
21 changes: 17 additions & 4 deletions autora/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,23 @@
from __future__ import annotations

import logging
import pathlib
from typing import Callable, Dict, Optional

from sklearn.base import BaseEstimator

from autora.controller.base import BaseController, ExecutorName
from autora.controller.base import BaseController
from autora.controller.executor import make_online_executor_collection
from autora.controller.planner import last_result_kind_planner
from autora.controller.serializer import HistorySerializer
from autora.controller.state import History
from autora.experimentalist.pipeline import Pipeline
from autora.variable import VariableCollection

_logger = logging.getLogger(__name__)


class Controller(BaseController):
class Controller(BaseController[History]):
"""
Runs an experimentalist, experiment runner, and theorist in order.

Expand All @@ -33,13 +35,13 @@ class Controller(BaseController):

def __init__(
self,
metadata: Optional[VariableCollection],
metadata: Optional[VariableCollection] = None,
theorist: Optional[BaseEstimator] = None,
experimentalist: Optional[Pipeline] = None,
experiment_runner: Optional[Callable] = None,
params: Optional[Dict] = None,
monitor: Optional[Callable[[History], None]] = None,
planner: Callable[[History], ExecutorName] = last_result_kind_planner,
planner: Callable[[History], str] = last_result_kind_planner,
):
"""
Args:
Expand Down Expand Up @@ -102,3 +104,14 @@ def __init__(
def seed(self, **kwargs):
for key, value in kwargs.items():
self.state = self.state.update(**{key: value})

def load(self, directory: pathlib.Path):
serializer = HistorySerializer(directory)
state = serializer.load()
self.state = state
return

def dump(self, directory: pathlib.Path):
serializer = HistorySerializer(directory)
serializer.dump(self.state)
return
73 changes: 58 additions & 15 deletions autora/controller/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

import copy
import logging
import pprint
from functools import partial
from types import MappingProxyType
from typing import Callable, Dict, Iterable, Literal, Optional, Tuple, Union

import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator

from autora.controller.protocol import SupportsControllerState
Expand All @@ -27,15 +29,27 @@ def experimentalist_wrapper(
params_ = resolve_state_params(params, state)
new_conditions = pipeline(**params_)

assert isinstance(new_conditions, Iterable)
# If the pipeline gives us an iterable, we need to make it into a concrete array.
# We can't move this logic to the Pipeline, because the pipeline doesn't know whether
# it's within another pipeline and whether it should convert the iterable to a
# concrete array.
new_conditions_values = list(new_conditions)
new_conditions_array = np.array(new_conditions_values)
if isinstance(new_conditions, pd.DataFrame):
new_conditions_array = new_conditions
elif isinstance(new_conditions, np.ndarray):
_logger.warning(
f"{new_conditions=} is an ndarray, so variable confusion is a possibility"
)
new_conditions_array = new_conditions
elif isinstance(new_conditions, np.recarray):
new_conditions_array = new_conditions
elif isinstance(new_conditions, Iterable):
# If the pipeline gives us an iterable, we need to make it into a concrete array.
# We can't move this logic to the Pipeline, because the pipeline doesn't know whether
# it's within another pipeline and whether it should convert the iterable to a
# concrete array.
new_conditions_values = list(new_conditions)
new_conditions_array = np.array(new_conditions_values)
else:
raise NotImplementedError(
f"Can't handle experimentalist output {new_conditions=}"
)

assert isinstance(new_conditions_array, np.ndarray) # Check the object is bounded
new_state = state.update(conditions=[new_conditions_array])
return new_state

Expand All @@ -46,8 +60,15 @@ def experiment_runner_wrapper(
"""Interface for running the experiment runner callable."""
params_ = resolve_state_params(params, state)
x = state.conditions[-1]
y = callable(x, **params_)
new_observations = np.column_stack([x, y])
output = callable(x, **params_)

if isinstance(x, pd.DataFrame):
new_observations = output
elif isinstance(x, np.ndarray):
new_observations = np.column_stack([x, output])
else:
raise NotImplementedError(f"type {x=} not supported")

new_state = state.update(observations=[new_observations])
return new_state

Expand All @@ -59,13 +80,35 @@ def theorist_wrapper(
params_ = resolve_state_params(params, state)
metadata = state.metadata
observations = state.observations
all_observations = np.row_stack(observations)
n_xs = len(metadata.independent_variables)
x, y = all_observations[:, :n_xs], all_observations[:, n_xs:]
if y.shape[1] == 1:
y = y.ravel()

if isinstance(observations[-1], pd.DataFrame):
all_observations = pd.concat(observations)
iv_names = [iv.name for iv in metadata.independent_variables]
dv_names = [dv.name for dv in metadata.dependent_variables]
x, y = all_observations[iv_names], all_observations[dv_names]
elif isinstance(observations[-1], np.ndarray):
all_observations = np.row_stack(observations)
n_xs = len(metadata.independent_variables)
x, y = all_observations[:, :n_xs], all_observations[:, n_xs:]
if y.shape[1] == 1:
y = y.ravel()
else:
raise NotImplementedError(f"type {observations[-1]=} not supported")

new_theorist = copy.deepcopy(estimator)
new_theorist.fit(x, y, **params_)

try:
_logger.debug(
f"fitted {new_theorist=}\nnew_theorist.__dict__:"
f"\n{pprint.pformat(new_theorist.__dict__)}"
)
except AttributeError:
_logger.debug(
f"fitted {new_theorist=} "
f"new_theorist has no __dict__ attribute, so no results are shown"
)

new_state = state.update(theories=[new_theorist])
return new_state

Expand Down
4 changes: 1 addition & 3 deletions autora/controller/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,7 @@ def __call__(self, __state: State, __params: Dict) -> State:
...


ExecutorName = TypeVar("ExecutorName", bound=str)

ExecutorCollection = Mapping[ExecutorName, Executor]
ExecutorCollection = Mapping[str, Executor]


@runtime_checkable
Expand Down
27 changes: 17 additions & 10 deletions autora/controller/serializer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import pickle
import tempfile
from abc import abstractmethod
from pathlib import Path
from typing import Mapping, NamedTuple, Optional, Type, Union
from typing import Generic, Mapping, NamedTuple, Union

import numpy as np

from autora.controller.protocol import (
ResultKind,
SupportsControllerStateHistory,
SupportsLoadDump,
)
from autora.controller.protocol import ResultKind, State, SupportsLoadDump
from autora.controller.serializer import yaml_ as YAMLSerializer
from autora.controller.state import History

Expand All @@ -25,7 +22,17 @@ class _LoadSpec(NamedTuple):
mode: str


class HistorySerializer:
class StateSerializer(Generic[State]):
@abstractmethod
def load(self) -> State:
...

@abstractmethod
def dump(self, ___state: State):
...


class HistorySerializer(StateSerializer[History]):
"""Serializes and deserializes History objects."""

def __init__(
Expand All @@ -51,7 +58,7 @@ def __init__(
".pickle": _LoadSpec(pickle, "rb"),
}

def dump(self, data_collection: SupportsControllerStateHistory):
def dump(self, data_collection: History):
"""

Args:
Expand Down Expand Up @@ -129,7 +136,7 @@ def dump(self, data_collection: SupportsControllerStateHistory):
with open(Path(path, filename), mode) as f:
serializer.dump(container, f)

def load(self, cls: Type[SupportsControllerStateHistory] = History):
def load(self) -> History:
"""

Examples:
Expand Down Expand Up @@ -185,7 +192,7 @@ def load(self, cls: Type[SupportsControllerStateHistory] = History):
loaded_object = serializer.load(f)
data.append(loaded_object)

data_collection = cls(history=data)
data_collection = History(history=data)

return data_collection

Expand Down
Loading