Skip to content

Commit

Permalink
feat: add support for deltas and converters in the add_field
Browse files Browse the repository at this point in the history
  • Loading branch information
hollandjg committed Nov 8, 2023
1 parent 5e8aef0 commit 0814a58
Showing 1 changed file with 107 additions and 1 deletion.
108 changes: 107 additions & 1 deletion src/autora/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@
import logging
import warnings
from collections import UserDict
from dataclasses import dataclass, field, fields, is_dataclass, replace
from dataclasses import (
asdict,
dataclass,
field,
fields,
is_dataclass,
make_dataclass,
replace,
)
from enum import Enum
from functools import singledispatch, wraps
from typing import (
Expand All @@ -34,6 +42,8 @@
T = TypeVar("T")
C = TypeVar("C", covariant=True)

MISSING = ...


class DeltaAddable(Protocol[C]):
"""A class which a Delta or other Mapping can be added to, returning the same class"""
Expand Down Expand Up @@ -697,6 +707,102 @@ def update(self, **kwargs):
"""
return self + Delta(**kwargs)

def add_field(
self,
name: str,
value=MISSING,
type_=MISSING,
default=MISSING,
default_factory=MISSING,
metadata: Optional[Mapping] = None,
):
"""
Return a new copy of the dataclass with an additional field
Start with a StateDataClass; here is an empty one:
>>> s = StateDataClass()
>>> s
StateDataClass()
You can add a field with a new value:
>>> t = s.add_field("b", value=1)
>>> t
StateDataClass(b=1)
The original State is unchanged:
>>> s
StateDataClass()
You can add a field with a default value which is None:
>>> s.add_field("a", default=None)
StateDataClass(a=None)
... or a field with a default_factory (like a list):
>>> s.add_field("l", default_factory=list)
StateDataClass(l=[])
... but not both:
>>> s.add_field("l", default_factory=list, default=None)
Traceback (most recent call last):
...
ValueError: cannot specify both default and default_factory
A field with a default value can also have a different value in the instantiation:
>>> s.add_field("c", default=0, value=1)
StateDataClass(c=1)
... and here with a default_factory:
>>> s.add_field("d", default_factory=list, value=[1, 2, 3])
StateDataClass(d=[1, 2, 3])
You can specify the "type" of the field:
>>> s.add_field("e", type_=List[int], value=[1, 2, 3])
StateDataClass(e=[1, 2, 3])
... but this is only for documentation and will throw no error if the value is wrong
>>> s.add_field("f", type_=int, value="a string, not an int")
StateDataClass(f='a string, not an int')
You can add metadata like a delta:
>>> u = s.add_field("g", type_=List[int], default_factory=list, delta="extend")
>>> u = u + Delta(g=[3]) + Delta(g=[4])
StateDataClass(g=[3, 4])
>>> v = u.add_field("h", type_=int, default=None, delta="replace")
>>> v
StateDataClass(g=[3, 4], h=None)
>>> v + Delta(h=1) + Delta(h=2)
StateDataClass(g=[3, 4], h=2)
>>>
"""
_field_kwargs = {
k: v
for k, v in dict(default=default, default_factory=default_factory).items()
if v is not MISSING
}

_field = field(**_field_kwargs)
_dataclass_params = {
key: getattr(getattr(self, "__dataclass_params__"), key)
for key in ("frozen", "init", "repr", "eq", "order", "unsafe_hash")
}

new_class = make_dataclass(
cls_name=self.__class__.__name__,
fields=[(name, type_, _field)],
bases=(self.__class__,),
**_dataclass_params,
)
if value is not MISSING:
new_value = {name: value}
else:
new_value = {}
new = new_class(**new_value, **asdict(self))
return new


def _get_value(f, other: Union[Delta, Mapping]):
"""
Expand Down

0 comments on commit 0814a58

Please sign in to comment.