Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add alias for model in standardstate #88

Merged
merged 5 commits into from
Aug 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 226 additions & 10 deletions src/autora/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

_logger = logging.getLogger(__name__)


T = TypeVar("T")
C = TypeVar("C", covariant=True)

Expand Down Expand Up @@ -328,7 +327,12 @@ def _get_value(f, other: Union[Delta, Mapping]):
>>> from dataclasses import field, dataclass, fields
>>> @dataclass
... class Example:
... a: int = field()
... a: int = field() # base case
... b: List[int] = field(metadata={"aliases": {"ba": lambda b: [b]}}) # Single alias
... c: List[int] = field(metadata={"aliases": {
... "ca": lambda x: x, # pass the value unchanged
... "cb": lambda x: [x] # wrap the value in a list
... }}) # Multiple alias

For a field with no aliases, we retrieve values with the base name:
>>> f_a = fields(Example)[0]
Expand All @@ -343,19 +347,140 @@ def _get_value(f, other: Union[Delta, Mapping]):
>>> _get_value(f_a, Delta(b=2, a=1))
(1, 'a')

For fields with an alias, we retrieve values with the base name:
>>> f_b = fields(Example)[1]
>>> _get_value(f_b, Delta(b=[2]))
([2], 'b')

... or for the alias name, transformed by the alias lambda function:
>>> _get_value(f_b, Delta(ba=21))
([21], 'ba')

We preferentially get the base name, and then any aliases:
>>> _get_value(f_b, Delta(b=2, ba=21))
(2, 'b')

... , regardless of their order in the `Delta` object:
>>> _get_value(f_b, Delta(ba=21, b=2))
(2, 'b')

Other names are ignored:
>>> _get_value(f_b, Delta(a=1))
(None, None)

and the order of other names is unimportant:
>>> _get_value(f_b, Delta(a=1, b=2))
(2, 'b')

For fields with multiple aliases, we retrieve values with the base name:
>>> f_c = fields(Example)[2]
>>> _get_value(f_c, Delta(c=[3]))
([3], 'c')

... for any alias:
>>> _get_value(f_c, Delta(ca=31))
(31, 'ca')

... transformed by the alias lambda function :
>>> _get_value(f_c, Delta(cb=32))
([32], 'cb')

... and ignoring any other names:
>>> print(_get_value(f_c, Delta(a=1)))
(None, None)

... preferentially in the order base name, 1st alias, 2nd alias, ... nth alias:
>>> _get_value(f_c, Delta(c=3, ca=31, cb=32))
(3, 'c')

>>> _get_value(f_c, Delta(ca=31, cb=32))
(31, 'ca')

>>> _get_value(f_c, Delta(cb=32))
([32], 'cb')

>>> print(_get_value(f_c, Delta()))
(None, None)

This works with dict objects:
>>> _get_value(f_a, dict(a=13))
(13, 'a')

... with multiple keys:
>>> _get_value(f_b, dict(a=13, b=24, c=35))
(24, 'b')

... and with aliases:
>>> _get_value(f_b, dict(ba=222))
([222], 'ba')

This works with UserDicts:
>>> class MyDelta(UserDict):
... pass

>>> _get_value(f_a, MyDelta(a=14))
(14, 'a')

... with multiple keys:
>>> _get_value(f_b, MyDelta(a=1, b=4, c=9))
(4, 'b')

... and with aliases:
>>> _get_value(f_b, MyDelta(ba=234))
([234], 'ba')

"""

key = f.name
aliases = f.metadata.get("aliases", {})

value, used_key = None, None

if key in other.keys():
value = other[key]
used_key = key
elif aliases: # ... is not an empty dict
for alias_key, wrapping_function in aliases.items():
if alias_key in other:
value = wrapping_function(other[alias_key])
used_key = alias_key
break # we only evaluate the first match

return value, used_key


def _get_field_names_and_properties(s: State):
"""
Get a list of field names and their aliases from a State object

Args:
s: a State object

Returns: a list of field names and their aliases on `s`

Examples:
>>> from dataclasses import field
>>> @dataclass(frozen=True)
... class SomeState(State):
... l: List = field(default_factory=list)
... m: List = field(default_factory=list)
... @property
... def both(self):
... return self.l + self.m
>>> _get_field_names_and_properties(SomeState())
['both', 'l', 'm']
"""
result = _get_field_names_and_aliases(s)
property_names = [
attr
for attr in dir(s)
if isinstance(getattr(type(s), attr, None), property)
and attr not in dir(object)
and attr not in result
]
return property_names + result


def _get_field_names_and_aliases(s: State):
"""
Get a list of field names and their aliases from a State object
Expand All @@ -374,8 +499,23 @@ def _get_field_names_and_aliases(s: State):
>>> _get_field_names_and_aliases(SomeState())
['l', 'm']

