From 6e8458b2d69b345653c199a04b670feb486e15b1 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sat, 10 Feb 2024 18:51:06 -0800 Subject: [PATCH] Add better type deduction to `use-set-discard` --- refurb/checks/builtin/set_discard.py | 33 +++++++++++++--------------- refurb/checks/common.py | 10 ++++++--- test/data/err_132.py | 14 ++++++++++++ test/data/err_132.txt | 3 ++- 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/refurb/checks/builtin/set_discard.py b/refurb/checks/builtin/set_discard.py index 52a7d94..ec02448 100644 --- a/refurb/checks/builtin/set_discard.py +++ b/refurb/checks/builtin/set_discard.py @@ -1,17 +1,8 @@ from dataclasses import dataclass -from mypy.nodes import ( - Block, - CallExpr, - ComparisonExpr, - ExpressionStmt, - IfStmt, - MemberExpr, - NameExpr, - Var, -) +from mypy.nodes import Block, CallExpr, ComparisonExpr, ExpressionStmt, IfStmt, MemberExpr -from refurb.checks.common import is_equivalent, is_same_type +from refurb.checks.common import get_mypy_type, is_equivalent, is_same_type, stringify from refurb.error import Error @@ -41,7 +32,6 @@ class ErrorInfo(Error): name = "use-set-discard" code = 132 - msg: str = "Replace `if x in s: s.remove(x)` with `s.discard(x)`" categories = ("readability", "set") @@ -54,15 +44,22 @@ def check(node: IfStmt, errors: list[Error]) -> None: body=[ ExpressionStmt( expr=CallExpr( - callee=MemberExpr( - expr=NameExpr(node=Var(type=ty)) as expr, - name="remove", - ), + callee=MemberExpr(expr=expr, name="remove"), args=[arg], ) ) ] ) ], - ) if is_equivalent(lhs, arg) and is_equivalent(rhs, expr) and is_same_type(ty, set): - errors.append(ErrorInfo.from_node(node)) + else_body=None, + ) if ( + is_equivalent(lhs, arg) + and is_equivalent(rhs, expr) + and is_same_type(get_mypy_type(expr), set) + ): + old = stringify(node) + new = f"{stringify(expr)}.discard({stringify(arg)})" + + msg = f"Replace `{old}` with `{new}`" + + errors.append(ErrorInfo.from_node(node, msg)) diff --git a/refurb/checks/common.py b/refurb/checks/common.py index 7f109e1..13fadb6 100644 --- a/refurb/checks/common.py +++ b/refurb/checks/common.py @@ -15,6 +15,7 @@ DictExpr, DictionaryComprehension, Expression, + ExpressionStmt, FloatExpr, ForStmt, FuncDef, @@ -423,14 +424,17 @@ def _stringify(node: Node) -> str: case AssignmentStmt(lvalues=[lhs], rvalue=rhs): return f"{stringify(lhs)} = {stringify(rhs)}" - case IfStmt(expr=[expr], body=[Block(body=[body])], else_body=None): - return f"if {_stringify(expr)}: {_stringify(body)}" + case IfStmt(expr=[expr], body=[Block(body=[stmt])], else_body=None): + return f"if {_stringify(expr)}: {_stringify(stmt)}" case ConditionalExpr(if_expr=if_true, cond=cond, else_expr=if_false): return f"{_stringify(if_true)} if {_stringify(cond)} else {_stringify(if_false)}" case DelStmt(expr=expr): - return f"del {stringify(expr)}" + return f"del {_stringify(expr)}" + + case ExpressionStmt(expr=expr): + return _stringify(expr) raise ValueError diff --git a/test/data/err_132.py b/test/data/err_132.py index 25678e5..93843ca 100644 --- a/test/data/err_132.py +++ b/test/data/err_132.py @@ -5,6 +5,15 @@ if "x" in s: s.remove("x") + +class Wrapper: + s: set[int] + +w = Wrapper() + +if 0 in w.s: + w.s.remove(0) + # these should not if "x" in s: @@ -32,3 +41,8 @@ def __contains__(self, other) -> bool: if "x" in c: c.remove("x") + +if "x" in s: + s.remove("x") +else: + pass diff --git a/test/data/err_132.txt b/test/data/err_132.txt index 8d5d007..974801c 100644 --- a/test/data/err_132.txt +++ b/test/data/err_132.txt @@ -1 +1,2 @@ -test/data/err_132.py:5:1 [FURB132]: Replace `if x in s: s.remove(x)` with `s.discard(x)` +test/data/err_132.py:5:1 [FURB132]: Replace `if "x" in s: s.remove("x")` with `s.discard("x")` +test/data/err_132.py:14:1 [FURB132]: Replace `if 0 in w.s: w.s.remove(0)` with `w.s.discard(0)`