diff --git a/overrides/enforce.py b/overrides/enforce.py index 236c878..824b5b6 100644 --- a/overrides/enforce.py +++ b/overrides/enforce.py @@ -30,11 +30,9 @@ def __new__(mcls, name, bases, namespace, **kwargs): is_override ), "Method %s overrides but does not have @overrides decorator" % (name) # `__finalized__` is added by `@final` decorator - assert not getattr( - base_class_method, "__finalized__", False - ), "Method %s is finalized in %s, it cannot be overridden" % ( - base_class_method, - base, + assert not getattr(base_class_method, "__finalized__", False), ( + "Method %s is finalized in %s, it cannot be overridden" + % (base_class_method, base,) ) return cls diff --git a/overrides/overrides.py b/overrides/overrides.py index 2dbd42d..7cd862e 100644 --- a/overrides/overrides.py +++ b/overrides/overrides.py @@ -95,9 +95,7 @@ def method(self): def _overrides( - method: _WrappedMethod, - check_signature: bool, - check_at_runtime: bool, + method: _WrappedMethod, check_signature: bool, check_at_runtime: bool, ) -> _WrappedMethod: setattr(method, "__override__", True) global_vars = getattr(method, "__globals__", None) diff --git a/overrides/signature.py b/overrides/signature.py index 5844918..83331ef 100644 --- a/overrides/signature.py +++ b/overrides/signature.py @@ -263,15 +263,21 @@ def is_param_defined_in_sub( def ensure_no_extra_args_in_sub( - super_sig, sub_sig, check_first_parameter: bool, method_name: str + super_sig: inspect.Signature, + sub_sig: inspect.Signature, + check_first_parameter: bool, + method_name: str, ) -> None: - super_var_args = any( - p.kind == Parameter.VAR_POSITIONAL for p in super_sig.parameters.values() - ) - super_var_kwargs = any( - p.kind == Parameter.VAR_KEYWORD for p in super_sig.parameters.values() - ) + super_params = super_sig.parameters.values() + super_var_args = any(p.kind == Parameter.VAR_POSITIONAL for p in super_params) + super_var_kwargs = any(p.kind == Parameter.VAR_KEYWORD for p in super_params) for sub_index, (name, sub_param) in enumerate(sub_sig.parameters.items()): + if ( + sub_param.kind == Parameter.POSITIONAL_ONLY + and len(super_params) > sub_index + and list(super_params)[sub_index].kind == Parameter.POSITIONAL_ONLY + ): + continue if ( name not in super_sig.parameters and sub_param.default == Parameter.empty diff --git a/overrides/typing_utils.py b/overrides/typing_utils.py index 9c78688..2d9f919 100644 --- a/overrides/typing_utils.py +++ b/overrides/typing_utils.py @@ -289,10 +289,9 @@ def _is_origin_subtype(left: OriginType, right: OriginType) -> bool: return left == right + NormalizedTypeArgs = typing.Union[ - typing.Tuple[typing.Any, ...], - typing.FrozenSet[NormalizedType], - NormalizedType, + typing.Tuple[typing.Any, ...], typing.FrozenSet[NormalizedType], NormalizedType, ] @@ -417,9 +416,7 @@ def _is_normal_subtype( def issubtype( - left: Type, - right: Type, - forward_refs: typing.Optional[dict] = None, + left: Type, right: Type, forward_refs: typing.Optional[dict] = None, ) -> typing.Optional[bool]: """Check that the left argument is a subtype of the right. For unions, check if the type arguments of the left is a subset of the right. diff --git a/setup.py b/setup.py index dfd69fd..a11fc7d 100644 --- a/setup.py +++ b/setup.py @@ -24,9 +24,7 @@ author_email=address, url="https://github.com/mkorpela/overrides", packages=find_packages(), - package_data={ - "overrides": ["*.pyi", "py.typed"], - }, + package_data={"overrides": ["*.pyi", "py.typed"],}, include_package_data=True, install_requires=['typing;python_version<"3.5"'], python_requires=">=3.6", diff --git a/tests/test_final.py b/tests/test_final.py index 66c29c2..eb76a0f 100644 --- a/tests/test_final.py +++ b/tests/test_final.py @@ -16,6 +16,7 @@ def some_finalized_method(self): class SomeFinalClass: pass + class SubClass(SuperClass): @overrides def some_method(self): diff --git a/tests/test_named_and_positional__py38.py b/tests/test_named_and_positional__py38.py index 4578da7..1b0b938 100644 --- a/tests/test_named_and_positional__py38.py +++ b/tests/test_named_and_positional__py38.py @@ -8,6 +8,9 @@ class A(EnforceOverrides): def methoda(self, x=0): print(x) + def methodb(self, x: int, /, y: str) -> str: + return y * x + class Other: def foo(self): @@ -86,6 +89,33 @@ def methoda(self, x=0, /): pass +def test_can_override_positional_only(): + class PositionalOnly1(A): + @overrides + def methodb(self, x: int, /, y: str) -> str: + return "OK" + + +def test_can_override_positional_only_with_new_name(): + class PositionalOnly2(A): + @overrides + def methodb(self, new_name_is_ok: int, /, y: str) -> str: + return "OK2" + + +def test_can_not_override_positional_only_with_new_type(): + try: + + class PositionalOnly3(A): + @overrides + def methodb(self, x: str, /, y: str) -> str: + return "NOPE" + + raise AssertionError("Should not go here") + except TypeError: + pass + + def test_can_not_override_with_keyword_only(): try: diff --git a/tests/test_overrides.py b/tests/test_overrides.py index 219b90b..0ce55c0 100644 --- a/tests/test_overrides.py +++ b/tests/test_overrides.py @@ -37,9 +37,11 @@ def some_method(self): class SomeClass: """Super Inner Class Docs""" + def check(self): return 0 + class SubClass(SuperClass): @overrides def some_method(self): @@ -88,6 +90,7 @@ class SomeClass: def check(self): return 1 + class OverridesTests(unittest.TestCase): def test_overrides_passes_for_same_package_superclass(self): sub = SubClass() @@ -167,6 +170,7 @@ def test_overrides_builtin_method_incorrect_signature(self): expected_error = no_error() with expected_error: + class SubclassOfInt(int): @overrides def bit_length(self, _):