>>> @dataclass(frozen=True)
... class SomeStateWithAliases(State):
... l: List = field(default_factory=list, metadata={"aliases": {"l1": None, "l2": None}})
... m: List = field(default_factory=list, metadata={"aliases": {"m1": None}})
>>> _get_field_names_and_aliases(SomeStateWithAliases())
['l', 'l1', 'l2', 'm', 'm1']

"""
result = [f.name for f in fields(s)]
result = []

for f in fields(s):
name = f.name
result.append(name)

aliases = f.metadata.get("aliases", {})
result.extend(aliases)

return result


Expand Down Expand Up @@ -692,19 +832,21 @@ def inputs_from_state(f, input_mapping: Dict = {}):
)

@wraps(f)
def _f(state_: S, /, **kwargs) -> S:
def _f(state_: State, /, **kwargs) -> State:
# Get the parameters needed which are available from the state_.
# All others must be provided as kwargs or default values on f.
assert is_dataclass(state_) or isinstance(state_, UserDict)
if is_dataclass(state_):
from_state = parameters_.intersection({i.name for i in fields(state_)})
from_state = parameters_.intersection(
_get_field_names_and_properties(state_)
)
arguments_from_state = {k: getattr(state_, k) for k in from_state}
from_state_input_mapping = {
reversed_mapping.get(field.name, field.name): getattr(
state_, field.name
reversed_mapping.get(field_name, field_name): getattr(
state_, field_name
)
for field in fields(state_)
if reversed_mapping.get(field.name, field.name) in parameters_
for field_name in _get_field_names_and_properties(state_)
if reversed_mapping.get(field_name, field_name) in parameters_
}
arguments_from_state.update(from_state_input_mapping)
elif isinstance(state_, UserDict):
Expand Down Expand Up @@ -1219,6 +1361,68 @@ class StandardState(State):
>>> (s + dm1 + dm2).models
[DummyClassifier(constant=1), DummyClassifier(constant=2), DummyClassifier(constant=3)]

The last model is available under the `model` property:
>>> (s + dm1 + dm2).model
DummyClassifier(constant=3)

If there is no model, `None` is returned:
>>> print(s.model)
None

`models` can also be updated using a Delta with a single `model`:
>>> dm3 = Delta(model=DummyClassifier(constant=4))
>>> (s + dm1 + dm3).model
DummyClassifier(constant=4)


We can use properties X, y, iv_names and dv_names as 'getters' ...
>>> x_v = Variable('x')
>>> y_v = Variable('y')
>>> variables = VariableCollection(independent_variables=[x_v], dependent_variables=[y_v])
>>> e_data = pd.DataFrame({'x': [1, 2, 3], 'y': [2, 4, 6]})
>>> s = StandardState(variables=variables, experiment_data=e_data)
>>> @inputs_from_state
... def show_X(X):
... return X
>>> show_X(s)
x
0 1
1 2
2 3

... but nothing happens if we use them as `setters`:
>>> @on_state
... def add_to_X(X):
... res = X.copy()
... res['x'] += 1
... return Delta(X=res)
>>> s = add_to_X(s)
>>> s.X
x
0 1
1 2
2 3

However, if the property has a deticated setter, we can still use them as getter:
>>> s.model is None
True

>>> from sklearn.linear_model import LinearRegression
>>> @on_state
... def add_model(_model):
... return Delta(model=_model)
>>> s = add_model(s, _model=LinearRegression())
>>> s.models
[LinearRegression()]

>>> s.model
LinearRegression()






"""

variables: Optional[VariableCollection] = field(
Expand All @@ -1232,7 +1436,7 @@ class StandardState(State):
)
models: List[BaseEstimator] = field(
default_factory=list,
metadata={"delta": "extend"},
metadata={"delta": "extend", "aliases": {"model": lambda model: [model]}},
)

@property
Expand Down Expand Up @@ -1329,6 +1533,18 @@ def y(self) -> pd.DataFrame:
return pd.DataFrame()
return self.experiment_data[self.dv_names]

@property
def model(self):
if len(self.models) == 0:
return None
# The property to access the backing field
return self.models[-1]

@model.setter
def model(self, value):
# Control the setting behavior
self.models.append(value)


X = TypeVar("X")
Y = TypeVar("Y")
Expand Down
Loading