diff --git a/src/autora/state.py b/src/autora/state.py index fdb8f4be..3c7f0f79 100644 --- a/src/autora/state.py +++ b/src/autora/state.py @@ -46,10 +46,163 @@ def __add__(self: C, other: Union[Delta, Mapping]) -> C: class StateDict(UserDict): + """ + Base object for UserDict which uses the Delta mechanism. + + Examples: + We first define an empty state + >>> s_0 = StateDict() + + Then we can add different fields with different Delta behaviours + >>> s_0.add_field("l", "extend", list("abc")) + >>> s_0.add_field("m", "replace", list("xyz")) + >>> s_0.l + ['a', 'b', 'c'] + >>> s_0.m + ['x', 'y', 'z'] + + We can add Deltas to it. Here, 'l' will be extended: + >>> s_1 = s_0 + Delta(l=list("def")) + >>> s_1.l + ['a', 'b', 'c', 'd', 'e', 'f'] + + ... whereas here, 'm' will be replaced: + >>> s_2 = s_1 + Delta(m=list("uvw")) + >>> s_2.m + ['u', 'v', 'w'] + + We can also chain Deltas: + >>> s_3 = s_2 + Delta(l=list("ghi")) + Delta(m=list("rst")) + >>> s_3.l + ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i'] + + >>> s_3.m + ['r', 's', 't'] + + ... or update multiple fields with one Delta: + >>> s_4 = s_3 + Delta(l=list("jkl"), m=list("opq")) + >>> s_4.l + ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l'] + + >>> s_4.m + ['o', 'p', 'q'] + + If we try to add a nonexistent field, nothing happens: + >>> s_5 = s_4 + Delta(n="not a field") + >>> 'n' in s_5 + False + + The update function replaces the entry: + >>> s_5.update(l=list("mno")) + >>> s_5.l + ['m', 'n', 'o'] + + We can also define fields which `append` the last result: + >>> s_5.add_field('n', 'append', list('abc')) + >>> s_6 = s_5 + Delta(n='d') + >>> s_6.n + ['a', 'b', 'c', 'd'] + + The metadata key "converter" is used to coerce types (inspired by + [PEP 712](https://peps.python.org/pep-0712/)): + >>> s_coerce = StateDict() + >>> s_coerce.add_field('o') + >>> s_coerce.add_field('p', converter=list) + >>> (s_coerce + Delta(o="not a list")).o + 'not a list' + + >>> (s_coerce + Delta(p='not a list')).p + ['n', 'o', 't', ' ', 'a', ' ', 'l', 'i', 's', 't'] + + If the input data are of the correct type, they are returned unaltered: + >>> (s_coerce + Delta(p=["a", "list"])).p + ['a', 'list'] + + With a converter, inputs are converted to the type that is output by the converter: + >>> s_coerce.add_field("q", converter=pd.DataFrame) + + If the type is already correct, the object is passed to the converter, + but should be returned unchanged: + >>> (s_coerce + Delta(q=pd.DataFrame([("a",1,"alpha"), ("b",2,"beta")],\ +columns=list("xyz")))).q + x y z + 0 a 1 alpha + 1 b 2 beta + + If the type is not correct, the object is converted if possible. For a DataFrame, + we can convert records: + >>> (s_coerce + Delta(q=[("a",1,"alpha"), ("b",2,"beta")])).q + 0 1 2 + 0 a 1 alpha + 1 b 2 beta + + ... or an array: + >>> (s_coerce + Delta(q=np.linspace([1, 2], [10, 15], 3))).q + 0 1 + 0 1.0 2.0 + 1 5.5 8.5 + 2 10.0 15.0 + + ... or a dictionary: + >>> (s_coerce + Delta(q={"a": [1,2,3], "b": [4,5,6]})).q + a b + 0 1 4 + 1 2 5 + 2 3 6 + + ... or a list: + >>> (s_coerce + Delta(q=[11, 12, 13])).q + 0 + 0 11 + 1 12 + 2 13 + + ... but not, for instance, a string: + >>> (s_coerce + Delta(q="not compatible with pd.DataFrame")).q + Traceback (most recent call last): + ... + ValueError: DataFrame constructor not properly called! + + We can define aliases for different potential field names: + >>> s_alias = StateDict() + >>> s_alias.add_field("things", "extend", aliases={"thing": lambda m: [m]}) + + + In the "normal" case, the Delta object is expected to include a list of data in the + format which is used to extend the object: + >>> s_alias = s_alias + Delta(things=["1", "2"]) + >>> s_alias.things + ['1', '2'] + + However, say the standard return from a step in AER is a single `thing`, rather than a + sequence: + >>> (s_alias + Delta(thing="3")).things + ['1', '2', '3'] + + If a cycle function relies on the existence of `s.thing` as a property of your state + `s`, rather than accessing `s.things[-1]`, you could additionally define a `getter`. + If you define such getters, the second argument must be a callable, in which case the input + to said callable will be interpreted as the state itself. + >>> s_alias.set_alias_getter("thing", lambda x: x["things"][-1]) + + At this point, you can access both `s.things` and `s.thing` as required by your code. + The State only shows `things` in the string representation. It exposes `things` as an + attribute: + >>> s_alias.things + ['1', '2'] + + ... but also exposes `thing`, which always returns the last value. + >>> s_alias.thing + '2' + + """ + def __init__(self, data: Optional[Dict] = None): super().__init__(data) - def add_field(self, name, delta="replace", default=None, aliases=None): + def add_field( + self, name, delta="replace", default=None, aliases=None, converter=None + ): self.data[name] = default if "_metadata" not in self.data.keys(): self.data["_metadata"] = {} @@ -57,6 +210,7 @@ def add_field(self, name, delta="replace", default=None, aliases=None): self.data["_metadata"][name]["default"] = default self.data["_metadata"][name]["delta"] = delta self.data["_metadata"][name]["aliases"] = aliases + self.data["_metadata"][name]["converter"] = converter def set_delta(self, name, delta): if "_metadata" not in self.data.keys(): @@ -67,19 +221,44 @@ def set_delta(self, name, delta): self.data["_metadata"][name]["aliases"] = None self.data["_metadata"][name]["delta"] = delta + def set_converter(self, name, converter): + if "_metadata" not in self.data.keys(): + self.data["_metadata"] = {} + if name not in self.data["_metadata"].keys(): + self.data["_metadata"][name] = {} + self.data["_metadata"][name]["default"] = None + self.data["_metadata"][name]["aliases"] = None + self.data["_metadata"][name]["converter"] = converter + + def set_alias(self, name, setter, getter): + if "_metadata" not in self.data.keys(): + self.data["_metadata"] = {} + if name not in self.data["_metadata"].keys(): + self.data["_metadata"][name] = {} + self.data["_metadata"][name]["default"] = None + self.data["_metadata"][name]["aliases"] = setter + self.data[f"_alias_getter_{name}"] = lambda: getter(self) + + def set_alias_getter(self, name, getter): + self.data[f"_alias_getter_{name}"] = lambda: getter(self) + def __setitem__(self, key, value): - if key != "_metadata" and ( - "_metadata" not in self.data.keys() - or key not in self.data["_metadata"].keys() - ): - warnings.warn( - f"Adding field {key} without metadata. Using defaults." - "Use add_field to safely initialize a field" + if ( + key != "_metadata" + and not key.startswith("_alias_getter") + and ( + "_metadata" not in self.data.keys() + or key not in self.data["_metadata"].keys() ) + ): self.add_field(key) super().__setitem__(key, value) def __getattr__(self, key): + if f"_alias_getter_{key}" in self.data and isinstance( + self.data[f"_alias_getter_{key}"], Callable + ): + return self.data[f"_alias_getter_{key}"]() if key in self.data: return self.data[key] raise AttributeError(f"'StateDict' object has no attribute '{key}'") @@ -88,10 +267,12 @@ def __add__(self, other: Union[Delta, Mapping]): updates = dict() other_fields_unused = list(other.keys()) for self_key in self.data: # Access the data dictionary within UserDict - other_value = other[self_key] if self_key in other else None + if self_key == "_metadata" or self_key.startswith("_alias_getter"): + continue + other_value, other_key = self._get_value(self_key, other) if other_value is None: continue - other_fields_unused.remove(self_key) + other_fields_unused.remove(other_key) self_field_key = self_key self_value = self.data[ @@ -100,9 +281,7 @@ def __add__(self, other: Union[Delta, Mapping]): delta_behavior = self.data["_metadata"][self_field_key]["delta"] if ( - constructor := self.data["_metadata"][self_field_key].get( - "converter", None - ) + constructor := self.data["_metadata"][self_field_key]["converter"] ) is not None: coerced_other_value = constructor(other_value) else: @@ -128,6 +307,128 @@ def __add__(self, other: Union[Delta, Mapping]): ) # Create a new instance of the same class with updated data return new + def _get_value(self, k, other: Union[Delta, Mapping]): + """ + Given a `StateDicts`'s `key` k, get a value from `other` and report its name. + + Returns: a tuple (the value, the key associated with that value) + + Examples: + >>> s = StateDict() + >>> s.add_field('a') + >>> s.add_field('b', aliases={"ba": lambda b: [b]}) + >>> s.add_field('c', aliases={"ca": lambda x: x, "cb": lambda x: [x]}) + + For a field with no aliases, we retrieve values with the base name: + >>> s._get_value('a', Delta(a=1)) + (1, 'a') + + ... and only the base name: + >>> s._get_value('a', Delta(b=2)) # no match for b + (None, None) + + Any other names._get_valueare unimportant: + >>> s._get_value('a', Delta(b=2, a=1)) + (1, 'a') + + For fields with an alias, we retrieve values with the base name: + >>> s._get_value('b', Delta(b=[2])) + ([2], 'b') + + ... or for the alias name, transformed by the alias lambda function: + >>> s._get_value('b', Delta(ba=21)) + ([21], 'ba') + + We preferentially get the base name, and then any aliases: + >>> s._get_value('b', Delta(b=2, ba=21)) + (2, 'b') + + ... regardless of their order in the `Delta` object: + >>> s._get_value('b', Delta(ba=21, b=2)) + (2, 'b') + + Other names are ignored: + >>> s._get_value('b', Delta(a=1)) + (None, None) + + and the order of other names is unimportant: + >>> s._get_value('b', Delta(a=1, b=2)) + (2, 'b') + + For fields with multiple aliases, we retrieve values with the base name: + >>> s._get_value('c', Delta(c=[3])) + ([3], 'c') + + ... for any alias: + >>> s._get_value('c', Delta(ca=31)) + (31, 'ca') + + ... transformed by the alias lambda function : + >>> s._get_value('c', Delta(cb=32)) + ([32], 'cb') + + ... and ignoring any other names: + >>> s._get_value('c', Delta(a=1)) + (None, None) + + ... preferentially in the order base name, 1st alias, 2nd alias, ... nth alias: + >>> s._get_value('c', Delta(c=3, ca=31, cb=32)) + (3, 'c') + + >>> s._get_value('c', Delta(ca=31, cb=32)) + (31, 'ca') + + >>> s._get_value('c', Delta(cb=32)) + ([32], 'cb') + + >>> s._get_value('c', Delta()) + (None, None) + + This works with dict objects: + >>> s._get_value('a', dict(a=13)) + (13, 'a') + + ... with multiple keys: + >>> s._get_value('b', dict(a=13, b=24, c=35)) + (24, 'b') + + ... and with aliases: + >>> s._get_value('b', dict(ba=222)) + ([222], 'ba') + + This works with UserDicts: + >>> class MyDelta(UserDict): + ... pass + + >>> s._get_value('a', MyDelta(a=14)) + (14, 'a') + + ... with multiple keys: + >>> s._get_value('b', MyDelta(a=1, b=4, c=9)) + (4, 'b') + + ... and with aliases: + >>> s._get_value('b', MyDelta(ba=234)) + ([234], 'ba') + + """ + + aliases = self.data["_metadata"][k].get("aliases", {}) + + value, used_key = None, None + + if k in other.keys(): + value = other[k] + used_key = k + elif aliases: # ... is not an empty dict + for alias_key, wrapping_function in aliases.items(): + if alias_key in other: + value = wrapping_function(other[alias_key]) + used_key = alias_key + break # we only evaluate the first match + + return value, used_key + State = StateDict @@ -888,9 +1189,11 @@ def _f(state_: S, /, **kwargs) -> S: from_state = parameters_.intersection({i.name for i in fields(state_)}) arguments_from_state = {k: getattr(state_, k) for k in from_state} from_state_input_mapping = { - reversed_mapping.get(f.name, f.name): getattr(state_, f.name) - for f in fields(state_) - if reversed_mapping.get(f.name, f.name) in parameters_ + reversed_mapping.get(field.name, field.name): getattr( + state_, field.name + ) + for field in fields(state_) + if reversed_mapping.get(field.name, field.name) in parameters_ } arguments_from_state.update(from_state_input_mapping) elif isinstance(state_, UserDict): @@ -1253,8 +1556,11 @@ def on_state( This also works on the StandardState or other States that are defined as UserDicts: >>> add_six(StandardState(conditions=[1, 2, 3,4])).conditions - [7, 8, 9, 10] - + 0 + 0 7 + 1 8 + 2 9 + 3 10 """ def decorator(f): @@ -1453,10 +1759,22 @@ def __init__(self, data: Optional[Dict] = None, **kwargs): if data is None: data = { "_metadata": { - "variables": {"default": None, "delta": "replace"}, - "conditions": {"default": None, "delta": "replace"}, - "experiment_data": {"default": None, "delta": "extend"}, - "models": {"default": None, "delta": "extend"}, + "variables": { + "default": None, + "delta": "replace", + "converter": VariableCollection, + }, + "conditions": { + "default": None, + "delta": "replace", + "converter": pd.DataFrame, + }, + "experiment_data": { + "default": None, + "delta": "extend", + "converter": pd.DataFrame, + }, + "models": {"default": None, "delta": "extend", "converter": list}, }, "variables": None, "conditions": None,