diff --git a/src/ducktools/classbuilder/__init__.py b/src/ducktools/classbuilder/__init__.py index 045500c..1387031 100644 --- a/src/ducktools/classbuilder/__init__.py +++ b/src/ducktools/classbuilder/__init__.py @@ -19,20 +19,34 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. + +# In this module there are some internal bits of circular logic. +# +# 'Field' needs to exist in order to be used in gatherers, but is itself a +# partially constructed class. These constructed attributes are placed on +# 'Field' post construction. +# +# The 'SlotMakerMeta' metaclass generates 'Field' instances to go in __slots__ +# but is also the metaclass used to construct 'Field'. +# Field itself sidesteps this by defining __slots__ to avoid that branch. + import sys -from .annotations import get_annotations, is_classvar +from .annotations import get_ns_annotations, is_classvar __version__ = "v0.6.0" # Change this name if you make heavy modifications INTERNALS_DICT = "__classbuilder_internals__" +META_GATHERER_NAME = "_meta_gatherer" # If testing, make Field classes frozen to make sure attributes are not # overwritten. When running this is a performance penalty so it is not required. _UNDER_TESTING = "pytest" in sys.modules +_MPT = type(type.__dict__) + def get_fields(cls, *, local=False): """ @@ -77,6 +91,17 @@ def __repr__(self): NOTHING = _NothingType() +# KW_ONLY sentinel 'type' to use to indicate all subsequent attributes are +# keyword only +# noinspection PyPep8Naming +class _KW_ONLY_TYPE: + def __repr__(self): + return "" + + +KW_ONLY = _KW_ONLY_TYPE() + + class MethodMaker: """ The descriptor class to place where methods should be generated. @@ -124,29 +149,51 @@ def cls_init_maker(cls): flags = get_flags(cls) arglist = [] + kw_only_arglist = [] assignments = [] globs = {} - if flags.get("kw_only", False): - arglist.append("*") + kw_only_flag = flags.get("kw_only", False) for k, v in fields.items(): - if v.default is not null: - globs[f"_{k}_default"] = v.default - arg = f"{k}=_{k}_default" - assignment = f"self.{k} = {k}" - elif v.default_factory is not null: - globs[f"_{k}_factory"] = v.default_factory - arg = f"{k}=None" - assignment = f"self.{k} = _{k}_factory() if {k} is None else {k}" - else: - arg = f"{k}" - assignment = f"self.{k} = {k}" + if v.init: + if v.default is not null: + globs[f"_{k}_default"] = v.default + arg = f"{k}=_{k}_default" + assignment = f"self.{k} = {k}" + elif v.default_factory is not null: + globs[f"_{k}_factory"] = v.default_factory + arg = f"{k}=None" + assignment = f"self.{k} = _{k}_factory() if {k} is None else {k}" + else: + arg = f"{k}" + assignment = f"self.{k} = {k}" - arglist.append(arg) - assignments.append(assignment) + if kw_only_flag or v.kw_only: + kw_only_arglist.append(arg) + else: + arglist.append(arg) + + assignments.append(assignment) + else: + if v.default is not null: + globs[f"_{k}_default"] = v.default + assignment = f"self.{k} = _{k}_default" + assignments.append(assignment) + elif v.default_factory is not null: + globs[f"_{k}_factory"] = v.default_factory + assignment = f"self.{k} = _{k}_factory()" + assignments.append(assignment) + + pos_args = ", ".join(arglist) + kw_args = ", ".join(kw_only_arglist) + if pos_args and kw_args: + args = f"{pos_args}, *, {kw_args}" + elif kw_args: + args = f"*, {kw_args}" + else: + args = pos_args - args = ", ".join(arglist) assigns = "\n ".join(assignments) if assignments else "pass\n" code = ( f"def __init__(self, {args}):\n" @@ -166,23 +213,75 @@ def cls_init_maker(cls): init_generator = get_init_generator() -def repr_generator(cls): - fields = get_fields(cls) - content = ", ".join( - f"{name}={{self.{name}!r}}" - for name, attrib in fields.items() - ) - code = ( - f"def __repr__(self):\n" - f" return f'{{type(self).__qualname__}}({content})'\n" - ) - globs = {} - return code, globs +def get_repr_generator(recursion_safe=False, eval_safe=False): + """ + + :param recursion_safe: use reprlib.recursive_repr + :param eval_safe: if the repr is known not to eval correctly, + generate a repr which will intentionally + not evaluate. + :return: + """ + def cls_repr_generator(cls): + fields = get_fields(cls) + + globs = {} + will_eval = True + valid_names = [] + + for name, fld in fields.items(): + if fld.repr: + valid_names.append(name) + + if will_eval and (fld.init ^ fld.repr): + will_eval = False + + content = ", ".join( + f"{name}={{self.{name}!r}}" + for name in valid_names + ) + + if recursion_safe: + import reprlib + globs["_recursive_repr"] = reprlib.recursive_repr() + recursion_func = "@_recursive_repr\n" + else: + recursion_func = "" + + if eval_safe and will_eval is False: + if content: + code = ( + f"{recursion_func}" + f"def __repr__(self):\n" + f" return f''\n" + ) + else: + code = ( + f"{recursion_func}" + f"def __repr__(self):\n" + f" return f''\n" + ) + else: + code = ( + f"{recursion_func}" + f"def __repr__(self):\n" + f" return f'{{type(self).__qualname__}}({content})'\n" + ) + + return code, globs + return cls_repr_generator + + +repr_generator = get_repr_generator() def eq_generator(cls): class_comparison = "self.__class__ is other.__class__" - field_names = get_fields(cls) + field_names = [ + name + for name, attrib in get_fields(cls).items() + if attrib.compare + ] if field_names: selfvals = ",".join(f"self.{name}" for name in field_names) @@ -316,26 +415,58 @@ def builder(cls=None, /, *, gatherer, methods, flags=None): return cls +# Slot gathering tools +# Subclass of dict to be identifiable by isinstance checks +# For anything more complicated this could be made into a Mapping +class SlotFields(dict): + """ + A plain dict subclass. + + For declaring slotfields there are no additional features required + other than recognising that this is intended to be used as a class + generating dict and isn't a regular dictionary that ended up in + `__slots__`. + + This should be replaced on `__slots__` after fields have been gathered. + """ + def __repr__(self): + return f"SlotFields({super().__repr__()})" + + # Tool to convert annotations to slots as a metaclass class SlotMakerMeta(type): """ - Metaclass to convert annotations to slots. + Metaclass to convert annotations or Field(...) attributes to slots. Will not convert `ClassVar` hinted values. """ def __new__(cls, name, bases, ns, slots=True, **kwargs): + # This should only run if slots=True is declared + # and __slots__ have not already been defined + if slots and "__slots__" not in ns: + # Check if a different gatherer has been set in any base classes + # Default to unified gatherer + gatherer = ns.get(META_GATHERER_NAME, None) + if not gatherer: + for base in bases: + if g := getattr(base, META_GATHERER_NAME, None): + gatherer = g + break + + if not gatherer: + gatherer = unified_gatherer + + # Obtain slots from annotations or attributes + cls_fields, cls_modifications = gatherer(ns) + for k, v in cls_modifications.items(): + if v is NOTHING: + ns.pop(k) + else: + ns[k] = v + + # Place slots *after* everything else to be safe + ns["__slots__"] = SlotFields(cls_fields) - # Obtain slots from annotations - if "__slots__" not in ns and slots: - cls_annotations = get_annotations(ns) - cls_slots = SlotFields({ - k: ns.pop(k, NOTHING) - for k, v in cls_annotations.items() - if not is_classvar(v) - }) - ns["__slots__"] = cls_slots - - # Make new slotted class new_cls = super().__new__(cls, name, bases, ns, **kwargs) return new_cls @@ -354,16 +485,30 @@ class Field(metaclass=SlotMakerMeta): Note: When run under `pytest`, Field instances are Frozen. When subclassing, passing `frozen=True` will make your subclass frozen. + + :param default: Standard default value to be used for attributes with this field. + :param default_factory: A zero-argument function to be called to generate a + default value, useful for mutable obects like lists. + :param type: The type of the attribute to be assigned by this field. + :param doc: The documentation for the attribute that appears when calling + help(...) on the class. (Only in slotted classes). + :param init: Include in the class __init__ parameters. + :param repr: Include in the class __repr__. + :param compare: Include in the class __eq__. + :param kw_only: Make this a keyword only parameter in __init__. """ - __slots__ = { - "default": "Standard default value to be used for attributes with" - "this field.", - "default_factory": "A 0 argument function to be called to generate " - "a default value, useful for mutable objects like " - "lists.", - "type": "The type of the attribute to be assigned by this field.", - "doc": "The documentation that appears when calling help(...) on the class." - } + # If this base class did not define __slots__ the metaclass would break it. + # This will be replaced by the builder. + __slots__ = SlotFields( + default=NOTHING, + default_factory=NOTHING, + type=NOTHING, + doc=None, + init=True, + repr=True, + compare=True, + kw_only=False, + ) # noinspection PyShadowingBuiltins def __init__( @@ -373,12 +518,26 @@ def __init__( default_factory=NOTHING, type=NOTHING, doc=None, + init=True, + repr=True, + compare=True, + kw_only=False, ): + # The init function for 'Field' cannot be generated + # as 'Field' needs to exist first. + # repr and comparison functions are generated as these + # do not need to exist to create initial Fields. + self.default = default self.default_factory = default_factory self.type = type self.doc = doc + self.init = init + self.repr = repr + self.compare = compare + self.kw_only = kw_only + self.validate_field() def __init_subclass__(cls, frozen=False): @@ -388,15 +547,21 @@ def __init_subclass__(cls, frozen=False): builder( cls, - gatherer=slot_gatherer, + gatherer=unified_gatherer, methods=field_methods, flags={"slotted": True, "kw_only": True} ) def validate_field(self): + cls_name = self.__class__.__name__ if self.default is not NOTHING and self.default_factory is not NOTHING: raise AttributeError( - "Cannot define both a default value and a default factory." + f"{cls_name} cannot define both a default value and a default factory." + ) + + if self.kw_only and not self.init: + raise AttributeError( + f"{cls_name} cannot be keyword only if it is not in init." ) @classmethod @@ -416,66 +581,6 @@ def from_field(cls, fld, /, **kwargs): return cls(**argument_dict) -class GatheredFields: - __slots__ = ("fields", "modifications") - - def __init__(self, fields, modifications): - self.fields = fields - self.modifications = modifications - - def __call__(self, cls): - return self.fields, self.modifications - - -# Use the builder to generate __repr__ and __eq__ methods -# for both Field and GatheredFields -_field_internal = { - "default": Field(default=NOTHING), - "default_factory": Field(default=NOTHING), - "type": Field(default=NOTHING), - "doc": Field(default=None), -} - -_gathered_field_internal = { - "fields": Field(default=NOTHING), - "modifications": Field(default=NOTHING), -} - -_field_methods = {repr_maker, eq_maker} -if _UNDER_TESTING: - _field_methods.update({frozen_setattr_maker, frozen_delattr_maker}) - -builder( - Field, - gatherer=GatheredFields(_field_internal, {}), - methods=_field_methods, - flags={"slotted": True, "kw_only": True}, -) - -builder( - GatheredFields, - gatherer=GatheredFields(_gathered_field_internal, {}), - methods={repr_maker, eq_maker}, - flags={"slotted": True, "kw_only": False}, -) - - -# Slot gathering tools -# Subclass of dict to be identifiable by isinstance checks -# For anything more complicated this could be made into a Mapping -class SlotFields(dict): - """ - A plain dict subclass. - - For declaring slotfields there are no additional features required - other than recognising that this is intended to be used as a class - generating dict and isn't a regular dictionary that ended up in - `__slots__`. - - This should be replaced on `__slots__` after fields have been gathered. - """ - - def make_slot_gatherer(field_type=Field): """ Create a new annotation gatherer that will work with `Field` instances @@ -485,16 +590,17 @@ def make_slot_gatherer(field_type=Field): :return: A slot gatherer that will check for and generate Fields of the type field_type. """ - def field_slot_gatherer(cls): + def field_slot_gatherer(cls_or_ns): """ Gather field information for class generation based on __slots__ - :param cls: Class to gather field information from + :param cls_or_ns: Class to gather field information from (or class namespace) :return: dict of field_name: Field(...) """ + cls_dict = cls_or_ns if isinstance(cls_or_ns, (_MPT, dict)) else cls_or_ns.__dict__ try: - cls_slots = cls.__dict__["__slots__"] + cls_slots = cls_dict["__slots__"] except KeyError: raise AttributeError( "__slots__ must be defined as an instance of SlotFields " @@ -509,9 +615,7 @@ def field_slot_gatherer(cls): # Don't want to mutate original annotations so make a copy if it exists # Looking at the dict is a Python3.9 or earlier requirement - cls_annotations = { - **cls.__dict__.get("__annotations__", {}) - } + cls_annotations = get_ns_annotations(cls_dict) cls_fields = {} slot_replacement = {} @@ -548,51 +652,6 @@ def field_slot_gatherer(cls): return field_slot_gatherer -slot_gatherer = make_slot_gatherer() - - -def check_argument_order(cls): - """ - Raise a SyntaxError if the argument order will be invalid for a generated - `__init__` function. - - :param cls: class being built - """ - fields = get_fields(cls) - used_default = False - for k, v in fields.items(): - if v.default is NOTHING and v.default_factory is NOTHING: - if used_default: - raise SyntaxError( - f"non-default argument {k!r} follows default argument" - ) - else: - used_default = True - - -# Class Decorators -def slotclass(cls=None, /, *, methods=default_methods, syntax_check=True): - """ - Example of class builder in action using __slots__ to find fields. - - :param cls: Class to be analysed and modified - :param methods: MethodMakers to be added to the class - :param syntax_check: check there are no arguments without defaults - after arguments with defaults. - :return: Modified class - """ - if not cls: - return lambda cls_: slotclass(cls_, methods=methods, syntax_check=syntax_check) - - cls = builder(cls, gatherer=slot_gatherer, methods=methods, flags={"slotted": True}) - - if syntax_check: - check_argument_order(cls) - - return cls - - -# Annotation based class tools def make_annotation_gatherer( field_type=Field, leave_default_values=True, @@ -606,35 +665,49 @@ def make_annotation_gatherer( default values in place as class variables. :return: An annotation gatherer with these settings. """ - def field_annotation_gatherer(cls): + def field_annotation_gatherer(cls_or_ns): + cls_dict = cls_or_ns if isinstance(cls_or_ns, (_MPT, dict)) else cls_or_ns.__dict__ cls_fields: dict[str, field_type] = {} modifications = {} - cls_annotations = get_annotations(cls.__dict__) + cls_annotations = get_ns_annotations(cls_dict) + cls_slots = cls_dict.get("__slots__", {}) + + kw_flag = False for k, v in cls_annotations.items(): # Ignore ClassVar if is_classvar(v): continue - attrib = getattr(cls, k, NOTHING) + if v is KW_ONLY: + if kw_flag: + raise SyntaxError("KW_ONLY sentinel may only appear once.") + kw_flag = True + continue + + attrib = cls_dict.get(k, NOTHING) if attrib is not NOTHING: if isinstance(attrib, field_type): - attrib = field_type.from_field(attrib, type=v) + kw_only = attrib.kw_only or kw_flag + + attrib = field_type.from_field(attrib, type=v, kw_only=kw_only) if attrib.default is not NOTHING and leave_default_values: modifications[k] = attrib.default else: # NOTHING sentinel indicates a value should be removed modifications[k] = NOTHING - else: - attrib = field_type(default=attrib, type=v) + elif k not in cls_slots: + attrib = field_type(default=attrib, type=v, kw_only=kw_flag) if not leave_default_values: modifications[k] = NOTHING + else: + attrib = field_type(type=v, kw_only=kw_flag) else: - attrib = field_type(type=v) + attrib = field_type(type=v, kw_only=kw_flag) cls_fields[k] = attrib @@ -643,8 +716,140 @@ def field_annotation_gatherer(cls): return field_annotation_gatherer +def make_field_gatherer( + field_type=Field, + leave_default_values=True, +): + def field_attribute_gatherer(cls_or_ns): + cls_dict = cls_or_ns if isinstance(cls_or_ns, (_MPT, dict)) else cls_or_ns.__dict__ + cls_attributes = { + k: v + for k, v in cls_dict.items() + if isinstance(v, field_type) + } + cls_annotations = get_ns_annotations(cls_dict) + + cls_modifications = {} + + for name in cls_attributes.keys(): + attrib = cls_attributes[name] + if leave_default_values: + cls_modifications[name] = attrib.default + else: + cls_modifications[name] = NOTHING + + if (anno := cls_annotations.get(name, NOTHING)) is not NOTHING: + cls_attributes[name] = field_type.from_field(attrib, type=anno) + + return cls_attributes, cls_modifications + return field_attribute_gatherer + + +def make_unified_gatherer( + field_type=Field, + leave_default_values=True, +): + """ + Create a gatherer that will work via first slots, then + Field(...) class attributes and finally annotations if + no unannotated Field(...) attributes are present. + + :param field_type: The field class to use for gathering + :param leave_default_values: leave default values in place + :return: gatherer function + """ + slot_g = make_slot_gatherer(field_type) + anno_g = make_annotation_gatherer(field_type, leave_default_values) + attrib_g = make_field_gatherer(field_type, leave_default_values) + + def field_unified_gatherer(cls_or_ns): + cls_dict = cls_or_ns if isinstance(cls_or_ns, (_MPT, dict)) else cls_or_ns.__dict__ + cls_slots = cls_dict.get("__slots__") + + if isinstance(cls_slots, SlotFields): + return slot_g(cls_dict) + + # To choose between annotation and attribute gatherers + # compare sets of names. + # Don't bother evaluating string annotations, as we only need names + cls_annotations = get_ns_annotations(cls_dict, eval_str=False) + cls_attributes = { + k: v for k, v in cls_dict.items() if isinstance(v, field_type) + } + + cls_annotation_names = cls_annotations.keys() + cls_attribute_names = cls_attributes.keys() + + if set(cls_annotation_names).issuperset(set(cls_attribute_names)): + # All `Field` values have annotations, so use annotation gatherer + return anno_g(cls_dict) + + return attrib_g(cls_dict) + return field_unified_gatherer + + +slot_gatherer = make_slot_gatherer() annotation_gatherer = make_annotation_gatherer() +unified_gatherer = make_unified_gatherer(field_type=Field, leave_default_values=False) + + +# Now the gatherers have been defined, add __repr__ and __eq__ to Field. +_field_methods = {repr_maker, eq_maker} +if _UNDER_TESTING: + _field_methods.update({frozen_setattr_maker, frozen_delattr_maker}) + +builder( + Field, + gatherer=slot_gatherer, + methods=_field_methods, + flags={"slotted": True, "kw_only": True}, +) + + +def check_argument_order(cls): + """ + Raise a SyntaxError if the argument order will be invalid for a generated + `__init__` function. + + :param cls: class being built + """ + fields = get_fields(cls) + used_default = False + for k, v in fields.items(): + if v.kw_only or (not v.init): + continue + + if v.default is NOTHING and v.default_factory is NOTHING: + if used_default: + raise SyntaxError( + f"non-default argument {k!r} follows default argument" + ) + else: + used_default = True + + +# Class Decorators +def slotclass(cls=None, /, *, methods=default_methods, syntax_check=True): + """ + Example of class builder in action using __slots__ to find fields. + + :param cls: Class to be analysed and modified + :param methods: MethodMakers to be added to the class + :param syntax_check: check there are no arguments without defaults + after arguments with defaults. + :return: Modified class + """ + if not cls: + return lambda cls_: slotclass(cls_, methods=methods, syntax_check=syntax_check) + + cls = builder(cls, gatherer=slot_gatherer, methods=methods, flags={"slotted": True}) + + if syntax_check: + check_argument_order(cls) + + return cls + class AnnotationClass(metaclass=SlotMakerMeta): def __init_subclass__(cls, methods=default_methods, **kwargs): @@ -658,3 +863,17 @@ def __init_subclass__(cls, methods=default_methods, **kwargs): builder(cls, gatherer=gatherer, methods=methods, flags={"slotted": slots}) check_argument_order(cls) super().__init_subclass__(**kwargs) + + +@slotclass +class GatheredFields: + """ + A helper gatherer for fields that have been gathered externally. + """ + __slots__ = SlotFields( + fields=Field(), + modifications=Field(), + ) + + def __call__(self, cls): + return self.fields, self.modifications diff --git a/src/ducktools/classbuilder/__init__.pyi b/src/ducktools/classbuilder/__init__.pyi index c76ab2f..c1d81f2 100644 --- a/src/ducktools/classbuilder/__init__.pyi +++ b/src/ducktools/classbuilder/__init__.pyi @@ -1,11 +1,15 @@ import typing + from collections.abc import Callable +from types import MappingProxyType from typing_extensions import dataclass_transform _py_type = type | str # Alias for type hint values +_CopiableMappings = dict[str, typing.Any] | MappingProxyType[str, typing.Any] __version__: str INTERNALS_DICT: str +META_GATHERER_NAME: str def get_fields(cls: type, *, local: bool = False) -> dict[str, Field]: ... @@ -14,9 +18,14 @@ def get_flags(cls:type) -> dict[str, bool]: ... def _get_inst_fields(inst: typing.Any) -> dict[str, typing.Any]: ... class _NothingType: - ... + def __repr__(self) -> str: ... NOTHING: _NothingType +# noinspection PyPep8Naming +class _KW_ONLY_TYPE: + def __repr__(self) -> str: ... + +KW_ONLY: _KW_ONLY_TYPE # Stub Only _codegen_type = Callable[[type], tuple[str, dict[str, typing.Any]]] @@ -33,6 +42,11 @@ def get_init_generator( ) -> Callable[[type], tuple[str, dict[str, typing.Any]]]: ... def init_generator(cls: type) -> tuple[str, dict[str, typing.Any]]: ... + +def get_repr_generator( + recursion_safe: bool = False, + eval_safe: bool = False +) -> Callable[[type], tuple[str, dict[str, typing.Any]]]: ... def repr_generator(cls: type) -> tuple[str, dict[str, typing.Any]]: ... def eq_generator(cls: type) -> tuple[str, dict[str, typing.Any]]: ... @@ -70,6 +84,10 @@ def builder( ) -> Callable[[type[_T]], type[_T]]: ... +class SlotFields(dict): + ... + + class SlotMakerMeta(type): def __new__( cls: type[_T], @@ -86,6 +104,10 @@ class Field(metaclass=SlotMakerMeta): default_factory: _NothingType | typing.Any type: _NothingType | _py_type doc: None | str + init: bool + repr: bool + compare: bool + kw_only: bool __slots__: dict[str, str] __classbuilder_internals__: dict @@ -97,6 +119,10 @@ class Field(metaclass=SlotMakerMeta): default_factory: _NothingType | typing.Any = NOTHING, type: _NothingType | _py_type = NOTHING, doc: None | str = None, + init: bool = True, + repr: bool = True, + compare: bool = True, + kw_only: bool = False, ) -> None: ... def __init_subclass__(cls, frozen: bool = False): ... @@ -107,41 +133,64 @@ class Field(metaclass=SlotMakerMeta): def from_field(cls, fld: Field, /, **kwargs: typing.Any) -> Field: ... -class GatheredFields: - __slots__ = ("fields", "modifications") +# type[Field] doesn't work due to metaclass +# This is not really precise enough because isinstance is used +_ReturnsField = Callable[..., Field] +_FieldType = typing.TypeVar("_FieldType", bound=Field) - fields: dict[str, Field] - modifications: dict[str, typing.Any] - __classbuilder_internals__: dict +@typing.overload +def make_slot_gatherer( + field_type: type[_FieldType] +) -> Callable[[type | _CopiableMappings], tuple[dict[str, _FieldType], dict[str, typing.Any]]]: ... - def __init__( - self, - fields: dict[str, Field], - modifications: dict[str, typing.Any] - ) -> None: ... +@typing.overload +def make_slot_gatherer( + field_type: _ReturnsField = Field +) -> Callable[[type | _CopiableMappings], tuple[dict[str, Field], dict[str, typing.Any]]]: ... - def __repr__(self) -> str: ... - def __eq__(self, other) -> bool: ... - def __call__(self, cls: type) -> tuple[dict[str, Field], dict[str, typing.Any]]: ... +@typing.overload +def make_annotation_gatherer( + field_type: type[_FieldType], + leave_default_values: bool = True, +) -> Callable[[type | _CopiableMappings], tuple[dict[str, _FieldType], dict[str, typing.Any]]]: ... +@typing.overload +def make_annotation_gatherer( + field_type: _ReturnsField = Field, + leave_default_values: bool = True, +) -> Callable[[type | _CopiableMappings], tuple[dict[str, Field], dict[str, typing.Any]]]: ... -class SlotFields(dict): - ... +@typing.overload +def make_field_gatherer( + field_type: type[_FieldType], + leave_default_values: bool = True, +) -> Callable[[type | _CopiableMappings], tuple[dict[str, _FieldType], dict[str, typing.Any]]]: ... -_FieldType = typing.TypeVar("_FieldType", bound=Field) +@typing.overload +def make_field_gatherer( + field_type: _ReturnsField = Field, + leave_default_values: bool = True, +) -> Callable[[type | _CopiableMappings], tuple[dict[str, Field], dict[str, typing.Any]]]: ... @typing.overload -def make_slot_gatherer( - field_type: type[_FieldType] -) -> Callable[[type], tuple[dict[str, _FieldType], dict[str, typing.Any]]]: ... +def make_unified_gatherer( + field_type: type[_FieldType], + leave_default_values: bool = True, +) -> Callable[[type | _CopiableMappings], tuple[dict[str, _FieldType], dict[str, typing.Any]]]: ... @typing.overload -def make_slot_gatherer( - field_type: SlotMakerMeta = Field -) -> Callable[[type], tuple[dict[str, Field], dict[str, typing.Any]]]: ... +def make_unified_gatherer( + field_type: _ReturnsField = Field, + leave_default_values: bool = True, +) -> Callable[[type | _CopiableMappings], tuple[dict[str, Field], dict[str, typing.Any]]]: ... + + +def slot_gatherer(cls_or_ns: type | _CopiableMappings) -> tuple[dict[str, Field], dict[str, typing.Any]]: ... +def annotation_gatherer(cls_or_ns: type | _CopiableMappings) -> tuple[dict[str, Field], dict[str, typing.Any]]: ... + +def unified_gatherer(cls_or_ns: type | _CopiableMappings) -> tuple[dict[str, Field], dict[str, typing.Any]]: ... -def slot_gatherer(cls: type) -> tuple[dict[str, Field], dict[str, typing.Any]]: ... def check_argument_order(cls: type) -> None: ... @@ -163,20 +212,6 @@ def slotclass( syntax_check: bool = True ) -> Callable[[type[_T]], type[_T]]: ... -@typing.overload -def make_annotation_gatherer( - field_type: type[_FieldType], - leave_default_values: bool = True, -) -> Callable[[type], tuple[dict[str, _FieldType], dict[str, typing.Any]]]: ... - -@typing.overload -def make_annotation_gatherer( - field_type: SlotMakerMeta = Field, - leave_default_values: bool = True, -) -> Callable[[type], tuple[dict[str, Field], dict[str, typing.Any]]]: ... - -def annotation_gatherer(cls: type) -> tuple[dict[str, Field], dict[str, typing.Any]]: ... - @dataclass_transform(field_specifiers=(Field,)) class AnnotationClass(metaclass=SlotMakerMeta): @@ -185,3 +220,21 @@ class AnnotationClass(metaclass=SlotMakerMeta): methods: frozenset[MethodMaker] | set[MethodMaker] = default_methods, **kwargs, ) -> None: ... + +class GatheredFields: + __slots__: dict[str, None] + + fields: dict[str, Field] + modifications: dict[str, typing.Any] + + __classbuilder_internals__: dict + + def __init__( + self, + fields: dict[str, Field], + modifications: dict[str, typing.Any] + ) -> None: ... + + def __repr__(self) -> str: ... + def __eq__(self, other) -> bool: ... + def __call__(self, cls: type) -> tuple[dict[str, Field], dict[str, typing.Any]]: ... diff --git a/src/ducktools/classbuilder/annotations.py b/src/ducktools/classbuilder/annotations.py index 9046ce4..8e4478f 100644 --- a/src/ducktools/classbuilder/annotations.py +++ b/src/ducktools/classbuilder/annotations.py @@ -21,35 +21,99 @@ # SOFTWARE. import sys +import builtins -def eval_hint(hint, obj_globals=None, obj_locals=None): +class _StringGlobs(dict): + """ + Based on the fake globals dictionary used for annotations + from 3.14. This allows us to evaluate containers which + include forward references. + + It's just a dictionary that returns the key if the key + is not found. + """ + def __missing__(self, key): + return key + + def __repr__(self): + cls_name = self.__class__.__name__ + dict_repr = super().__repr__() + return f"{cls_name}({dict_repr})" + + +def eval_hint(hint, context=None, *, recursion_limit=5): """ Attempt to evaluate a string type hint in the given - context. If this fails, return the original string. + context. + + If this raises an exception, return the last string. + + If the recursion limit is hit or a previous value returns + on evaluation, return the original hint string. + + Example:: + import builtins + from typing import ClassVar + + from ducktools.classbuilder.annotations import eval_hint + + foo = "foo" + + context = {**vars(builtins), **globals(), **locals()} + eval_hint("foo", context) # returns 'foo' + + eval_hint("ClassVar[str]", context) # returns typing.ClassVar[str] + eval_hint("ClassVar[forwardref]", context) # returns typing.ClassVar[ForwardRef('forwardref')] :param hint: The existing type hint - :param obj_globals: global context - :param obj_locals: local context + :param context: merged context + :param recursion_limit: maximum number of evaluation loops before + returning the original string. :return: evaluated hint, or string if it could not evaluate """ + if context is not None: + context = _StringGlobs(context) + + original_hint = hint + seen = set() + i = 0 while isinstance(hint, str): + seen.add(hint) + # noinspection PyBroadException try: - hint = eval(hint, obj_globals, obj_locals) + hint = eval(hint, context) except Exception: break + + if hint in seen or i >= recursion_limit: + hint = original_hint + break + + i += 1 + return hint -def get_annotations(ns): +def get_ns_annotations(ns, eval_str=True): """ - Given an class namespace, attempt to retrieve the + Given a class namespace, attempt to retrieve the annotations dictionary and evaluate strings. + Note: This only evaluates in the context of module level globals + and values in the class namespace. Non-local variables will not + be evaluated. + :param ns: Class namespace (eg cls.__dict__) + :param eval_str: Attempt to evaluate string annotations (default to True) :return: dictionary of evaluated annotations """ + raw_annotations = ns.get("__annotations__", {}) + + if not eval_str: + return raw_annotations.copy() + try: obj_modulename = ns["__module__"] except KeyError: @@ -58,16 +122,21 @@ def get_annotations(ns): obj_module = sys.modules.get(obj_modulename, None) if obj_module: - obj_globals = obj_module.__dict__.copy() + obj_globals = vars(obj_module) else: obj_globals = {} - obj_locals = ns.copy() + # Type parameters should be usable in hints without breaking + # This is for Python 3.12+ + type_params = { + repr(param): param + for param in ns.get("__type_params__", ()) + } - raw_annotations = ns.get("__annotations__", {}) + context = {**vars(builtins), **obj_globals, **type_params, **ns} return { - k: eval_hint(v, obj_globals, obj_locals) + k: eval_hint(v, context) for k, v in raw_annotations.items() } diff --git a/src/ducktools/classbuilder/annotations.pyi b/src/ducktools/classbuilder/annotations.pyi index 7b99ddc..7ccbf0a 100644 --- a/src/ducktools/classbuilder/annotations.pyi +++ b/src/ducktools/classbuilder/annotations.pyi @@ -1,16 +1,25 @@ import typing import types - +_T = typing.TypeVar("_T") _CopiableMappings = dict[str, typing.Any] | types.MappingProxyType[str, typing.Any] +class _StringGlobs: + def __missing__(self, key: _T) -> _T: ... + + def eval_hint( hint: type | str, - obj_globals: None | dict[str, typing.Any] = None, - obj_locals: None | dict[str, typing.Any] = None, + context: None | dict[str, typing.Any] = None, + *, + recursion_limit: int = 5 ) -> type | str: ... -def get_annotations(ns: _CopiableMappings) -> dict[str, typing.Any]: ... + +def get_ns_annotations( + ns: _CopiableMappings, + eval_str: bool = True, +) -> dict[str, typing.Any]: ... def is_classvar( hint: object, diff --git a/src/ducktools/classbuilder/prefab.py b/src/ducktools/classbuilder/prefab.py index b3354b0..cbfccdb 100644 --- a/src/ducktools/classbuilder/prefab.py +++ b/src/ducktools/classbuilder/prefab.py @@ -26,12 +26,13 @@ Includes pre and post init functions along with other methods. """ from . import ( - INTERNALS_DICT, NOTHING, - Field, MethodMaker, SlotFields, GatheredFields, - builder, get_flags, get_fields, make_slot_gatherer, - frozen_setattr_maker, frozen_delattr_maker + INTERNALS_DICT, NOTHING, SlotFields, KW_ONLY, + Field, MethodMaker, GatheredFields, SlotMakerMeta, + builder, get_flags, get_fields, + make_unified_gatherer, + frozen_setattr_maker, frozen_delattr_maker, eq_maker, + get_repr_generator, ) -from .annotations import is_classvar, get_annotations PREFAB_FIELDS = "PREFAB_FIELDS" PREFAB_INIT_FUNC = "__prefab_init__" @@ -39,17 +40,6 @@ POST_INIT_FUNC = "__prefab_post_init__" -# KW_ONLY sentinel 'type' to use to indicate all subsequent attributes are -# keyword only -# noinspection PyPep8Naming -class _KW_ONLY_TYPE: - def __repr__(self): - return "" - - -KW_ONLY = _KW_ONLY_TYPE() - - class PrefabError(Exception): pass @@ -221,88 +211,6 @@ def __init__(cls: "type") -> "tuple[str, dict]": return MethodMaker(init_name, __init__) -def get_repr_maker(*, recursion_safe=False): - def __repr__(cls: "type") -> "tuple[str, dict]": - attributes = get_attributes(cls) - - globs = {} - - will_eval = True - valid_names = [] - for name, attrib in attributes.items(): - if attrib.repr and not attrib.exclude_field: - valid_names.append(name) - - # If the init fields don't match the repr, or some fields are excluded - # generate a repr that clearly will not evaluate - if will_eval and (attrib.exclude_field or (attrib.init ^ attrib.repr)): - will_eval = False - - content = ", ".join( - f"{name}={{self.{name}!r}}" - for name in valid_names - ) - - if recursion_safe: - import reprlib - globs["_recursive_repr"] = reprlib.recursive_repr() - recursion_func = "@_recursive_repr\n" - else: - recursion_func = "" - - if will_eval: - code = ( - f"{recursion_func}" - f"def __repr__(self):\n" - f" return f'{{type(self).__qualname__}}({content})'\n" - ) - else: - if content: - code = ( - f"{recursion_func}" - f"def __repr__(self):\n" - f" return f''\n" - ) - else: - code = ( - f"{recursion_func}" - f"def __repr__(self):\n" - f" return f''\n" - ) - - return code, globs - - return MethodMaker("__repr__", __repr__) - - -def get_eq_maker(): - def __eq__(cls: "type") -> "tuple[str, dict]": - class_comparison = "self.__class__ is other.__class__" - attribs = get_attributes(cls) - field_names = [ - name - for name, attrib in attribs.items() - if attrib.compare and not attrib.exclude_field - ] - - if field_names: - selfvals = ",".join(f"self.{name}" for name in field_names) - othervals = ",".join(f"other.{name}" for name in field_names) - instance_comparison = f"({selfvals},) == ({othervals},)" - else: - instance_comparison = "True" - - code = ( - f"def __eq__(self, other):\n" - f" return {instance_comparison} if {class_comparison} else NotImplemented\n" - ) - globs = {} - - return code, globs - - return MethodMaker("__eq__", __eq__) - - def get_iter_maker(): def __iter__(cls: "type") -> "tuple[str, dict]": fields = get_attributes(cls) @@ -344,9 +252,14 @@ def as_dict_gen(cls: "type") -> "tuple[str, dict]": init_maker = get_init_maker() prefab_init_maker = get_init_maker(init_name=PREFAB_INIT_FUNC) -repr_maker = get_repr_maker() -recursive_repr_maker = get_repr_maker(recursion_safe=True) -eq_maker = get_eq_maker() +repr_maker = MethodMaker( + "__repr__", + get_repr_generator(recursion_safe=False, eval_safe=True) +) +recursive_repr_maker = MethodMaker( + "__repr__", + get_repr_generator(recursion_safe=True, eval_safe=True) +) iter_maker = get_iter_maker() asdict_maker = get_asdict_maker() @@ -372,19 +285,21 @@ class Attribute(Field): :param doc: Parameter documentation for slotted classes :param type: Type of this attribute (for slotted classes) """ - init: bool = Field(default=True, doc="Include in the class __init__ parameters") - repr: bool = Field(default=True, doc="Include in the class __repr__") - compare: bool = Field(default=True, doc="Include in the class __eq__") - iter: bool = Field(default=True, doc="Include in the class __iter__ if generated.") - kw_only: bool = Field(default=False, doc="Make this a keyword only parameter in __init__") - serialize: bool = Field(default=True, doc="Serialize this attribute") - exclude_field: bool = Field(default=False, doc="Exclude this field from multiple methods") + iter: bool = True + serialize: bool = True + exclude_field: bool = False def validate_field(self): super().validate_field() - if self.kw_only and not self.init: + + exclude_attribs = { + self.repr, self.compare, self.iter, self.serialize + } + if self.exclude_field and any(exclude_attribs): raise PrefabError( - "Attribute cannot be keyword only if it is not in init." + "Excluded fields must have repr, compare, iter, serialize " + "set to False." + "This is automatically handled by using the `attribute` helper." ) @@ -424,6 +339,12 @@ def attribute( :return: Attribute generated with these parameters. """ + if exclude_field: + repr = False + compare = False + iter = False + serialize = False + return Attribute( default=default, default_factory=default_factory, @@ -439,84 +360,7 @@ def attribute( ) -slot_prefab_gatherer = make_slot_gatherer(Attribute) - - -# Gatherer for classes built on attributes or annotations -def attribute_gatherer(cls): - cls_annotations = get_annotations(cls.__dict__) - cls_annotation_names = cls_annotations.keys() - - cls_slots = cls.__dict__.get("__slots__", {}) - - cls_attributes = { - k: v for k, v in vars(cls).items() if isinstance(v, Attribute) - } - - cls_attribute_names = cls_attributes.keys() - - cls_modifications = {} - - if set(cls_annotation_names).issuperset(set(cls_attribute_names)): - # replace the classes' attributes dict with one with the correct - # order from the annotations. - kw_flag = False - new_attributes = {} - for name, value in cls_annotations.items(): - # Ignore ClassVar hints - if is_classvar(value): - continue - - # Look for the KW_ONLY annotation - if value is KW_ONLY: - if kw_flag: - raise PrefabError( - "Class can not be defined as keyword only twice" - ) - kw_flag = True - else: - # Copy attributes that are already defined to the new dict - # generate Attribute() values for those that are not defined. - - # Extra parameters to pass to each Attribute - extras = { - "type": cls_annotations[name] - } - if kw_flag: - extras["kw_only"] = True - - # If a field name is also declared in slots it can't have a real - # default value and the attr will be the slot descriptor. - if hasattr(cls, name) and name not in cls_slots: - if name in cls_attribute_names: - attrib = Attribute.from_field( - cls_attributes[name], - **extras, - ) - else: - attribute_default = getattr(cls, name) - attrib = attribute(default=attribute_default, **extras) - - # Clear the attribute from the class after it has been used - # in the definition. - cls_modifications[name] = NOTHING - else: - attrib = attribute(**extras) - - new_attributes[name] = attrib - - cls_attributes = new_attributes - else: - for name in cls_attributes.keys(): - attrib = cls_attributes[name] - cls_modifications[name] = NOTHING - - # Some items can still be annotated. - if name in cls_annotations: - new_attrib = Attribute.from_field(attrib, type=cls_annotations[name]) - cls_attributes[name] = new_attrib - - return cls_attributes, cls_modifications +prefab_gatherer = make_unified_gatherer(Attribute, False) # Class Builders @@ -562,16 +406,13 @@ def _make_prefab( ) slots = cls_dict.get("__slots__") + + slotted = False if slots is None else True + if gathered_fields is None: - if isinstance(slots, SlotFields): - gatherer = slot_prefab_gatherer - slotted = True - else: - gatherer = attribute_gatherer - slotted = False + gatherer = prefab_gatherer else: gatherer = gathered_fields - slotted = False if slots is None else True methods = set() @@ -712,6 +553,37 @@ def _make_prefab( return cls +class Prefab(metaclass=SlotMakerMeta): + _meta_gatherer = prefab_gatherer + __slots__ = {} + + # noinspection PyShadowingBuiltins + def __init_subclass__( + cls, + init=True, + repr=True, + eq=True, + iter=False, + match_args=True, + kw_only=False, + frozen=False, + dict_method=False, + recursive_repr=False, + ): + _make_prefab( + cls, + init=init, + repr=repr, + eq=eq, + iter=iter, + match_args=match_args, + kw_only=kw_only, + frozen=frozen, + dict_method=dict_method, + recursive_repr=recursive_repr, + ) + + # noinspection PyShadowingBuiltins def prefab( cls=None, diff --git a/src/ducktools/classbuilder/prefab.pyi b/src/ducktools/classbuilder/prefab.pyi index 57ce843..eeb02a5 100644 --- a/src/ducktools/classbuilder/prefab.pyi +++ b/src/ducktools/classbuilder/prefab.pyi @@ -1,11 +1,15 @@ import typing +from types import MappingProxyType from typing_extensions import dataclass_transform from collections.abc import Callable from . import ( INTERNALS_DICT, NOTHING, - Field, MethodMaker, SlotFields as SlotFields, + KW_ONLY as KW_ONLY, + Field, MethodMaker, + SlotFields as SlotFields, + SlotMakerMeta, builder, get_flags, get_fields, make_slot_gatherer ) @@ -17,12 +21,7 @@ PREFAB_INIT_FUNC: str PRE_INIT_FUNC: str POST_INIT_FUNC: str - -# noinspection PyPep8Naming -class _KW_ONLY_TYPE: - def __repr__(self) -> str: ... - -KW_ONLY: _KW_ONLY_TYPE +_CopiableMappings = dict[str, typing.Any] | MappingProxyType[str, typing.Any] class PrefabError(Exception): ... @@ -30,10 +29,6 @@ def get_attributes(cls: type) -> dict[str, Attribute]: ... def get_init_maker(*, init_name: str="__init__") -> MethodMaker: ... -def get_repr_maker(*, recursion_safe: bool = False) -> MethodMaker: ... - -def get_eq_maker() -> MethodMaker: ... - def get_iter_maker() -> MethodMaker: ... def get_asdict_maker() -> MethodMaker: ... @@ -50,11 +45,7 @@ asdict_maker: MethodMaker class Attribute(Field): __slots__: dict - init: bool - repr: bool - compare: bool iter: bool - kw_only: bool serialize: bool exclude_field: bool @@ -93,9 +84,7 @@ def attribute( exclude_field: bool = False, ) -> Attribute: ... -def slot_prefab_gatherer(cls: type) -> tuple[dict[str, Attribute], dict[str, typing.Any]]: ... - -def attribute_gatherer(cls: type) -> tuple[dict[str, Attribute], dict[str, typing.Any]]: ... +def prefab_gatherer(cls_or_ns: type | MappingProxyType) -> tuple[dict[str, Attribute], dict[str, typing.Any]]: ... def _make_prefab( cls: type, @@ -114,6 +103,23 @@ def _make_prefab( _T = typing.TypeVar("_T") +# noinspection PyUnresolvedReferences +@dataclass_transform(field_specifiers=(Attribute, attribute)) +class Prefab(metaclass=SlotMakerMeta): + _meta_gatherer: Callable[[type | _CopiableMappings], tuple[dict[str, Field], dict[str, typing.Any]]] + def __init_subclass__( + cls, + init: bool = True, + repr: bool = True, + eq: bool = True, + iter: bool = False, + match_args: bool = True, + kw_only: bool = False, + frozen: bool = False, + dict_method: bool = False, + recursive_repr: bool = False, + ) -> None: ... + # For some reason PyCharm can't see 'attribute'?!? # noinspection PyUnresolvedReferences diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6074ec8 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +import sys + +collect_ignore = [] + +if sys.version_info < (3, 14): + minor_ver = sys.version_info.minor + + collect_ignore.extend( + f"py3{i+1}_tests" for i in range(minor_ver, 14) + ) + +print(collect_ignore) diff --git a/tests/prefab/dynamic/test_construction.py b/tests/prefab/dynamic/test_construction.py index a7ad4a4..d91e07d 100644 --- a/tests/prefab/dynamic/test_construction.py +++ b/tests/prefab/dynamic/test_construction.py @@ -1,7 +1,7 @@ import pytest from ducktools.classbuilder import get_fields -from ducktools.classbuilder.annotations import get_annotations +from ducktools.classbuilder.annotations import get_ns_annotations from ducktools.classbuilder.prefab import build_prefab, prefab, attribute, PrefabError @@ -109,7 +109,7 @@ def test_build_slotted(): assert inst.y == 0 assert SlottedClass.__slots__ == {'x': "x co-ordinate", 'y': "y co-ordinate"} - assert get_annotations(SlottedClass.__dict__) == {'x': float, 'y': float} + assert get_ns_annotations(SlottedClass.__dict__) == {'x': float, 'y': float} # Test slots are functioning with pytest.raises(AttributeError): diff --git a/tests/prefab/dynamic/test_slotted_class.py b/tests/prefab/dynamic/test_slotted_class.py index 1a8e8f8..0c8ebf6 100644 --- a/tests/prefab/dynamic/test_slotted_class.py +++ b/tests/prefab/dynamic/test_slotted_class.py @@ -1,6 +1,6 @@ import pytest -from ducktools.classbuilder.annotations import get_annotations +from ducktools.classbuilder.annotations import get_ns_annotations from ducktools.classbuilder.prefab import prefab, attribute, SlotFields @@ -13,7 +13,7 @@ class SlottedPrefab: ) assert SlottedPrefab.__slots__ == {"x": None, "y": "Digits of pi"} - assert get_annotations(SlottedPrefab.__dict__) == {"y": float} + assert get_ns_annotations(SlottedPrefab.__dict__) == {"y": float} ex = SlottedPrefab() diff --git a/tests/prefab/shared/test_creation.py b/tests/prefab/shared/test_creation.py index b1f3d0d..04b2327 100644 --- a/tests/prefab/shared/test_creation.py +++ b/tests/prefab/shared/test_creation.py @@ -1,7 +1,7 @@ """Tests for errors raised on class creation""" import sys -from ducktools.classbuilder.annotations import get_annotations +from ducktools.classbuilder.annotations import get_ns_annotations from ducktools.classbuilder.prefab import PrefabError import pytest @@ -46,12 +46,12 @@ def test_removed_defaults(self): removed_attributes = ["x", "y", "z"] for attrib in removed_attributes: assert attrib not in getattr(OnlyHints, "__dict__") - assert attrib in get_annotations(OnlyHints.__dict__) + assert attrib in get_ns_annotations(OnlyHints.__dict__) def test_removed_only_used_defaults(self): from creation import MixedHints - annotations = get_annotations(MixedHints.__dict__) + annotations = get_ns_annotations(MixedHints.__dict__) assert "x" in annotations assert "y" in annotations @@ -145,7 +145,7 @@ def test_skipped_annotated_classvars(self): class TestExceptions: def test_kw_not_in_init(self): - with pytest.raises(PrefabError) as e_info: + with pytest.raises(AttributeError) as e_info: from fails.creation_1 import Construct assert ( @@ -171,7 +171,7 @@ def test_default_value_and_factory_error(self): assert ( e_info.value.args[0] - == "Cannot define both a default value and a default factory." + == "Attribute cannot define both a default value and a default factory." ) @@ -185,7 +185,7 @@ def test_splitvardef(self, classname): cls = getattr(creation, classname) - assert get_annotations(cls.__dict__)["x"] == str + assert get_ns_annotations(cls.__dict__)["x"] == str inst = cls() assert inst.x == "test" @@ -211,7 +211,7 @@ def test_horriblemess(self): assert inst.x == "true_test" assert repr(inst) == "HorribleMess(x='true_test', y='test_2')" - assert get_annotations(cls.__dict__) == {"x": str, "y": str} + assert get_ns_annotations(cls.__dict__) == {"x": str, "y": str} def test_call_mistaken(): @@ -235,7 +235,7 @@ def test_non_init_works_no_default(self): x.x = 12 - assert repr(x) == "" + assert repr(x) == "" def test_non_init_doesnt_break_syntax(self): # No syntax error if an attribute with a default is defined @@ -243,4 +243,4 @@ def test_non_init_doesnt_break_syntax(self): from creation import PositionalNotAfterKW x = PositionalNotAfterKW(1, 2) - assert repr(x) == "" + assert repr(x) == "" diff --git a/tests/prefab/shared/test_dunders.py b/tests/prefab/shared/test_dunders.py index 8cce404..f84ce7f 100644 --- a/tests/prefab/shared/test_dunders.py +++ b/tests/prefab/shared/test_dunders.py @@ -15,7 +15,7 @@ def test_repr(): def test_repr_exclude(): from dunders import CoordinateNoXRepr - expected_repr = "" + expected_repr = "" assert repr(CoordinateNoXRepr(1, 2)) == expected_repr diff --git a/tests/prefab/shared/test_init.py b/tests/prefab/shared/test_init.py index 7e01971..1e23372 100644 --- a/tests/prefab/shared/test_init.py +++ b/tests/prefab/shared/test_init.py @@ -142,8 +142,8 @@ def test_exclude_field(): assert x.x == "EXCLUDED_FIELD" assert y.x == "STILL_EXCLUDED" - assert repr(x) == "" - assert repr(y) == "" + assert repr(x) == "" + assert repr(y) == "" assert x == y diff --git a/tests/prefab/shared/test_kw_only.py b/tests/prefab/shared/test_kw_only.py index 7a34d20..943322c 100644 --- a/tests/prefab/shared/test_kw_only.py +++ b/tests/prefab/shared/test_kw_only.py @@ -1,6 +1,6 @@ import pytest -from ducktools.classbuilder.annotations import get_annotations +from ducktools.classbuilder.annotations import get_ns_annotations def test_kw_only_basic(): from kw_only import KWBasic @@ -65,7 +65,7 @@ def test_kw_only_prefab_argument_overrides(): def test_kw_flag_no_defaults(): from kw_only import KWFlagNoDefaults - annotations = get_annotations(KWFlagNoDefaults.__dict__) + annotations = get_ns_annotations(KWFlagNoDefaults.__dict__) assert "_" in annotations diff --git a/tests/prefab/shared/test_repr.py b/tests/prefab/shared/test_repr.py index 277aeb5..f26aa54 100644 --- a/tests/prefab/shared/test_repr.py +++ b/tests/prefab/shared/test_repr.py @@ -9,28 +9,28 @@ def test_basic_repr_no_fields(): from repr_func import NoReprAttributes x = NoReprAttributes() - assert repr(x) == "" + assert repr(x) == "" def test_one_attribute_no_repr(): from repr_func import OneAttributeNoRepr x = OneAttributeNoRepr() - assert repr(x) == "" + assert repr(x) == "" def test_one_attribute_no_init(): from repr_func import OneAttributeNoInit x = OneAttributeNoInit() - assert repr(x) == "" + assert repr(x) == "" def test_one_attribute_exclude_field(): from repr_func import OneAttributeExcludeField x = OneAttributeExcludeField() - assert repr(x) == "" + assert repr(x) == "" def test_regular_one_arg(): diff --git a/tests/py312_tests/test_generic_annotations.py b/tests/py312_tests/test_generic_annotations.py new file mode 100644 index 0000000..dddc0f3 --- /dev/null +++ b/tests/py312_tests/test_generic_annotations.py @@ -0,0 +1,15 @@ +# This syntax only exists in Python 3.12 or later. +from ducktools.classbuilder.annotations import get_ns_annotations + + +def test_312_generic(): + class X[T]: + test_var = T # Need access outside of class to test + + x: list[T] + y: "list[T]" + + assert get_ns_annotations(vars(X)) == { + "x": list[X.test_var], + "y": list[X.test_var], + } diff --git a/tests/test_annotated.py b/tests/test_annotated.py index 997a477..f904e07 100644 --- a/tests/test_annotated.py +++ b/tests/test_annotated.py @@ -11,7 +11,7 @@ from ducktools.classbuilder.annotations import ( is_classvar, - get_annotations, + get_ns_annotations, ) CV = ClassVar @@ -82,7 +82,7 @@ class ExampleAnnotated: h: Annotated[CV[str], ''] = "h" annos, modifications = gatherer(ExampleAnnotated) - annotations = get_annotations(vars(ExampleAnnotated)) + annotations = get_ns_annotations(vars(ExampleAnnotated)) assert annos["blank_field"] == NewField(type=str) diff --git a/tests/test_annotations_module.py b/tests/test_annotations_module.py new file mode 100644 index 0000000..a40bf2f --- /dev/null +++ b/tests/test_annotations_module.py @@ -0,0 +1,118 @@ +import builtins + +from ducktools.classbuilder.annotations import ( + _StringGlobs, + eval_hint, + get_ns_annotations, + is_classvar, +) +from typing import List, ClassVar +from typing_extensions import Annotated + + +def test_string_globs(): + context = _StringGlobs({'str': str}) + assert context['str'] == str + assert context['forwardref'] == 'forwardref' + + assert repr(context) == f"_StringGlobs({{'str': {str!r}}})" + + +class TestEvalHint: + def test_basic(self): + assert eval_hint('str') == str + assert eval_hint("'str'") == str + + assert eval_hint('forwardref') == 'forwardref' + + def test_container(self): + context = _StringGlobs({ + **vars(builtins), + **globals(), + **locals() + }) + + assert eval_hint("List[str]", context) == List[str] + assert eval_hint("ClassVar[str]", context) == ClassVar[str] + + assert eval_hint("List[forwardref]", context) == List["forwardref"] + assert eval_hint("ClassVar[forwardref]", context) == ClassVar["forwardref"] + + def test_loop(self): + # Check the 'seen' test prevents an infinite loop + + alt_str = str + bleh = "bleh" + + context = _StringGlobs({ + **vars(builtins), + **globals(), + **locals() + }) + + assert eval_hint("alt_str", context) == str + assert eval_hint("bleh", context) == "bleh" + + def test_evil_hint(self): + # Nobody should evaluate anything that does this, but it shouldn't break + # On every evaluation this function generates a new string + # This hits the (low) recursion limit and returns the original string + class EvilLookup: + counter = 0 + + def __getattr__(self, key): + EvilLookup.counter += 1 + return f"EvilLookup().loop{self.counter}" + + evil_value = EvilLookup() + + context = _StringGlobs({ + **vars(builtins), + **globals(), + **locals() + }) + + assert eval_hint("evil_value.loop", context) == "evil_value.loop" + + +def test_ns_annotations(): + CV = ClassVar + + class AnnotatedClass: + a: str + b: "str" + c: List[str] + d: "List[str]" + e: ClassVar[str] + f: "ClassVar[str]" + g: "ClassVar[forwardref]" + h: "Annotated[ClassVar[str], '']" + i: "Annotated[ClassVar[forwardref], '']" + j: "CV[str]" # Limitation, can't see closure variables. + + annos = get_ns_annotations(vars(AnnotatedClass)) + + assert annos == { + 'a': str, + 'b': str, + 'c': List[str], + 'd': List[str], + 'e': ClassVar[str], + 'f': ClassVar[str], + 'g': ClassVar['forwardref'], + 'h': Annotated[ClassVar[str], ''], + 'i': Annotated[ClassVar['forwardref'], ''], + 'j': "CV[str]", + } + + +def test_is_classvar(): + assert is_classvar(ClassVar) + assert is_classvar(ClassVar[str]) + assert is_classvar(ClassVar['forwardref']) + + assert is_classvar(Annotated[ClassVar[str], '']) + assert is_classvar(Annotated[ClassVar['forwardref'], '']) + + assert not is_classvar(str) + assert not is_classvar(Annotated[str, '']) diff --git a/tests/test_core.py b/tests/test_core.py index 3678905..4c3e763 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -16,7 +16,7 @@ slotclass, GatheredFields, ) -from ducktools.classbuilder.annotations import get_annotations +from ducktools.classbuilder.annotations import get_ns_annotations def test_get_fields_flags(): @@ -113,13 +113,20 @@ def test_repr_field(): f4 = Field(default=True, type=bool) f5 = Field(default=True, doc="True or False") + repr_ending = "init=True, repr=True, compare=True, kw_only=False" + nothing_repr = repr(NOTHING) - f1_repr = f"Field(default=True, default_factory={nothing_repr}, type={nothing_repr}, doc=None)" - f2_repr = f"Field(default=False, default_factory={nothing_repr}, type={nothing_repr}, doc=None)" - f3_repr = f"Field(default={nothing_repr}, default_factory=, type={nothing_repr}, doc=None)" - f4_repr = f"Field(default=True, default_factory={nothing_repr}, type=, doc=None)" - f5_repr = f"Field(default=True, default_factory={nothing_repr}, type={nothing_repr}, doc='True or False')" + f1_repr = (f"Field(default=True, default_factory={nothing_repr}, " + f"type={nothing_repr}, doc=None, {repr_ending})") + f2_repr = (f"Field(default=False, default_factory={nothing_repr}, " + f"type={nothing_repr}, doc=None, {repr_ending})") + f3_repr = (f"Field(default={nothing_repr}, default_factory=, " + f"type={nothing_repr}, doc=None, {repr_ending})") + f4_repr = (f"Field(default=True, default_factory={nothing_repr}, " + f"type=, doc=None, {repr_ending})") + f5_repr = (f"Field(default=True, default_factory={nothing_repr}, " + f"type={nothing_repr}, doc='True or False', {repr_ending})") assert repr(f1) == f1_repr assert repr(f2) == f2_repr @@ -172,7 +179,7 @@ class SlotsExample: assert slots == fields assert modifications["__slots__"] == {"a": None, "b": None, "c": "a list", "d": None} assert modifications["__annotations__"] == {"a": int, "d": str} - assert get_annotations(SlotsExample.__dict__) == {"a": int} # Original annotations dict unmodified + assert get_ns_annotations(SlotsExample.__dict__) == {"a": int} # Original annotations dict unmodified def test_slot_gatherer_failure(): @@ -449,11 +456,15 @@ class Ex: assert repr(flds).endswith( "GatheredFields(" - "fields={'x': Field(default=1, default_factory=, type=, doc=None)}, " + "fields={'x': Field(" + "default=1, default_factory=, type=, doc=None, " + "init=True, repr=True, compare=True, kw_only=False" + ")}, " "modifications={'x': }" ")" ) + def test_signature(): # This used to fail @slotclass diff --git a/tests/test_field_flags.py b/tests/test_field_flags.py new file mode 100644 index 0000000..4017fd7 --- /dev/null +++ b/tests/test_field_flags.py @@ -0,0 +1,48 @@ +from ducktools.classbuilder import Field, SlotFields, slotclass +import inspect + + +def test_init_false_field(): + @slotclass + class Example: + __slots__ = SlotFields( + x=Field(default="x", init=False), + y=Field(default="y") + ) + + sig = inspect.signature(Example) + assert 'x' not in sig.parameters + assert 'y' in sig.parameters + assert sig.parameters["y"].default == "y" + + ex = Example() + assert ex.x == "x" + assert ex.y == "y" + + +def test_repr_false_field(): + @slotclass + class Example: + __slots__ = SlotFields( + x=Field(default="x", repr=False), + y=Field(default="y"), + ) + + ex = Example() + assert repr(ex).endswith("Example(y='y')") + + +def test_compare_false_field(): + @slotclass + class Example: + __slots__ = SlotFields( + x=Field(default="x", compare=False), + y=Field(default="y"), + ) + + ex = Example() + ex2 = Example(x="z") + ex3 = Example(y="z") + + assert ex == ex2 + assert ex != ex3 diff --git a/tests/test_slotter.py b/tests/test_slotter.py index 75c4e4d..e55bbdf 100644 --- a/tests/test_slotter.py +++ b/tests/test_slotter.py @@ -1,6 +1,6 @@ -from typing import ClassVar +from typing import ClassVar, List from typing_extensions import Annotated -from ducktools.classbuilder import SlotFields, NOTHING, SlotMakerMeta +from ducktools.classbuilder import Field, SlotFields, NOTHING, SlotMakerMeta import pytest @@ -8,7 +8,7 @@ def test_slots_created(): class ExampleAnnotated(metaclass=SlotMakerMeta): a: str = "a" - b: "list[str]" = "b" # Yes this is the wrong type, I know. + b: "List[str]" = "b" # Yes this is the wrong type, I know. c: Annotated[str, ""] = "c" d: ClassVar[str] = "d" @@ -19,7 +19,13 @@ class ExampleAnnotated(metaclass=SlotMakerMeta): assert hasattr(ExampleAnnotated, "__slots__") slots = ExampleAnnotated.__slots__ # noqa - assert slots == SlotFields({char: char for char in "abc"}) + expected_slots = SlotFields({ + "a": Field(default="a", type=str), + "b": Field(default="b", type=List[str]), + "c": Field(default="c", type=Annotated[str, ""]) + }) + + assert slots == expected_slots def test_slots_correct_subclass(): @@ -31,8 +37,12 @@ class ExampleBase(metaclass=SlotMakerMeta): class ExampleChild(ExampleBase): d: str = "d" - assert ExampleBase.__slots__ == SlotFields(a=NOTHING, b="b", c="c") # noqa - assert ExampleChild.__slots__ == SlotFields(d="d") # noqa + assert ExampleBase.__slots__ == SlotFields( # noqa + a=Field(type=str), + b=Field(default="b", type=str), + c=Field(default="c", type=str), + ) + assert ExampleChild.__slots__ == SlotFields(d=Field(default="d", type=str)) # noqa inst = ExampleChild() @@ -43,3 +53,17 @@ class ExampleChild(ExampleBase): with pytest.raises(AttributeError): inst.e = "e" + + +def test_slots_attribute(): + # In the case where an unannotated field is declared, ignore + # annotations without field values. + class ExampleBase(metaclass=SlotMakerMeta): + x: str = "x" + y: str = Field(default="y") + z = Field(default="z") + + assert ExampleBase.__slots__ == SlotFields( # noqa + y=Field(default="y", type=str), + z=Field(default="z"), + )