Skip to content
This repository has been archived by the owner on Dec 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #43 from AutoResearch/feat/reorganize-serializers
Browse files Browse the repository at this point in the history
feature!: allow other serializers and set built-in "pickle" as the default
  • Loading branch information
hollandjg authored Nov 29, 2023
2 parents 184021d + 9318e55 commit bb1ed8a
Show file tree
Hide file tree
Showing 7 changed files with 342 additions and 58 deletions.
19 changes: 12 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@ license = {file = "LICENSE"}

dependencies = [
"autora-core>=4.0.0",
"scikit-learn",
"matplotlib",
"pandas",
"typer[all]",
"dill",
"pyyaml",
]

[project.optional-dependencies]
Expand All @@ -26,16 +21,26 @@ dev = [
"autora-workflow[test]",
]
docs = [
"autora-core[docs]>=4.0.0"
"autora-core[docs]>=4.0.0",
"scikit-learn",
"matplotlib",
"pandas",
]
test = [
"autora-core[test]>=4.0.0",
"hypothesis"
"autora-workflow[serializers]",
"hypothesis",
"scikit-learn",
"pandas",
]
cylc = [
"cylc-flow",
"cylc-uiserver"
]
serializers = [
"dill",
"pyyaml"
]

[project.urls]
homepage = "http://www.empiricalresearch.ai/"
Expand Down
82 changes: 82 additions & 0 deletions src/autora/serializer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import importlib
import logging
import pathlib
from collections import namedtuple
from enum import Enum
from typing import Callable, Dict, Literal, Optional, Tuple, Union

from autora.state import State

_logger = logging.getLogger(__name__)


class SerializersSupported(str, Enum):
"""Listing of allowed serializers."""

pickle = "pickle"
dill = "dill"
yaml = "yaml"


_SerializerDef = namedtuple(
"_SerializerDef", ["module", "load", "dump", "dumps", "file_mode"]
)
_serializer_dict: Dict[SerializersSupported, _SerializerDef] = {
SerializersSupported.pickle: _SerializerDef("pickle", "load", "dump", "dumps", "b"),
SerializersSupported.yaml: _SerializerDef(
"autora.serializer._yaml", "load", "dump", "dumps", ""
),
SerializersSupported.dill: _SerializerDef("dill", "load", "dump", "dumps", "b"),
}

default_serializer = SerializersSupported.pickle


def _get_serializer_mode(
serializer: SerializersSupported, interface: Literal["load", "dump", "dumps"]
) -> Tuple[Callable, str]:
serializer_def = _serializer_dict[serializer]
module = serializer_def.module
interface_function_name = getattr(serializer_def, interface)
_logger.debug(
f"_get_serializer_mode: loading {interface_function_name=} from" f" {module=}"
)
module = importlib.import_module(module)
function = getattr(module, interface_function_name)
file_mode = serializer_def.file_mode
return function, file_mode


def load_state(
path: Optional[pathlib.Path],
loader: SerializersSupported = default_serializer,
) -> Union[State, None]:
"""Load a State object from a path."""
if path is not None:
load, file_mode = _get_serializer_mode(loader, "load")
_logger.debug(f"load_state: loading from {path=}")
with open(path, f"r{file_mode}") as f:
state_ = load(f)
else:
_logger.debug(f"load_state: {path=} -> returning None")
state_ = None
return state_


def dump_state(
state_: State,
path: Optional[pathlib.Path],
dumper: SerializersSupported = default_serializer,
) -> None:
"""Write a State object to a path."""
if path is not None:
dump, file_mode = _get_serializer_mode(dumper, "dump")
_logger.debug(f"dump_state: dumping to {path=}")
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, f"w{file_mode}") as f:
dump(state_, f)
else:
dumps, _ = _get_serializer_mode(dumper, "dumps")
_logger.debug(f"dump_state: {path=} so writing to stdout")
print(dumps(state_))
return
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ def dump(data, file):
return


def dumps(data):
yaml.dumps(data, Dumper=yaml.Dumper)
return


