Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flake8-pyi] Implement autofix for redundant-numeric-union (PYI041) #14273

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI041.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,38 @@ async def f4(**kwargs: int | int | float) -> None:
...


def f5(
arg: Union[ # comment
float, # another
complex, int]
) -> None:
...

def f6(
arg: (
int | # comment
float | # another
complex
)
) -> None:
...


class Foo:
def good(self, arg: int) -> None:
...

def bad(self, arg: int | float | complex) -> None:
...

def bad2(self, arg: int | Union[float, complex]) -> None:
...

def bad3(self, arg: Union[Union[float, complex], int]) -> None:
...

def bad4(self, arg: Union[float | complex, int]) -> None:
...

def bad5(self, arg: int | (float | complex)) -> None:
...
21 changes: 21 additions & 0 deletions crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI041.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,29 @@ def f3(arg1: int, *args: Union[int | int | float]) -> None: ... # PYI041

async def f4(**kwargs: int | int | float) -> None: ... # PYI041

def f5(
arg: Union[ # comment
float, # another
complex, int]
) -> None: ... # PYI041

def f6(
arg: (
int | # comment
float | # another
complex
)
) -> None: ... # PYI041

class Foo:
def good(self, arg: int) -> None: ...

def bad(self, arg: int | float | complex) -> None: ... # PYI041

def bad2(self, arg: int | Union[float, complex]) -> None: ... # PYI041

def bad3(self, arg: Union[Union[float, complex], int]) -> None: ... # PYI041

def bad4(self, arg: Union[float | complex, int]) -> None: ... # PYI041

def bad5(self, arg: int | (float | complex)) -> None: ... # PYI041
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use crate::checkers::ast::Checker;
/// ## Fix safety
/// This rule's fix is marked as safe, unless the type annotation contains comments.
///
/// Note that the fix will flatten nested literals into a single top-level
/// literal.
/// Note that while the fix may flatten nested literals into a single top-level literal,
/// the semantics of the annotation will remain unchanged.
///
/// ## References
/// - [Python documentation: `typing.Literal`](https://docs.python.org/3/library/typing.html#typing.Literal)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
use ruff_diagnostics::{Diagnostic, Violation};
use bitflags::bitflags;

use anyhow::Result;

use ruff_diagnostics::{Applicability, Diagnostic, Edit, Fix, FixAvailability, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::{AnyParameterRef, Expr, Parameters};
use ruff_python_ast::{
name::Name, AnyParameterRef, Expr, ExprBinOp, ExprContext, ExprName, ExprSubscript, ExprTuple,
Operator, Parameters,
};
use ruff_python_semantic::analyze::typing::traverse_union;
use ruff_text_size::Ranged;
use ruff_text_size::{Ranged, TextRange};

use crate::checkers::ast::Checker;
use crate::{checkers::ast::Checker, importer::ImportRequest};

/// ## What it does
/// Checks for parameter annotations that contain redundant unions between
Expand Down Expand Up @@ -37,6 +44,12 @@ use crate::checkers::ast::Checker;
/// def foo(x: float | str) -> None: ...
/// ```
///
/// ## Fix safety
/// This rule's fix is marked as safe, unless the type annotation contains comments.
///
/// Note that while the fix may flatten nested unions into a single top-level union,
/// the semantics of the annotation will remain unchanged.
///
/// ## References
/// - [Python documentation: The numeric tower](https://docs.python.org/3/library/numbers.html#the-numeric-tower)
/// - [PEP 484: The numeric tower](https://peps.python.org/pep-0484/#the-numeric-tower)
Expand All @@ -48,15 +61,23 @@ pub struct RedundantNumericUnion {
}

impl Violation for RedundantNumericUnion {
// Always fixable, but currently under preview.
const FIX_AVAILABILITY: FixAvailability = FixAvailability::Sometimes;

#[derive_message_formats]
fn message(&self) -> String {
let (subtype, supertype) = match self.redundancy {
Redundancy::IntFloatComplex => ("int | float", "complex"),
Redundancy::FloatComplex => ("float", "complex"),
sbrugman marked this conversation as resolved.
Show resolved Hide resolved
Redundancy::IntComplex => ("int", "complex"),
Redundancy::IntFloat => ("int", "float"),
};
format!("Use `{supertype}` instead of `{subtype} | {supertype}`")
}

fn fix_title(&self) -> Option<String> {
Some("Remove redundant type".to_string())
}
}

/// PYI041
Expand All @@ -66,57 +87,210 @@ pub(crate) fn redundant_numeric_union(checker: &mut Checker, parameters: &Parame
}
}

