Skip to content

Commit

Permalink
feat: add support for deltas and converters in the add_field
Browse files Browse the repository at this point in the history
  • Loading branch information
hollandjg committed Nov 8, 2023
1 parent 0814a58 commit b017ca5
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions src/autora/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Dict,
Generic,
List,
Literal,
Mapping,
Optional,
Protocol,
Expand Down Expand Up @@ -714,7 +715,8 @@ def add_field(
type_=MISSING,
default=MISSING,
default_factory=MISSING,
metadata: Optional[Mapping] = None,
delta: Literal["replace", "extend", "append"] = "replace",
converter: Optional[Callable] = None,
):
"""
Return a new copy of the dataclass with an additional field
Expand Down Expand Up @@ -764,9 +766,18 @@ def add_field(
>>> s.add_field("f", type_=int, value="a string, not an int")
StateDataClass(f='a string, not an int')
You can add metadata like a delta:
By default, fields are replaced if modified by Deltas:
>>> s.add_field("h", value=1) + Delta(h=2)
StateDataClass(h=2)
(This is the same as:
>>> s.add_field("h", value=1, delta="replace") + Delta(h=2)
StateDataClass(h=2)
You can specify a different delta type:
>>> u = s.add_field("g", type_=List[int], default_factory=list, delta="extend")
>>> u = u + Delta(g=[3]) + Delta(g=[4])
>>> u
StateDataClass(g=[3, 4])
>>> v = u.add_field("h", type_=int, default=None, delta="replace")
Expand All @@ -776,14 +787,31 @@ def add_field(
>>> v + Delta(h=1) + Delta(h=2)
StateDataClass(g=[3, 4], h=2)
>>>
>>> w = v.add_field("i", type_=List[int], default_factory=list, delta="append")
>>> w + Delta(i=3) + Delta(i=9) + Delta(i=27)
StateDataClass(g=[3, 4], h=None, i=[3, 9, 27])
You can specify a converter:
>>> x = s.add_field("df", type_=pd.DataFrame, default=None, converter=pd.DataFrame)
>>> x + Delta(df = {"a": [1, 2, 3], "b": list("abc")})
StateDataClass(df= a b
0 1 a
1 2 b
2 3 c)
"""
_field_kwargs = {
k: v
for k, v in dict(default=default, default_factory=default_factory).items()
if v is not MISSING
}

_field_kwargs["metadata"] = {
k: v
for k, v in dict(delta=delta, converter=converter).items()
if v is not None
}

_field = field(**_field_kwargs)
_dataclass_params = {
key: getattr(getattr(self, "__dataclass_params__"), key)
Expand All @@ -796,11 +824,14 @@ def add_field(
bases=(self.__class__,),
**_dataclass_params,
)

if value is not MISSING:
new_value = {name: value}
else:
new_value = {}

new = new_class(**new_value, **asdict(self))

return new


Expand Down

0 comments on commit b017ca5

Please sign in to comment.