Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][TIR] Fix version conflict with typing for different Python versions (3.8.0-3.10.0) #13820

Merged
merged 2 commits into from
Jan 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions python/tvm/tir/schedule/_type_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ def _is_none_type(type_: Any) -> bool:
return type_ is None or type_ is type(None)


def _get_subtypes(type_: Any) -> Any:
# TODO(@tvm-team): This is hot fix to support subtle difference between python versions
# Would be nice to find a better way if possible
if hasattr(typing, "_SpecialGenericAlias"):
if hasattr(typing, "get_args"):
subtypes = typing.get_args(type_) # type: ignore
else:
subtypes = type_.__args__
else:
subtypes = type_.__args__
return subtypes


if hasattr(typing, "_GenericAlias"):
# For python versions 3.7 onward, check the __origin__ attribute.

Expand Down Expand Up @@ -64,10 +77,7 @@ def dict_(type_: Any) -> Any:
@staticmethod
def tuple_(type_: Any) -> Optional[List[type]]:
if _Subtype._origin(type_) is tuple:
if hasattr(typing, "get_args"):
subtypes = typing.get_args(type_) # type: ignore
else:
subtypes = type_.__args__
subtypes = _get_subtypes(type_)
return subtypes
return None

Expand All @@ -76,32 +86,23 @@ def optional( # pylint: disable=missing-function-docstring
type_: Any,
) -> Optional[List[type]]:
if _Subtype._origin(type_) is Union:
if hasattr(typing, "get_args"):
subtypes = typing.get_args(type_) # type: ignore
else:
subtypes = type_.__args__
subtypes = _get_subtypes(type_)
if len(subtypes) == 2 and _is_none_type(subtypes[1]):
return [subtypes[0]]
return None

@staticmethod
def union(type_: Any) -> Optional[List[type]]: # pylint: disable=missing-function-docstring
if _Subtype._origin(type_) is Union:
if hasattr(typing, "get_args"):
subtypes = typing.get_args(type_) # type: ignore
else:
subtypes = type_.__args__
subtypes = _get_subtypes(type_)
if len(subtypes) != 2 or not _is_none_type(subtypes[1]):
return list(subtypes)
return None

@staticmethod
def callable(type_: Any) -> Optional[List[type]]:
if _Subtype._origin(type_) is collections.abc.Callable:
if hasattr(typing, "get_args") and not type_._special:
subtypes = typing.get_args(type_) # type: ignore
else:
subtypes = type_.__args__
subtypes = _get_subtypes(type_)
return subtypes
return None

Expand Down