#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum Redundancy {
FloatComplex,
IntComplex,
IntFloat,
}

fn check_annotation(checker: &mut Checker, annotation: &Expr) {
let mut has_float = false;
let mut has_complex = false;
sbrugman marked this conversation as resolved.
Show resolved Hide resolved
let mut has_int = false;
fn check_annotation<'a>(checker: &mut Checker, annotation: &'a Expr) {
let mut numeric_flags = NumericFlags::empty();

let mut find_numeric_type = |expr: &Expr, _parent: &Expr| {
let Some(builtin_type) = checker.semantic().resolve_builtin_symbol(expr) else {
return;
};

match builtin_type {
"int" => has_int = true,
"float" => has_float = true,
"complex" => has_complex = true,
_ => {}
}
numeric_flags.seen_builtin_type(builtin_type);
};

// Traverse the union, and remember which numeric types are found.
traverse_union(&mut find_numeric_type, checker.semantic(), annotation);

sbrugman marked this conversation as resolved.
Show resolved Hide resolved
if has_complex {
if has_float {
checker.diagnostics.push(Diagnostic::new(
RedundantNumericUnion {
redundancy: Redundancy::FloatComplex,
},
annotation.range(),
));
let Some(redundancy) = Redundancy::from_numeric_flags(numeric_flags) else {
return;
};

// Traverse the union a second time to construct the fix.
let mut necessary_nodes: Vec<&Expr> = Vec::new();

let mut union_type = UnionKind::TypingUnion;
let mut remove_numeric_type = |expr: &'a Expr, parent: &'a Expr| {
let Some(builtin_type) = checker.semantic().resolve_builtin_symbol(expr) else {
// Keep type annotations that are not numeric.
necessary_nodes.push(expr);
return;
};

if matches!(parent, Expr::BinOp(_)) {
union_type = UnionKind::PEP604;
}

if has_int {
checker.diagnostics.push(Diagnostic::new(
RedundantNumericUnion {
redundancy: Redundancy::IntComplex,
},
annotation.range(),
));
// `int` is always dropped, since `float` or `complex` must be present.
// `float` is only dropped if `complex`` is present.
if (builtin_type == "float" && !numeric_flags.contains(NumericFlags::COMPLEX))
|| (builtin_type != "float" && builtin_type != "int")
{
sbrugman marked this conversation as resolved.
Show resolved Hide resolved
necessary_nodes.push(expr);
}
};

// Traverse the union a second time to construct a [`Fix`].
traverse_union(&mut remove_numeric_type, checker.semantic(), annotation);

let mut diagnostic = Diagnostic::new(RedundantNumericUnion { redundancy }, annotation.range());
if checker.settings.preview.is_enabled() {
// Mark [`Fix`] as unsafe when comments are in range.
let applicability = if checker.comment_ranges().intersects(annotation.range()) {
Applicability::Unsafe
} else {
Applicability::Safe
};

// Generate the flattened fix once.
let fix = if let &[edit_expr] = necessary_nodes.as_slice() {
// Generate a [`Fix`] for a single type expression, e.g. `int`.
Fix::applicable_edit(
Edit::range_replacement(checker.generator().expr(edit_expr), annotation.range()),
applicability,
MichaReiser marked this conversation as resolved.
Show resolved Hide resolved
)
} else {
match union_type {
UnionKind::PEP604 => {
generate_pep604_fix(checker, necessary_nodes, annotation, applicability)
}
UnionKind::TypingUnion => {
generate_union_fix(checker, necessary_nodes, annotation, applicability)
.ok()
.unwrap()
}
}
};
diagnostic.set_fix(fix);
};

checker.diagnostics.push(diagnostic);
}

#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum Redundancy {
IntFloatComplex,
FloatComplex,
IntComplex,
IntFloat,
}

impl Redundancy {
pub(super) fn from_numeric_flags(numeric_flags: NumericFlags) -> Option<Self> {
if numeric_flags == NumericFlags::INT | NumericFlags::FLOAT | NumericFlags::COMPLEX {
Some(Self::IntFloatComplex)
} else if numeric_flags == NumericFlags::FLOAT | NumericFlags::COMPLEX {
Some(Self::FloatComplex)
} else if numeric_flags == NumericFlags::INT | NumericFlags::COMPLEX {
Some(Self::IntComplex)
} else if numeric_flags == NumericFlags::FLOAT | NumericFlags::INT {
Some(Self::IntFloat)
} else {
None
}
} else if has_float && has_int {
checker.diagnostics.push(Diagnostic::new(
RedundantNumericUnion {
redundancy: Redundancy::IntFloat,
},
annotation.range(),
));
}
}

bitflags! {
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(super) struct NumericFlags: u8 {
/// `int`
const INT = 1 << 0;
/// `float`
const FLOAT = 1 << 1;
/// `complex`
const COMPLEX = 1 << 2;
}
}

impl NumericFlags {
pub(super) fn seen_builtin_type(&mut self, name: &str) {
let flag: NumericFlags = match name {
"int" => NumericFlags::INT,
"float" => NumericFlags::FLOAT,
"complex" => NumericFlags::COMPLEX,
_ => {
return;
}
};
self.insert(flag);
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum UnionKind {
/// E.g., `typing.Union[int, str]`
TypingUnion,
/// E.g., `int | str`
PEP604,
}

// Generate a [`Fix`] for two or more type expressions, e.g. `int | float | complex`.
fn generate_pep604_fix(
checker: &Checker,
nodes: Vec<&Expr>,
annotation: &Expr,
applicability: Applicability,
) -> Fix {
debug_assert!(nodes.len() >= 2, "At least two nodes required");

sbrugman marked this conversation as resolved.
Show resolved Hide resolved
let new_expr = nodes
.into_iter()
.fold(None, |acc: Option<Expr>, right: &Expr| {
if let Some(left) = acc {
Some(Expr::BinOp(ExprBinOp {
left: Box::new(left),
op: Operator::BitOr,
right: Box::new(right.clone()),
range: TextRange::default(),
}))
} else {
Some(right.clone())
}
})
.unwrap();

Fix::applicable_edit(
Edit::range_replacement(checker.generator().expr(&new_expr), annotation.range()),
applicability,
)
}

// Generate a [`Fix`] for two or more type expresisons, e.g. `typing.Union[int, float, complex]`.
fn generate_union_fix(
checker: &Checker,
nodes: Vec<&Expr>,
annotation: &Expr,
applicability: Applicability,
) -> Result<Fix> {
debug_assert!(nodes.len() >= 2, "At least two nodes required");

// Request `typing.Union`
let (import_edit, binding) = checker.importer().get_or_import_symbol(
&ImportRequest::import_from("typing", "Union"),
annotation.start(),
checker.semantic(),
)?;

// Construct the expression as `Subscript[typing.Union, Tuple[expr, [expr, ...]]]`
let new_expr = Expr::Subscript(ExprSubscript {
range: TextRange::default(),
value: Box::new(Expr::Name(ExprName {
id: Name::new(binding),
ctx: ExprContext::Store,
range: TextRange::default(),
})),
slice: Box::new(Expr::Tuple(ExprTuple {
elts: nodes.into_iter().cloned().collect(),
range: TextRange::default(),
ctx: ExprContext::Load,
parenthesized: false,
})),
ctx: ExprContext::Load,
});

Ok(Fix::applicable_edits(
Edit::range_replacement(checker.generator().expr(&new_expr), annotation.range()),
[import_edit],
applicability,
))
}
Loading
Loading