Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix enforce overrides magic methods #105

Merged
merged 4 commits into from
Oct 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions overrides/enforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,40 @@ def __new__(mcls, name, bases, namespace, **kwargs):

cls = super().__new__(mcls, name, bases, namespace, **kwargs)
for name, value in namespace.items():
# Actually checking the direct parent should be enough,
# otherwise the error would have emerged during the parent class checking
if name.startswith("__"):
mcls._check_if_overrides_final_method(name, bases)
if not name.startswith("__"):
value = mcls._handle_special_value(value)
mcls._check_if_overrides_without_overrides_decorator(name, value, bases)
return cls

@staticmethod
def _check_if_overrides_without_overrides_decorator(name, value, bases):
is_override = getattr(value, "__override__", False)
for base in bases:
base_class_method = getattr(base, name, False)
if (
not base_class_method
or not callable(base_class_method)
or getattr(base_class_method, "__ignored__", False)
):
continue
value = mcls.handle_special_value(value)
is_override = getattr(value, "__override__", False)
for base in bases:
base_class_method = getattr(base, name, False)
if (
not base_class_method
or not callable(base_class_method)
or getattr(base_class_method, "__ignored__", False)
):
continue
assert (
is_override
), "Method %s overrides but does not have @overrides decorator" % (name)
# `__finalized__` is added by `@final` decorator
assert not getattr(base_class_method, "__finalized__", False), (
"Method %s is finalized in %s, it cannot be overridden"
% (base_class_method, base,)
if not is_override:
raise TypeError(
f"Method {name} overrides method from {base} but does not have @overrides decorator"
)

@staticmethod
def _check_if_overrides_final_method(name, bases):
for base in bases:
base_class_method = getattr(base, name, False)
# `__finalized__` is added by `@final` decorator
if getattr(base_class_method, "__finalized__", False):
raise TypeError(
f"Method {name} is finalized in {base}, it cannot be overridden"
)
return cls

@staticmethod
def handle_special_value(value):
def _handle_special_value(value):
if isinstance(value, classmethod) or isinstance(value, staticmethod):
value = value.__get__(None, dict)
elif isinstance(value, property):
Expand Down
6 changes: 4 additions & 2 deletions overrides/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def method(self):


def _overrides(
method: _WrappedMethod, check_signature: bool, check_at_runtime: bool,
method: _WrappedMethod,
check_signature: bool,
check_at_runtime: bool,
) -> _WrappedMethod:
setattr(method, "__override__", True)
global_vars = getattr(method, "__globals__", None)
Expand Down Expand Up @@ -125,7 +127,7 @@ def _validate_method(method, super_class, check_signature):
if hasattr(super_method, "__finalized__"):
finalized = getattr(super_method, "__finalized__")
if finalized:
raise TypeError(f"{method.__name__}: is finalized")
raise TypeError(f"{method.__name__}: is finalized in {super_class}")
if not method.__doc__:
method.__doc__ = super_method.__doc__
if (
Expand Down
8 changes: 6 additions & 2 deletions overrides/typing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,9 @@ def _is_origin_subtype(left: OriginType, right: OriginType) -> bool:


NormalizedTypeArgs = typing.Union[
typing.Tuple[typing.Any, ...], typing.FrozenSet[NormalizedType], NormalizedType,
typing.Tuple[typing.Any, ...],
typing.FrozenSet[NormalizedType],
NormalizedType,
]


Expand Down Expand Up @@ -416,7 +418,9 @@ def _is_normal_subtype(


def issubtype(
left: Type, right: Type, forward_refs: typing.Optional[dict] = None,
left: Type,
right: Type,
forward_refs: typing.Optional[dict] = None,
) -> typing.Optional[bool]:
"""Check that the left argument is a subtype of the right.
For unions, check if the type arguments of the left is a subset of the right.
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
author_email=address,
url="https://github.com/mkorpela/overrides",
packages=find_packages(),
package_data={"overrides": ["*.pyi", "py.typed"],},
package_data={
"overrides": ["*.pyi", "py.typed"],
},
include_package_data=True,
install_requires=['typing;python_version<"3.5"'],
python_requires=">=3.6",
Expand Down
23 changes: 13 additions & 10 deletions tests/test_enforce__py38.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ class Enforcing(EnforceOverrides):
def finality(self):
return "final"

@final
def __and__(self, other):
return True

def nonfinal1(self, param: int) -> str:
return "super1"

Expand Down Expand Up @@ -47,27 +51,26 @@ def nonfinal1(self, param: int) -> str:
self.assertEqual(sc.classVariableIsOk, "OK!")

def test_enforcing_when_finality_broken(self):
try:
with self.assertRaises(TypeError):

class BrokesFinality(Enforcing):
def finality(self):
return "NEVER HERE"

raise RuntimeError("Should not go here")
except AssertionError:
pass
def test_trying_to_override_final_magic_method(self):
with self.assertRaises(TypeError):

class FinalMagicOverrides(Enforcing):
def __and__(self, other):
return False

def test_enforcing_when_none_explicit_override(self):
try:
with self.assertRaises(TypeError):

class Overrider(Enforcing):
def nonfinal2(self):
return "NEVER HERE EITHER"

raise RuntimeError("Should not go here")
except AssertionError:
pass

def test_enforcing_when_property_overriden(self):
class PropertyOverrider(Enforcing):
@property
Expand Down Expand Up @@ -116,7 +119,7 @@ class MetaClassMethodOverrider(Enforcing):
def register(self):
pass

with self.assertRaises(AssertionError):
with self.assertRaises(TypeError):

class SubClass(MetaClassMethodOverrider):
def register(self):
Expand Down