diff --git a/src/autora/state.py b/src/autora/state.py index 75e11838..e6bf16f3 100644 --- a/src/autora/state.py +++ b/src/autora/state.py @@ -32,7 +32,6 @@ _logger = logging.getLogger(__name__) - T = TypeVar("T") C = TypeVar("C", covariant=True) @@ -328,7 +327,12 @@ def _get_value(f, other: Union[Delta, Mapping]): >>> from dataclasses import field, dataclass, fields >>> @dataclass ... class Example: - ... a: int = field() + ... a: int = field() # base case + ... b: List[int] = field(metadata={"aliases": {"ba": lambda b: [b]}}) # Single alias + ... c: List[int] = field(metadata={"aliases": { + ... "ca": lambda x: x, # pass the value unchanged + ... "cb": lambda x: [x] # wrap the value in a list + ... }}) # Multiple alias For a field with no aliases, we retrieve values with the base name: >>> f_a = fields(Example)[0] @@ -343,19 +347,140 @@ def _get_value(f, other: Union[Delta, Mapping]): >>> _get_value(f_a, Delta(b=2, a=1)) (1, 'a') + For fields with an alias, we retrieve values with the base name: + >>> f_b = fields(Example)[1] + >>> _get_value(f_b, Delta(b=[2])) + ([2], 'b') + + ... or for the alias name, transformed by the alias lambda function: + >>> _get_value(f_b, Delta(ba=21)) + ([21], 'ba') + + We preferentially get the base name, and then any aliases: + >>> _get_value(f_b, Delta(b=2, ba=21)) + (2, 'b') + + ... , regardless of their order in the `Delta` object: + >>> _get_value(f_b, Delta(ba=21, b=2)) + (2, 'b') + + Other names are ignored: + >>> _get_value(f_b, Delta(a=1)) + (None, None) + + and the order of other names is unimportant: + >>> _get_value(f_b, Delta(a=1, b=2)) + (2, 'b') + + For fields with multiple aliases, we retrieve values with the base name: + >>> f_c = fields(Example)[2] + >>> _get_value(f_c, Delta(c=[3])) + ([3], 'c') + + ... for any alias: + >>> _get_value(f_c, Delta(ca=31)) + (31, 'ca') + + ... transformed by the alias lambda function : + >>> _get_value(f_c, Delta(cb=32)) + ([32], 'cb') + + ... and ignoring any other names: + >>> print(_get_value(f_c, Delta(a=1))) + (None, None) + + ... preferentially in the order base name, 1st alias, 2nd alias, ... nth alias: + >>> _get_value(f_c, Delta(c=3, ca=31, cb=32)) + (3, 'c') + + >>> _get_value(f_c, Delta(ca=31, cb=32)) + (31, 'ca') + + >>> _get_value(f_c, Delta(cb=32)) + ([32], 'cb') + + >>> print(_get_value(f_c, Delta())) + (None, None) + + This works with dict objects: + >>> _get_value(f_a, dict(a=13)) + (13, 'a') + + ... with multiple keys: + >>> _get_value(f_b, dict(a=13, b=24, c=35)) + (24, 'b') + + ... and with aliases: + >>> _get_value(f_b, dict(ba=222)) + ([222], 'ba') + + This works with UserDicts: + >>> class MyDelta(UserDict): + ... pass + + >>> _get_value(f_a, MyDelta(a=14)) + (14, 'a') + + ... with multiple keys: + >>> _get_value(f_b, MyDelta(a=1, b=4, c=9)) + (4, 'b') + + ... and with aliases: + >>> _get_value(f_b, MyDelta(ba=234)) + ([234], 'ba') + """ key = f.name + aliases = f.metadata.get("aliases", {}) value, used_key = None, None if key in other.keys(): value = other[key] used_key = key + elif aliases: # ... is not an empty dict + for alias_key, wrapping_function in aliases.items(): + if alias_key in other: + value = wrapping_function(other[alias_key]) + used_key = alias_key + break # we only evaluate the first match return value, used_key +def _get_field_names_and_properties(s: State): + """ + Get a list of field names and their aliases from a State object + + Args: + s: a State object + + Returns: a list of field names and their aliases on `s` + + Examples: + >>> from dataclasses import field + >>> @dataclass(frozen=True) + ... class SomeState(State): + ... l: List = field(default_factory=list) + ... m: List = field(default_factory=list) + ... @property + ... def both(self): + ... return self.l + self.m + >>> _get_field_names_and_properties(SomeState()) + ['both', 'l', 'm'] + """ + result = _get_field_names_and_aliases(s) + property_names = [ + attr + for attr in dir(s) + if isinstance(getattr(type(s), attr, None), property) + and attr not in dir(object) + and attr not in result + ] + return property_names + result + + def _get_field_names_and_aliases(s: State): """ Get a list of field names and their aliases from a State object @@ -374,8 +499,23 @@ def _get_field_names_and_aliases(s: State): >>> _get_field_names_and_aliases(SomeState()) ['l', 'm'] + >>> @dataclass(frozen=True) + ... class SomeStateWithAliases(State): + ... l: List = field(default_factory=list, metadata={"aliases": {"l1": None, "l2": None}}) + ... m: List = field(default_factory=list, metadata={"aliases": {"m1": None}}) + >>> _get_field_names_and_aliases(SomeStateWithAliases()) + ['l', 'l1', 'l2', 'm', 'm1'] + """ - result = [f.name for f in fields(s)] + result = [] + + for f in fields(s): + name = f.name + result.append(name) + + aliases = f.metadata.get("aliases", {}) + result.extend(aliases) + return result @@ -692,19 +832,21 @@ def inputs_from_state(f, input_mapping: Dict = {}): ) @wraps(f) - def _f(state_: S, /, **kwargs) -> S: + def _f(state_: State, /, **kwargs) -> State: # Get the parameters needed which are available from the state_. # All others must be provided as kwargs or default values on f. assert is_dataclass(state_) or isinstance(state_, UserDict) if is_dataclass(state_): - from_state = parameters_.intersection({i.name for i in fields(state_)}) + from_state = parameters_.intersection( + _get_field_names_and_properties(state_) + ) arguments_from_state = {k: getattr(state_, k) for k in from_state} from_state_input_mapping = { - reversed_mapping.get(field.name, field.name): getattr( - state_, field.name + reversed_mapping.get(field_name, field_name): getattr( + state_, field_name ) - for field in fields(state_) - if reversed_mapping.get(field.name, field.name) in parameters_ + for field_name in _get_field_names_and_properties(state_) + if reversed_mapping.get(field_name, field_name) in parameters_ } arguments_from_state.update(from_state_input_mapping) elif isinstance(state_, UserDict): @@ -1219,6 +1361,68 @@ class StandardState(State): >>> (s + dm1 + dm2).models [DummyClassifier(constant=1), DummyClassifier(constant=2), DummyClassifier(constant=3)] + The last model is available under the `model` property: + >>> (s + dm1 + dm2).model + DummyClassifier(constant=3) + + If there is no model, `None` is returned: + >>> print(s.model) + None + + `models` can also be updated using a Delta with a single `model`: + >>> dm3 = Delta(model=DummyClassifier(constant=4)) + >>> (s + dm1 + dm3).model + DummyClassifier(constant=4) + + + We can use properties X, y, iv_names and dv_names as 'getters' ... + >>> x_v = Variable('x') + >>> y_v = Variable('y') + >>> variables = VariableCollection(independent_variables=[x_v], dependent_variables=[y_v]) + >>> e_data = pd.DataFrame({'x': [1, 2, 3], 'y': [2, 4, 6]}) + >>> s = StandardState(variables=variables, experiment_data=e_data) + >>> @inputs_from_state + ... def show_X(X): + ... return X + >>> show_X(s) + x + 0 1 + 1 2 + 2 3 + + ... but nothing happens if we use them as `setters`: + >>> @on_state + ... def add_to_X(X): + ... res = X.copy() + ... res['x'] += 1 + ... return Delta(X=res) + >>> s = add_to_X(s) + >>> s.X + x + 0 1 + 1 2 + 2 3 + + However, if the property has a deticated setter, we can still use them as getter: + >>> s.model is None + True + + >>> from sklearn.linear_model import LinearRegression + >>> @on_state + ... def add_model(_model): + ... return Delta(model=_model) + >>> s = add_model(s, _model=LinearRegression()) + >>> s.models + [LinearRegression()] + + >>> s.model + LinearRegression() + + + + + + """ variables: Optional[VariableCollection] = field( @@ -1232,7 +1436,7 @@ class StandardState(State): ) models: List[BaseEstimator] = field( default_factory=list, - metadata={"delta": "extend"}, + metadata={"delta": "extend", "aliases": {"model": lambda model: [model]}}, ) @property @@ -1329,6 +1533,18 @@ def y(self) -> pd.DataFrame: return pd.DataFrame() return self.experiment_data[self.dv_names] + @property + def model(self): + if len(self.models) == 0: + return None + # The property to access the backing field + return self.models[-1] + + @model.setter + def model(self, value): + # Control the setting behavior + self.models.append(value) + X = TypeVar("X") Y = TypeVar("Y")