From 5c15349c03d73ea2b5f67ff9a42340c96b651194 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Fri, 16 Feb 2024 16:40:33 -0800 Subject: [PATCH] Add better type deduction for more expressions: * Deduce type of walrus operator (`:=`) expressions * Deduce type of awaited expressions * Deduce type of `cast()` functions * Deduce `set` literals as type `set` * Deduce type for unary operations * Deduce type of index (`x[0]`) expressions * Deduce return type when calling lambda expressions --- refurb/checks/common.py | 110 +++++++++++++----- .../checks/readability/no_unnecessary_cast.py | 1 + test/data/err_123.py | 4 + test/data/err_123.txt | 2 + test/data/type_deduce.py | 40 +++++++ test/data/type_deduce.txt | 12 ++ 6 files changed, 137 insertions(+), 32 deletions(-) create mode 100644 test/data/type_deduce.py create mode 100644 test/data/type_deduce.txt diff --git a/refurb/checks/common.py b/refurb/checks/common.py index 186d7e7..fbb196d 100644 --- a/refurb/checks/common.py +++ b/refurb/checks/common.py @@ -4,10 +4,13 @@ from mypy.nodes import ( ArgKind, + AssignmentExpr, AssignmentStmt, + AwaitExpr, Block, BytesExpr, CallExpr, + CastExpr, ComparisonExpr, ComplexExpr, ConditionalExpr, @@ -36,7 +39,7 @@ StarExpr, Statement, StrExpr, - SymbolTableNode, + SymbolNode, TupleExpr, TypeAlias, TypeInfo, @@ -445,6 +448,12 @@ def _stringify(node: Node) -> str: case ExpressionStmt(expr=expr): return _stringify(expr) + case AwaitExpr(expr=expr): + return f"await {_stringify(expr)}" + + case AssignmentExpr(target=lhs, value=rhs): + return f"{_stringify(lhs)} := {_stringify(rhs)}" + raise ValueError @@ -463,7 +472,7 @@ def slice_expr_to_slice_call(expr: SliceExpr) -> str: TypeLike = type | str | None | object -def is_same_type(ty: Type | TypeInfo | None, *expected: TypeLike) -> bool: +def is_same_type(ty: Type | SymbolNode | None, *expected: TypeLike) -> bool: """ Check if the type `ty` matches any of the `expected` types. `ty` must be a Mypy type object, but the expected types can be any of the following: @@ -498,7 +507,7 @@ def is_same_type(ty: Type | TypeInfo | None, *expected: TypeLike) -> bool: } -def _is_same_type(ty: Type | TypeInfo | None, expected: TypeLike) -> bool: +def _is_same_type(ty: Type | SymbolNode | None, expected: TypeLike) -> bool: if ty is expected is None: return True @@ -520,14 +529,17 @@ def _is_same_type(ty: Type | TypeInfo | None, expected: TypeLike) -> bool: return False -def _get_builtin_mypy_type(name: str) -> Type | None: +def _get_builtin_mypy_type(name: str) -> Instance | None: if (sym := types.BUILTINS_MYPY_FILE.names.get(name)) and isinstance(sym.node, TypeInfo): return Instance(sym.node, []) return None # pragma: no cover -def get_mypy_type(node: Node) -> Type | None: +def get_mypy_type(node: Node) -> Type | SymbolNode | None: + # forward declaration to make Mypy happy + ty: Type | SymbolNode | None + match node: case StrExpr(): return _get_builtin_mypy_type("str") @@ -556,52 +568,86 @@ def get_mypy_type(node: Node) -> Type | None: case TupleExpr(): return _get_builtin_mypy_type("tuple") - case Var(type=ty): + case SetExpr(): + return _get_builtin_mypy_type("set") + + case Var(type=ty) | FuncDef(type=ty): return ty - case NameExpr(node=sym): - match sym: - case Var(type=ty) | Instance(type=ty): # type: ignore - return ty + case TypeInfo() | TypeAlias() | MypyFile(): + return node - case TypeAlias(target=ty): + case NameExpr(node=sym) if sym: + return get_mypy_type(sym) + + case MemberExpr(expr=lhs, name=name): + ty = get_mypy_type(lhs) + + if ( + isinstance(ty, MypyFile | TypeInfo) + and (member := ty.names.get(name)) + and member.node + ): + return get_mypy_type(member.node) + + if isinstance(ty, Instance) and (member := ty.type.get(name)) and member.node: + return get_mypy_type(member.node) + + case CallExpr(analyzed=CastExpr(type=ty)): + return ty + + case CallExpr(callee=callee): + match get_mypy_type(callee): + case CallableType(ret_type=ty): return ty - case FuncDef(type=CallableType(ret_type=ty)): + case TypeAlias(target=ty): return ty - case TypeInfo(): + case TypeInfo() as sym: return Instance(sym, []) - case MemberExpr(expr=lhs, name=name): - # TODO: don't special case this - match lhs: - case NameExpr(node=MypyFile(names=names)): - match names.get(name): - case SymbolTableNode(node=FuncDef(type=CallableType(ret_type=ty))): - return ty + case UnaryExpr(op="not"): + return _get_builtin_mypy_type("bool") - case SymbolTableNode(node=TypeInfo() as ty): # type: ignore - return Instance(ty, []) # type: ignore + case UnaryExpr(method_type=CallableType(ret_type=ty)): + return ty - lhs_type = get_mypy_type(lhs) + case OpExpr(method_type=CallableType(ret_type=ty)): + return ty - if isinstance(lhs_type, Instance): - sym = lhs_type.type.get(name) # type: ignore + case IndexExpr(method_type=CallableType(ret_type=ty)): + return ty - if sym and sym.node: # type: ignore - return get_mypy_type(sym.node) # type: ignore + case AwaitExpr(expr=expr): + ty = get_mypy_type(expr) - case CallExpr(callee=callee): - return get_mypy_type(callee) + # TODO: allow for any Awaitable[T] type + match ty: + case Instance(type=TypeInfo(fullname="typing.Coroutine"), args=[_, _, rtype]): + return rtype - case OpExpr(method_type=CallableType(ret_type=ty)): - return ty + case Instance(type=TypeInfo(fullname="asyncio.tasks.Task"), args=[rtype]): + return rtype + + case LambdaExpr(body=Block(body=[ReturnStmt(expr=expr)])) if expr: + if (ty := get_mypy_type(expr)) and isinstance(ty, Type): + return _build_placeholder_callable(ty) + + case AssignmentExpr(target=expr): + return get_mypy_type(expr) return None -def mypy_type_to_python_type(ty: Type | None) -> type | None: +def _build_placeholder_callable(rtype: Type) -> Type | None: + if function := _get_builtin_mypy_type("function"): + return CallableType([], [], [], ret_type=rtype, fallback=function) + + return None # pragma: no cover + + +def mypy_type_to_python_type(ty: Type | SymbolNode | None) -> type | None: match ty: # TODO: return annotated types if instance has args (ie, `list[int]`) case Instance(type=TypeInfo(fullname=fullname)): diff --git a/refurb/checks/readability/no_unnecessary_cast.py b/refurb/checks/readability/no_unnecessary_cast.py index fc28568..36e697c 100644 --- a/refurb/checks/readability/no_unnecessary_cast.py +++ b/refurb/checks/readability/no_unnecessary_cast.py @@ -53,6 +53,7 @@ class ErrorInfo(Error): "builtins.float": (float, ""), "builtins.int": (int, ""), "builtins.list": (list, ".copy()"), + "builtins.set": (set, ""), "builtins.str": (str, ""), "builtins.tuple": (tuple, ""), } diff --git a/test/data/err_123.py b/test/data/err_123.py index 6b69e37..ff466e9 100644 --- a/test/data/err_123.py +++ b/test/data/err_123.py @@ -39,6 +39,10 @@ def func() -> bool: _ = bool(func()) +s = {1} +_ = set(s) +_ = set({1}) + # these will not diff --git a/test/data/err_123.txt b/test/data/err_123.txt index 9ee2788..40f56f5 100644 --- a/test/data/err_123.txt +++ b/test/data/err_123.txt @@ -16,3 +16,5 @@ test/data/err_123.py:29:5 [FURB123]: Replace `list(f)` with `f.copy()` test/data/err_123.py:32:5 [FURB123]: Replace `str(g)` with `g` test/data/err_123.py:35:5 [FURB123]: Replace `tuple(t)` with `t` test/data/err_123.py:40:5 [FURB123]: Replace `bool(func())` with `func()` +test/data/err_123.py:43:5 [FURB123]: Replace `set(s)` with `s` +test/data/err_123.py:44:5 [FURB123]: Replace `set({1})` with `{1}` diff --git a/test/data/type_deduce.py b/test/data/type_deduce.py new file mode 100644 index 0000000..64811ad --- /dev/null +++ b/test/data/type_deduce.py @@ -0,0 +1,40 @@ +# These are a variety of checks to ensure Refurb is able to deduce types from +# complex expressions. + +_ = bool([True][0]) + + +async def async_wrapper(): + import asyncio + + async def return_bool() -> bool: + return True + + task = asyncio.create_task(return_bool()) + + _ = bool(await return_bool()) + _ = bool(await task) + + +lambda_return_bool = lambda: True +_ = bool(lambda_return_bool()) +_ = bool((lambda: True)()) # TODO: error message should include parens around lambda + +bool_value = True + +_ = bool(not bool_value) +_ = bool(not False) + +_ = int(-1) +_ = int(+1) +_ = int(~1) + +_ = bool(walrus := True) + +from typing import cast + +_ = bool(cast(bool, 123)) + + +# These types are not able to be deduced (yet) +_ = int(1 or 2) diff --git a/test/data/type_deduce.txt b/test/data/type_deduce.txt new file mode 100644 index 0000000..2c26bb6 --- /dev/null +++ b/test/data/type_deduce.txt @@ -0,0 +1,12 @@ +test/data/type_deduce.py:4:5 [FURB123]: Replace `bool([True][0])` with `[True][0]` +test/data/type_deduce.py:15:9 [FURB123]: Replace `bool(await return_bool())` with `await return_bool()` +test/data/type_deduce.py:16:9 [FURB123]: Replace `bool(await task)` with `await task` +test/data/type_deduce.py:20:5 [FURB123]: Replace `bool(lambda_return_bool())` with `lambda_return_bool()` +test/data/type_deduce.py:21:5 [FURB123]: Replace `bool(lambda: True())` with `lambda: True()` +test/data/type_deduce.py:25:5 [FURB123]: Replace `bool(not bool_value)` with `not bool_value` +test/data/type_deduce.py:26:5 [FURB123]: Replace `bool(not False)` with `not False` +test/data/type_deduce.py:28:5 [FURB123]: Replace `int(-1)` with `-1` +test/data/type_deduce.py:29:5 [FURB123]: Replace `int(+1)` with `+1` +test/data/type_deduce.py:30:5 [FURB123]: Replace `int(~1)` with `~1` +test/data/type_deduce.py:32:5 [FURB123]: Replace `bool(walrus := True)` with `walrus := True` +test/data/type_deduce.py:36:5 [FURB123]: Replace `bool(cast(bool, 123))` with `cast(bool, 123)`