Skip to content

Commit

Permalink
python 3.12 bytecode
Browse files Browse the repository at this point in the history
  • Loading branch information
mkorpela committed Jan 20, 2024
1 parent aa5b794 commit 444a7a5
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 27 deletions.
22 changes: 15 additions & 7 deletions overrides/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import functools
import inspect
import sys
from types import FunctionType
from types import FrameType, FunctionType
from typing import Callable, List, Optional, Tuple, TypeVar, Union, overload

__VERSION__ = "7.5.0"
Expand Down Expand Up @@ -150,7 +150,9 @@ def method(self):


def _overrides(
method: _WrappedMethod, check_signature: bool, check_at_runtime: bool,
method: _WrappedMethod,
check_signature: bool,
check_at_runtime: bool,
) -> _WrappedMethod:
setattr(method, "__override__", True)
global_vars = getattr(method, "__globals__", None)
Expand Down Expand Up @@ -196,7 +198,7 @@ def _get_base_classes(frame, namespace):
]


def _get_base_class_names(frame) -> List[List[str]]:
def _get_base_class_names(frame: FrameType) -> List[List[str]]:
"""Get baseclass names from the code object"""
extends: List[Tuple[str, str]] = []
add_last_step = True
Expand All @@ -208,15 +210,19 @@ def _get_base_class_names(frame) -> List[List[str]]:
if not add_last_step:
extends = []
add_last_step = True
if instruction.opname == "LOAD_NAME":

# Combine LOAD_NAME and LOAD_GLOBAL as they have similar functionality
if instruction.opname in ["LOAD_NAME", "LOAD_GLOBAL"]:
extends.append(("name", instruction.argval))
elif instruction.opname == "LOAD_ATTR":

elif instruction.opname == "LOAD_ATTR" and extends and extends[-1][0] == "name":
extends.append(("attr", instruction.argval))
elif instruction.opname == "LOAD_GLOBAL":
extends.append(("name", instruction.argval))

# Reset on other instructions
else:
add_last_step = False

# Extracting class names
items: List[List[str]] = []
previous_item: List[str] = []
for t, s in extends:
Expand All @@ -226,8 +232,10 @@ def _get_base_class_names(frame) -> List[List[str]]:
previous_item = [s]
else:
previous_item += [s]

if previous_item:
items.append(previous_item)

return items


Expand Down
44 changes: 25 additions & 19 deletions overrides/typing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,12 @@
}

if UnionType:

def is_union(element: object) -> bool:
return element is typing.Union or element is UnionType

else:

def is_union(element: object) -> bool:
return element is typing.Union

Expand Down Expand Up @@ -185,7 +187,7 @@ def get_args(type_) -> typing.Tuple:
res = _getter(type_)
elif hasattr(typing.List, "_special"): # python 3.7
if (
isinstance(type_, GenericClass) and not type_._special # type: ignore
isinstance(type_, GenericClass) and not type_._special # type: ignore
): # backport for python 3.8
res = type_.__args__ # type: ignore
if get_origin(type_) is collections.abc.Callable and res[0] is not Ellipsis:
Expand Down Expand Up @@ -283,9 +285,9 @@ def _is_origin_subtype(left: OriginType, right: OriginType) -> bool:
return True

if (
left is not None
and left in STATIC_SUBTYPE_MAPPING
and right == STATIC_SUBTYPE_MAPPING[left]
left is not None
and left in STATIC_SUBTYPE_MAPPING
and right == STATIC_SUBTYPE_MAPPING[left]
):
return True

Expand All @@ -301,14 +303,16 @@ def _is_origin_subtype(left: OriginType, right: OriginType) -> bool:


NormalizedTypeArgs = typing.Union[
typing.Tuple[typing.Any, ...], typing.FrozenSet[NormalizedType], NormalizedType,
typing.Tuple[typing.Any, ...],
typing.FrozenSet[NormalizedType],
NormalizedType,
]


def _is_origin_subtype_args(
left: "NormalizedTypeArgs",
right: "NormalizedTypeArgs",
forward_refs: typing.Optional[typing.Mapping[str, type]],
left: "NormalizedTypeArgs",
right: "NormalizedTypeArgs",
forward_refs: typing.Optional[typing.Mapping[str, type]],
) -> typing.Optional[bool]:
if isinstance(left, frozenset):
if not isinstance(right, frozenset):
Expand All @@ -325,18 +329,18 @@ def _is_origin_subtype_args(
)

if isinstance(left, collections.abc.Sequence) and not isinstance(
left, NormalizedType
left, NormalizedType
):
if not isinstance(right, collections.abc.Sequence) or isinstance(
right, NormalizedType
right, NormalizedType
):
return False

if (
left
and left[-1].origin is not Ellipsis
and right
and right[-1].origin is Ellipsis
left
and left[-1].origin is not Ellipsis
and right
and right[-1].origin is Ellipsis
):
# Tuple[type, type] <> Tuple[type, ...]
return all(_is_origin_subtype_args(l, right[0], forward_refs) for l in left)
Expand All @@ -358,9 +362,9 @@ def _is_origin_subtype_args(


def _is_normal_subtype(
left: NormalizedType,
right: NormalizedType,
forward_refs: typing.Optional[typing.Mapping[str, type]],
left: NormalizedType,
right: NormalizedType,
forward_refs: typing.Optional[typing.Mapping[str, type]],
) -> typing.Optional[bool]:
if isinstance(left.origin, ForwardRef):
left = normalize(eval_forward_ref(left.origin, forward_refs=forward_refs))
Expand Down Expand Up @@ -392,7 +396,7 @@ def _is_normal_subtype(

# TypeVar
if isinstance(left.origin, typing.TypeVar) and isinstance(
right.origin, typing.TypeVar
right.origin, typing.TypeVar
):
if left.origin is right.origin:
return True
Expand Down Expand Up @@ -425,7 +429,9 @@ def _is_normal_subtype(


def issubtype(
left: Type, right: Type, forward_refs: typing.Optional[dict] = None,
left: Type,
right: Type,
forward_refs: typing.Optional[dict] = None,
) -> typing.Optional[bool]:
"""Check that the left argument is a subtype of the right.
For unions, check if the type arguments of the left is a subset of the right.
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
author_email=address,
url="https://github.com/mkorpela/overrides",
packages=find_packages(),
package_data={"overrides": ["*.pyi", "py.typed"],},
package_data={
"overrides": ["*.pyi", "py.typed"],
},
include_package_data=True,
install_requires=['typing;python_version<"3.5"'],
python_requires=">=3.6",
Expand Down
1 change: 1 addition & 0 deletions tests/test_new_union__py3_10.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def f(self) -> int:
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires Python3.10 or higher")
def test_should_not_allow_increasing_type():
with pytest.raises(TypeError):

class C(A):
@override
def f(self) -> int | str | list[str]:
Expand Down

0 comments on commit 444a7a5

Please sign in to comment.