diff --git a/src/autora/state.py b/src/autora/state.py index f959f92f..e6bf16f3 100644 --- a/src/autora/state.py +++ b/src/autora/state.py @@ -327,7 +327,12 @@ def _get_value(f, other: Union[Delta, Mapping]): >>> from dataclasses import field, dataclass, fields >>> @dataclass ... class Example: - ... a: int = field() + ... a: int = field() # base case + ... b: List[int] = field(metadata={"aliases": {"ba": lambda b: [b]}}) # Single alias + ... c: List[int] = field(metadata={"aliases": { + ... "ca": lambda x: x, # pass the value unchanged + ... "cb": lambda x: [x] # wrap the value in a list + ... }}) # Multiple alias For a field with no aliases, we retrieve values with the base name: >>> f_a = fields(Example)[0] @@ -342,15 +347,104 @@ def _get_value(f, other: Union[Delta, Mapping]): >>> _get_value(f_a, Delta(b=2, a=1)) (1, 'a') + For fields with an alias, we retrieve values with the base name: + >>> f_b = fields(Example)[1] + >>> _get_value(f_b, Delta(b=[2])) + ([2], 'b') + + ... or for the alias name, transformed by the alias lambda function: + >>> _get_value(f_b, Delta(ba=21)) + ([21], 'ba') + + We preferentially get the base name, and then any aliases: + >>> _get_value(f_b, Delta(b=2, ba=21)) + (2, 'b') + + ... , regardless of their order in the `Delta` object: + >>> _get_value(f_b, Delta(ba=21, b=2)) + (2, 'b') + + Other names are ignored: + >>> _get_value(f_b, Delta(a=1)) + (None, None) + + and the order of other names is unimportant: + >>> _get_value(f_b, Delta(a=1, b=2)) + (2, 'b') + + For fields with multiple aliases, we retrieve values with the base name: + >>> f_c = fields(Example)[2] + >>> _get_value(f_c, Delta(c=[3])) + ([3], 'c') + + ... for any alias: + >>> _get_value(f_c, Delta(ca=31)) + (31, 'ca') + + ... transformed by the alias lambda function : + >>> _get_value(f_c, Delta(cb=32)) + ([32], 'cb') + + ... and ignoring any other names: + >>> print(_get_value(f_c, Delta(a=1))) + (None, None) + + ... preferentially in the order base name, 1st alias, 2nd alias, ... nth alias: + >>> _get_value(f_c, Delta(c=3, ca=31, cb=32)) + (3, 'c') + + >>> _get_value(f_c, Delta(ca=31, cb=32)) + (31, 'ca') + + >>> _get_value(f_c, Delta(cb=32)) + ([32], 'cb') + + >>> print(_get_value(f_c, Delta())) + (None, None) + + This works with dict objects: + >>> _get_value(f_a, dict(a=13)) + (13, 'a') + + ... with multiple keys: + >>> _get_value(f_b, dict(a=13, b=24, c=35)) + (24, 'b') + + ... and with aliases: + >>> _get_value(f_b, dict(ba=222)) + ([222], 'ba') + + This works with UserDicts: + >>> class MyDelta(UserDict): + ... pass + + >>> _get_value(f_a, MyDelta(a=14)) + (14, 'a') + + ... with multiple keys: + >>> _get_value(f_b, MyDelta(a=1, b=4, c=9)) + (4, 'b') + + ... and with aliases: + >>> _get_value(f_b, MyDelta(ba=234)) + ([234], 'ba') + """ key = f.name + aliases = f.metadata.get("aliases", {}) value, used_key = None, None if key in other.keys(): value = other[key] used_key = key + elif aliases: # ... is not an empty dict + for alias_key, wrapping_function in aliases.items(): + if alias_key in other: + value = wrapping_function(other[alias_key]) + used_key = alias_key + break # we only evaluate the first match return value, used_key @@ -405,8 +499,23 @@ def _get_field_names_and_aliases(s: State): >>> _get_field_names_and_aliases(SomeState()) ['l', 'm'] + >>> @dataclass(frozen=True) + ... class SomeStateWithAliases(State): + ... l: List = field(default_factory=list, metadata={"aliases": {"l1": None, "l2": None}}) + ... m: List = field(default_factory=list, metadata={"aliases": {"m1": None}}) + >>> _get_field_names_and_aliases(SomeStateWithAliases()) + ['l', 'l1', 'l2', 'm', 'm1'] + """ - result = [f.name for f in fields(s)] + result = [] + + for f in fields(s): + name = f.name + result.append(name) + + aliases = f.metadata.get("aliases", {}) + result.extend(aliases) + return result @@ -1252,6 +1361,20 @@ class StandardState(State): >>> (s + dm1 + dm2).models [DummyClassifier(constant=1), DummyClassifier(constant=2), DummyClassifier(constant=3)] + The last model is available under the `model` property: + >>> (s + dm1 + dm2).model + DummyClassifier(constant=3) + + If there is no model, `None` is returned: + >>> print(s.model) + None + + `models` can also be updated using a Delta with a single `model`: + >>> dm3 = Delta(model=DummyClassifier(constant=4)) + >>> (s + dm1 + dm3).model + DummyClassifier(constant=4) + + We can use properties X, y, iv_names and dv_names as 'getters' ... >>> x_v = Variable('x') >>> y_v = Variable('y') @@ -1280,6 +1403,24 @@ class StandardState(State): 1 2 2 3 + However, if the property has a deticated setter, we can still use them as getter: + >>> s.model is None + True + + >>> from sklearn.linear_model import LinearRegression + >>> @on_state + ... def add_model(_model): + ... return Delta(model=_model) + >>> s = add_model(s, _model=LinearRegression()) + >>> s.models + [LinearRegression()] + + >>> s.model + LinearRegression() + + + + """ @@ -1295,7 +1436,7 @@ class StandardState(State): ) models: List[BaseEstimator] = field( default_factory=list, - metadata={"delta": "extend"}, + metadata={"delta": "extend", "aliases": {"model": lambda model: [model]}}, ) @property @@ -1392,6 +1533,18 @@ def y(self) -> pd.DataFrame: return pd.DataFrame() return self.experiment_data[self.dv_names] + @property + def model(self): + if len(self.models) == 0: + return None + # The property to access the backing field + return self.models[-1] + + @model.setter + def model(self, value): + # Control the setting behavior + self.models.append(value) + X = TypeVar("X") Y = TypeVar("Y")