def load(file):
result = yaml.load(file, Loader=yaml.Loader)
return result
60 changes: 29 additions & 31 deletions src/autora/workflow/__main__.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,59 @@
import importlib
import logging
import pathlib
from typing import Optional, Union
from typing import Optional

import dill
import typer
from typing_extensions import Annotated

from autora.state import State
from autora.serializer import (
SerializersSupported,
default_serializer,
dump_state,
load_state,
)

_logger = logging.getLogger(__name__)


def main(
fully_qualified_function_name: Annotated[
str, typer.Argument(help="Function to load")
str,
typer.Argument(
help="Fully qualified name of the function to load, like `module.function`"
),
],
in_path: Annotated[
Optional[pathlib.Path],
typer.Option(help="Path to a .dill file with the initial state"),
typer.Option(help="Path to a file with the initial state"),
] = None,
in_loader: Annotated[
SerializersSupported,
typer.Option(
help="(de)serializer to load the data",
),
] = default_serializer,
out_path: Annotated[
Optional[pathlib.Path],
typer.Option(help="Path to output the final state as a .dill file"),
typer.Option(help="Path to output the final state"),
] = None,
out_dumper: Annotated[
SerializersSupported,
typer.Option(
help="serializer to save the data",
),
] = default_serializer,
verbose: Annotated[bool, typer.Option(help="Turns on info logging level.")] = False,
debug: Annotated[bool, typer.Option(help="Turns on debug logging level.")] = False,
):
"""Run an arbitrary function on an optional input State object and save the output."""
_configure_logger(debug, verbose)
starting_state = _load_state(in_path)
starting_state = load_state(in_path, in_loader)
_logger.info(f"Starting State: {starting_state}")
function = _load_function(fully_qualified_function_name)
ending_state = function(starting_state)
_logger.info(f"Ending State: {ending_state}")
_dump_state(ending_state, out_path)
dump_state(ending_state, out_path, out_dumper)

return

Expand All @@ -47,18 +67,8 @@ def _configure_logger(debug, verbose):
_logger.info("using INFO logging level")


def _load_state(path: Optional[pathlib.Path]) -> Union[State, None]:
if path is not None:
_logger.debug(f"_load_state: loading from {path=}")
with open(path, "rb") as f:
state_ = dill.load(f)
else:
_logger.debug(f"_load_state: {path=} -> returning None")
state_ = None
return state_


def _load_function(fully_qualified_function_name: str):
"""Load a function by its fully qualified name, `module.function_name`"""
_logger.debug(f"_load_function: Loading function {fully_qualified_function_name}")
module_name, function_name = fully_qualified_function_name.rsplit(".", 1)
module = importlib.import_module(module_name)
Expand All @@ -67,17 +77,5 @@ def _load_function(fully_qualified_function_name: str):
return function


def _dump_state(state_: State, path: Optional[pathlib.Path]) -> None:
if path is not None:
_logger.debug(f"_dump_state: dumping to {path=}")
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "wb") as f:
dill.dump(state_, f)
else:
_logger.debug(f"_dump_state: {path=} so writing to stdout")
print(dill.dumps(state_))
return


if __name__ == "__main__":
typer.run(main)
20 changes: 0 additions & 20 deletions tests/test_load_dump_state.py

This file was deleted.

24 changes: 24 additions & 0 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pathlib
import tempfile
import uuid

from hypothesis import Verbosity, given, settings
from hypothesis import strategies as st

from autora.serializer import SerializersSupported, dump_state, load_state
from autora.state import StandardState


@given(
st.builds(StandardState, st.text(), st.text(), st.text(), st.lists(st.integers())),
st.sampled_from(SerializersSupported),
)
@settings(verbosity=Verbosity.verbose)
def test_load_inverts_dump(s, serializer):
"""Test that each serializer can be used to serialize and deserialize a state object."""
with tempfile.TemporaryDirectory() as dir:
path = pathlib.Path(dir, f"{str(uuid.uuid4())}")
print(path, s)

dump_state(s, path, dumper=serializer)
assert load_state(path, loader=serializer) == s
Loading

0 comments on commit bb1ed8a

Please sign in to comment.