Skip to content

Commit

Permalink
[Bugfix][TIR] Fix version conflict with typing for different Python v…
Browse files Browse the repository at this point in the history
…ersions (3.8.0-3.10.0) (#13820)

* hotfix

* fix lint
  • Loading branch information
sunggg authored Jan 22, 2023
1 parent ac9fb98 commit cc7def0
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions python/tvm/tir/schedule/_type_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ def _is_none_type(type_: Any) -> bool:
return type_ is None or type_ is type(None)


def _get_subtypes(type_: Any) -> Any:
# TODO(@tvm-team): This is hot fix to support subtle difference between python versions
# Would be nice to find a better way if possible
if hasattr(typing, "_SpecialGenericAlias"):
if hasattr(typing, "get_args"):
subtypes = typing.get_args(type_) # type: ignore
else:
subtypes = type_.__args__
else:
subtypes = type_.__args__
return subtypes


if hasattr(typing, "_GenericAlias"):
# For python versions 3.7 onward, check the __origin__ attribute.

Expand Down Expand Up @@ -64,10 +77,7 @@ def dict_(type_: Any) -> Any:
@staticmethod
def tuple_(type_: Any) -> Optional[List[type]]:
if _Subtype._origin(type_) is tuple:
if hasattr(typing, "get_args"):
subtypes = typing.get_args(type_) # type: ignore
else:
subtypes = type_.__args__
subtypes = _get_subtypes(type_)
return subtypes
return None

Expand All @@ -76,32 +86,23 @@ def optional( # pylint: disable=missing-function-docstring
type_: Any,
) -> Optional[List[type]]:
if _Subtype._origin(type_) is Union:
if hasattr(typing, "get_args"):
subtypes = typing.get_args(type_) # type: ignore
else:
subtypes = type_.__args__
subtypes = _get_subtypes(type_)
if len(subtypes) == 2 and _is_none_type(subtypes[1]):
return [subtypes[0]]
return None

@staticmethod
def union(type_: Any) -> Optional[List[type]]: # pylint: disable=missing-function-docstring
if _Subtype._origin(type_) is Union:
if hasattr(typing, "get_args"):
subtypes = typing.get_args(type_) # type: ignore
else:
subtypes = type_.__args__
subtypes = _get_subtypes(type_)
if len(subtypes) != 2 or not _is_none_type(subtypes[1]):
return list(subtypes)
return None

@staticmethod
def callable(type_: Any) -> Optional[List[type]]:
if _Subtype._origin(type_) is collections.abc.Callable:
if hasattr(typing, "get_args") and not type_._special:
subtypes = typing.get_args(type_) # type: ignore
else:
subtypes = type_.__args__
subtypes = _get_subtypes(type_)
return subtypes
return None

Expand Down

0 comments on commit cc7def0

Please sign in to comment.