diff --git a/crates/ruff/resources/test/fixtures/flake8_pyi/PYI030.pyi b/crates/ruff/resources/test/fixtures/flake8_pyi/PYI030.pyi index 875e76477d178..e54264d331add 100644 --- a/crates/ruff/resources/test/fixtures/flake8_pyi/PYI030.pyi +++ b/crates/ruff/resources/test/fixtures/flake8_pyi/PYI030.pyi @@ -70,8 +70,14 @@ field20: typing.Union[ # Should handle multiple unions with multiple members field21: Literal[1, 2] | Literal[3, 4] # Error -# Should emit in cases with `typing.Union`` instead of `|` +# Should emit in cases with `typing.Union` instead of `|` field22: typing.Union[Literal[1], Literal[2]] # Error # Should emit in cases with `typing_extensions.Literal` field23: typing_extensions.Literal[1] | typing_extensions.Literal[2] # Error + +# Should emit in cases with nested `typing.Union` +field24: typing.Union[Literal[1], typing.Union[Literal[2], str]] # Error + +# Should emit in cases with mixed `typing.Union` and `|` +field24: typing.Union[Literal[1], Literal[2] | str] # Error diff --git a/crates/ruff/src/checkers/ast/mod.rs b/crates/ruff/src/checkers/ast/mod.rs index 30fad458626a3..3468ac61eed05 100644 --- a/crates/ruff/src/checkers/ast/mod.rs +++ b/crates/ruff/src/checkers/ast/mod.rs @@ -2189,6 +2189,11 @@ where } } + // Ex) Union[...] + if self.enabled(Rule::UnnecessaryLiteralUnion) { + flake8_pyi::rules::unnecessary_literal_union(self, expr); + } + if self.semantic.match_typing_expr(value, "Literal") { self.semantic.flags |= SemanticModelFlags::LITERAL; } diff --git a/crates/ruff/src/rules/flake8_pyi/rules/unnecessary_literal_union.rs b/crates/ruff/src/rules/flake8_pyi/rules/unnecessary_literal_union.rs index cba7648de3781..1aceb1ba22f32 100644 --- a/crates/ruff/src/rules/flake8_pyi/rules/unnecessary_literal_union.rs +++ b/crates/ruff/src/rules/flake8_pyi/rules/unnecessary_literal_union.rs @@ -10,7 +10,8 @@ use crate::checkers::ast::Checker; /// Checks for the presence of multiple literal types in a union. /// /// ## Why is this bad? -/// Literal types accept multiple arguments and it is clearer to specify them as a single literal. +/// Literal types accept multiple arguments and it is clearer to specify them +/// as a single literal. /// /// ## Example /// ```python @@ -37,15 +38,26 @@ impl Violation for UnnecessaryLiteralUnion { } /// PYI030 -pub(crate) fn unnecessary_literal_union(checker: &mut Checker, expr: &Expr) { - let mut literal_members = Vec::new(); - collect_literal_members(&mut literal_members, checker.semantic(), expr); +pub(crate) fn unnecessary_literal_union<'a>(checker: &mut Checker, expr: &'a Expr) { + let mut literal_exprs = Vec::new(); + + // Adds a member to `literal_exprs` if it is a `Literal` annotation + let mut collect_literal_expr = |expr: &'a Expr| { + if let Expr::Subscript(ast::ExprSubscript { value, slice, .. }) = expr { + if checker.semantic().match_typing_expr(&*value, "Literal") { + literal_exprs.push(slice); + } + } + }; + + // Traverse the union, collect all literal members + traverse_union(&mut collect_literal_expr, expr, checker.semantic()); // Raise a violation if more than one - if literal_members.len() > 1 { + if literal_exprs.len() > 1 { let diagnostic = Diagnostic::new( UnnecessaryLiteralUnion { - members: literal_members + members: literal_exprs .into_iter() .map(|m| checker.locator.slice(m.range()).to_string()) .collect(), @@ -57,20 +69,12 @@ pub(crate) fn unnecessary_literal_union(checker: &mut Checker, expr: &Expr) { } } -/// Collect literal expressions from a union. -fn collect_literal_members<'a>( - literal_members: &mut Vec<&'a Expr>, - model: &SemanticModel, - expr: &'a Expr, -) { - // The union data structure usually looks like this: - // a | b | c -> (a | b) | c - // - // However, parenthesized expressions can coerce it into any structure: - // a | (b | c) - // - // So we have to traverse both branches in order (left, then right), to report members - // in the order they appear in the source code. +/// Traverse a "union" type annotation, calling `func` on each expression in the union. +fn traverse_union<'a, F>(func: &mut F, expr: &'a Expr, semantic: &SemanticModel) +where + F: FnMut(&'a Expr), +{ + // Ex) x | y if let Expr::BinOp(ast::ExprBinOp { op: Operator::BitOr, left, @@ -78,15 +82,34 @@ fn collect_literal_members<'a>( range: _, }) = expr { - // Traverse left subtree, then the right subtree, propagating the previous node. - collect_literal_members(literal_members, model, left); - collect_literal_members(literal_members, model, right); + // The union data structure usually looks like this: + // a | b | c -> (a | b) | c + // + // However, parenthesized expressions can coerce it into any structure: + // a | (b | c) + // + // So we have to traverse both branches in order (left, then right), to report members + // in the order they appear in the source code. + + // Traverse the left then right arms + traverse_union(func, left, semantic); + traverse_union(func, right, semantic); + return; } - // If it's a literal expression add it to the members + // Ex) Union[x, y] if let Expr::Subscript(ast::ExprSubscript { value, slice, .. }) = expr { - if model.match_typing_expr(value, "Literal") { - literal_members.push(slice); + if semantic.match_typing_expr(value, "Union") { + if let Expr::Tuple(ast::ExprTuple { elts, .. }) = slice.as_ref() { + // Traverse each element of the tuple within the union recursively to handle cases + // such as `Union[..., Union[...]] + elts.iter() + .for_each(|elt| traverse_union(func, elt, semantic)); + return; + } } } + + // Otherwise, call the function on expression + func(expr) } diff --git a/crates/ruff/src/rules/flake8_pyi/snapshots/ruff__rules__flake8_pyi__tests__PYI030_PYI030.pyi.snap b/crates/ruff/src/rules/flake8_pyi/snapshots/ruff__rules__flake8_pyi__tests__PYI030_PYI030.pyi.snap index 1485ceaefaf35..74eb4c0cadc6a 100644 --- a/crates/ruff/src/rules/flake8_pyi/snapshots/ruff__rules__flake8_pyi__tests__PYI030_PYI030.pyi.snap +++ b/crates/ruff/src/rules/flake8_pyi/snapshots/ruff__rules__flake8_pyi__tests__PYI030_PYI030.pyi.snap @@ -237,13 +237,37 @@ PYI030.pyi:60:10: PYI030 Multiple literal members in a union. Use a single liter 62 | # Should emit in cases with newlines | +PYI030.pyi:63:10: PYI030 Multiple literal members in a union. Use a single literal, e.g. `Literal[1, 2]` + | +62 | # Should emit in cases with newlines +63 | field20: typing.Union[ + | __________^ +64 | | Literal[ +65 | | 1 # test +66 | | ], +67 | | Literal[2], +68 | | ] # Error, newline and comment will not be emitted in message + | |_^ PYI030 +69 | +70 | # Should handle multiple unions with multiple members + | + PYI030.pyi:71:10: PYI030 Multiple literal members in a union. Use a single literal, e.g. `Literal[1, 2, 3, 4]` | 70 | # Should handle multiple unions with multiple members 71 | field21: Literal[1, 2] | Literal[3, 4] # Error | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PYI030 72 | -73 | # Should emit in cases with `typing.Union`` instead of `|` +73 | # Should emit in cases with `typing.Union` instead of `|` + | + +PYI030.pyi:74:10: PYI030 Multiple literal members in a union. Use a single literal, e.g. `Literal[1, 2]` + | +73 | # Should emit in cases with `typing.Union` instead of `|` +74 | field22: typing.Union[Literal[1], Literal[2]] # Error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PYI030 +75 | +76 | # Should emit in cases with `typing_extensions.Literal` | PYI030.pyi:77:10: PYI030 Multiple literal members in a union. Use a single literal, e.g. `Literal[1, 2]` @@ -251,6 +275,24 @@ PYI030.pyi:77:10: PYI030 Multiple literal members in a union. Use a single liter 76 | # Should emit in cases with `typing_extensions.Literal` 77 | field23: typing_extensions.Literal[1] | typing_extensions.Literal[2] # Error | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PYI030 +78 | +79 | # Should emit in cases with nested `typing.Union` + | + +PYI030.pyi:80:10: PYI030 Multiple literal members in a union. Use a single literal, e.g. `Literal[1, 2]` + | +79 | # Should emit in cases with nested `typing.Union` +80 | field24: typing.Union[Literal[1], typing.Union[Literal[2], str]] # Error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PYI030 +81 | +82 | # Should emit in cases with mixed `typing.Union` and `|` + | + +PYI030.pyi:83:10: PYI030 Multiple literal members in a union. Use a single literal, e.g. `Literal[1, 2]` + | +82 | # Should emit in cases with mixed `typing.Union` and `|` +83 | field24: typing.Union[Literal[1], Literal[2] | str] # Error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PYI030 |