diff --git a/src/autora/state.py b/src/autora/state.py index 76baeb54..da276fa7 100644 --- a/src/autora/state.py +++ b/src/autora/state.py @@ -23,6 +23,7 @@ Dict, Generic, List, + Literal, Mapping, Optional, Protocol, @@ -714,7 +715,8 @@ def add_field( type_=MISSING, default=MISSING, default_factory=MISSING, - metadata: Optional[Mapping] = None, + delta: Literal["replace", "extend", "append"] = "replace", + converter: Optional[Callable] = None, ): """ Return a new copy of the dataclass with an additional field @@ -764,9 +766,18 @@ def add_field( >>> s.add_field("f", type_=int, value="a string, not an int") StateDataClass(f='a string, not an int') - You can add metadata like a delta: + By default, fields are replaced if modified by Deltas: + >>> s.add_field("h", value=1) + Delta(h=2) + StateDataClass(h=2) + + (This is the same as: + >>> s.add_field("h", value=1, delta="replace") + Delta(h=2) + StateDataClass(h=2) + + You can specify a different delta type: >>> u = s.add_field("g", type_=List[int], default_factory=list, delta="extend") >>> u = u + Delta(g=[3]) + Delta(g=[4]) + >>> u StateDataClass(g=[3, 4]) >>> v = u.add_field("h", type_=int, default=None, delta="replace") @@ -776,7 +787,18 @@ def add_field( >>> v + Delta(h=1) + Delta(h=2) StateDataClass(g=[3, 4], h=2) - >>> + >>> w = v.add_field("i", type_=List[int], default_factory=list, delta="append") + >>> w + Delta(i=3) + Delta(i=9) + Delta(i=27) + StateDataClass(g=[3, 4], h=None, i=[3, 9, 27]) + + You can specify a converter: + >>> x = s.add_field("df", type_=pd.DataFrame, default=None, converter=pd.DataFrame) + >>> x + Delta(df = {"a": [1, 2, 3], "b": list("abc")}) + StateDataClass(df= a b + 0 1 a + 1 2 b + 2 3 c) + """ _field_kwargs = { k: v @@ -784,6 +806,12 @@ def add_field( if v is not MISSING } + _field_kwargs["metadata"] = { + k: v + for k, v in dict(delta=delta, converter=converter).items() + if v is not None + } + _field = field(**_field_kwargs) _dataclass_params = { key: getattr(getattr(self, "__dataclass_params__"), key) @@ -796,11 +824,14 @@ def add_field( bases=(self.__class__,), **_dataclass_params, ) + if value is not MISSING: new_value = {name: value} else: new_value = {} + new = new_class(**new_value, **asdict(self)) + return new