Skip to content

Commit

Permalink
[red-knot] Add diagnostic for invalid unpacking
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruvmanila committed Dec 24, 2024
1 parent 2835d94 commit ff53cb1
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 21 deletions.
27 changes: 15 additions & 12 deletions crates/red_knot_python_semantic/resources/mdtest/unpacking.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ reveal_type(c) # revealed: Literal[4]
### Uneven unpacking (1)

```py
# TODO: Add diagnostic (there aren't enough values to unpack)
# error: "Not enough values to unpack (expected 3, got 2)"
(a, b, c) = (1, 2)
reveal_type(a) # revealed: Literal[1]
reveal_type(b) # revealed: Literal[2]
Expand All @@ -71,7 +71,7 @@ reveal_type(c) # revealed: Unknown
### Uneven unpacking (2)

```py
# TODO: Add diagnostic (too many values to unpack)
# error: "Too many values to unpack (expected 2, got 3)"
(a, b) = (1, 2, 3)
reveal_type(a) # revealed: Literal[1]
reveal_type(b) # revealed: Literal[2]
Expand All @@ -80,7 +80,7 @@ reveal_type(b) # revealed: Literal[2]
### Starred expression (1)

```py
# TODO: Add diagnostic (need more values to unpack)
# error: "Not enough values to unpack (expected 3 or more, got 2)"
[a, *b, c, d] = (1, 2)
reveal_type(a) # revealed: Literal[1]
# TODO: Should be list[Any] once support for assigning to starred expression is added
Expand Down Expand Up @@ -133,7 +133,7 @@ reveal_type(c) # revealed: @Todo(starred unpacking)
### Starred expression (6)

```py
# TODO: Add diagnostic (need more values to unpack)
# error: "Not enough values to unpack (expected 5 or more, got 1)"
(a, b, c, *d, e, f) = (1,)
reveal_type(a) # revealed: Literal[1]
reveal_type(b) # revealed: Unknown
Expand Down Expand Up @@ -199,7 +199,7 @@ reveal_type(b) # revealed: LiteralString
### Uneven unpacking (1)

```py
# TODO: Add diagnostic (there aren't enough values to unpack)
# error: "Not enough values to unpack (expected 3, got 2)"
a, b, c = "ab"
reveal_type(a) # revealed: LiteralString
reveal_type(b) # revealed: LiteralString
Expand All @@ -209,7 +209,7 @@ reveal_type(c) # revealed: Unknown
### Uneven unpacking (2)

```py
# TODO: Add diagnostic (too many values to unpack)
# error: "Too many values to unpack (expected 2, got 3)"
a, b = "abc"
reveal_type(a) # revealed: LiteralString
reveal_type(b) # revealed: LiteralString
Expand All @@ -218,7 +218,7 @@ reveal_type(b) # revealed: LiteralString
### Starred expression (1)

```py
# TODO: Add diagnostic (need more values to unpack)
# error: "Not enough values to unpack (expected 3 or more, got 2)"
(a, *b, c, d) = "ab"
reveal_type(a) # revealed: LiteralString
# TODO: Should be list[LiteralString] once support for assigning to starred expression is added
Expand Down Expand Up @@ -271,7 +271,7 @@ reveal_type(c) # revealed: @Todo(starred unpacking)
### Unicode

```py
# TODO: Add diagnostic (need more values to unpack)
# error: "Not enough values to unpack (expected 2, got 1)"
(a, b) = "é"

reveal_type(a) # revealed: LiteralString
Expand All @@ -281,7 +281,7 @@ reveal_type(b) # revealed: Unknown
### Unicode escape (1)

```py
# TODO: Add diagnostic (need more values to unpack)
# error: "Not enough values to unpack (expected 2, got 1)"
(a, b) = "\u9E6C"

reveal_type(a) # revealed: LiteralString
Expand All @@ -291,7 +291,7 @@ reveal_type(b) # revealed: Unknown
### Unicode escape (2)

```py
# TODO: Add diagnostic (need more values to unpack)
# error: "Not enough values to unpack (expected 2, got 1)"
(a, b) = "\U0010FFFF"

reveal_type(a) # revealed: LiteralString
Expand Down Expand Up @@ -383,7 +383,8 @@ def _(arg: tuple[int, bytes, int] | tuple[int, int, str, int, bytes]):

```py
def _(arg: tuple[int, bytes, int] | tuple[int, int, str, int, bytes]):
# TODO: Add diagnostic (too many values to unpack)
# error: "Too many values to unpack (expected 2, got 3)"
# error: "Too many values to unpack (expected 2, got 5)"
a, b = arg
reveal_type(a) # revealed: int
reveal_type(b) # revealed: bytes | int
Expand All @@ -393,7 +394,8 @@ def _(arg: tuple[int, bytes, int] | tuple[int, int, str, int, bytes]):

```py
def _(arg: tuple[int, bytes] | tuple[int, str]):
# TODO: Add diagnostic (there aren't enough values to unpack)
# error: "Not enough values to unpack (expected 3, got 2)"
# error: "Not enough values to unpack (expected 3, got 2)"
a, b, c = arg
reveal_type(a) # revealed: int
reveal_type(b) # revealed: bytes | str
Expand Down Expand Up @@ -536,6 +538,7 @@ for a, b in ((1, 2), ("a", "b")):
# error: "Object of type `Literal[1]` is not iterable"
# error: "Object of type `Literal[2]` is not iterable"
# error: "Object of type `Literal[4]` is not iterable"
# error: "Not enough values to unpack (expected 2, got 1)"
for a, b in (1, 2, (3, "a"), 4, (5, "b"), "c"):
reveal_type(a) # revealed: Unknown | Literal[3, 5] | LiteralString
reveal_type(b) # revealed: Unknown | Literal["a", "b"]
Expand Down
6 changes: 6 additions & 0 deletions crates/red_knot_python_semantic/src/types/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,12 @@ impl<'db> TypeArrayDisplay<'db> for Vec<Type<'db>> {
}
}

impl<'db> TypeArrayDisplay<'db> for [Type<'db>] {
fn display(&self, db: &'db dyn Db) -> DisplayTypeArray {
DisplayTypeArray { types: self, db }
}
}

pub(crate) struct DisplayTypeArray<'b, 'db> {
types: &'b [Type<'db>],
db: &'db dyn Db,
Expand Down
69 changes: 60 additions & 9 deletions crates/red_knot_python_semantic/src/types/unpacker.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::borrow::Cow;
use std::cmp::Ordering;

use rustc_hash::FxHashMap;

Expand All @@ -11,6 +12,7 @@ use crate::unpack::UnpackValue;
use crate::Db;

use super::context::{InferContext, WithDiagnostics};
use super::diagnostic::INVALID_ASSIGNMENT;
use super::{TupleType, UnionType};

/// Unpacks the value expression type to their respective targets.
Expand Down Expand Up @@ -104,9 +106,33 @@ impl<'db> Unpacker<'db> {
};

if let Some(tuple_ty) = ty.into_tuple() {
let tuple_ty_elements = self.tuple_ty_elements(elts, tuple_ty);
let tuple_ty_elements = self.tuple_ty_elements(target, elts, tuple_ty);

// TODO: Add diagnostic for length mismatch
match elts.len().cmp(&tuple_ty_elements.len()) {
Ordering::Less => {
self.context.report_lint(
&INVALID_ASSIGNMENT,
target.into(),
format_args!(
"Too many values to unpack (expected {}, got {})",
elts.len(),
tuple_ty_elements.len()
),
);
}
Ordering::Greater => {
self.context.report_lint(
&INVALID_ASSIGNMENT,
target.into(),
format_args!(
"Not enough values to unpack (expected {}, got {})",
elts.len(),
tuple_ty_elements.len()
),
);
}
Ordering::Equal => {}
}

for (index, ty) in tuple_ty_elements.iter().enumerate() {
if let Some(element_types) = target_types.get_mut(index) {
Expand Down Expand Up @@ -142,29 +168,40 @@ impl<'db> Unpacker<'db> {
/// Returns the [`Type`] elements inside the given [`TupleType`] taking into account that there
/// can be a starred expression in the `elements`.
fn tuple_ty_elements(
&mut self,
&self,
expr: &ast::Expr,
targets: &[ast::Expr],
tuple_ty: TupleType<'db>,
) -> Cow<'_, [Type<'db>]> {
// If there is a starred expression, it will consume all of the entries at that location.
// If there is a starred expression, it will consume all of the types at that location.
let Some(starred_index) = targets.iter().position(ast::Expr::is_starred_expr) else {
// Otherwise, the types will be unpacked 1-1 to the elements.
// Otherwise, the types will be unpacked 1-1 to the targets.
return Cow::Borrowed(tuple_ty.elements(self.db()).as_ref());
};

if tuple_ty.len(self.db()) >= targets.len() - 1 {
// This branch is only taken when there are enough elements in the tuple type to
// combine for the starred expression. So, the arithmetic and indexing operations are
// safe to perform.
let mut element_types = Vec::with_capacity(targets.len());

// Insert all the elements before the starred expression.
element_types.extend_from_slice(
// SAFETY: Safe because of the length check above.
&tuple_ty.elements(self.db())[..starred_index],
);

// E.g., in `(a, *b, c, d) = ...`, the index of starred element `b`
// is 1 and the remaining elements after that are 2.
// The number of target expressions that are remaining after the starred expression.
// For example, in `(a, *b, c, d) = ...`, the index of starred element `b` is 1 and the
// remaining elements after that are 2.
let remaining = targets.len() - (starred_index + 1);
// This index represents the type of the last element that belongs
// to the starred expression, in an exclusive manner.

// This index represents the position of the last element that belongs to the starred
// expression, in an exclusive manner. For example, in `(a, *b, c) = (1, 2, 3, 4)`, the
// starred expression `b` will consume the elements `Literal[2]` and `Literal[3]` and
// the index value would be 3.
let starred_end_index = tuple_ty.len(self.db()) - remaining;

// SAFETY: Safe because of the length check above.
let _starred_element_types =
&tuple_ty.elements(self.db())[starred_index..starred_end_index];
Expand All @@ -173,18 +210,32 @@ impl<'db> Unpacker<'db> {
// combine_types(starred_element_types);
element_types.push(todo_type!("starred unpacking"));

// Insert the types remaining that aren't consumed by the starred expression.
element_types.extend_from_slice(
// SAFETY: Safe because of the length check above.
&tuple_ty.elements(self.db())[starred_end_index..],
);

Cow::Owned(element_types)
} else {
self.context.report_lint(
&INVALID_ASSIGNMENT,
expr.into(),
format_args!(
"Not enough values to unpack (expected {} or more, got {})",
targets.len() - 1,
tuple_ty.len(self.db())
),
);

let mut element_types = tuple_ty.elements(self.db()).to_vec();

// Subtract 1 to insert the starred expression type at the correct
// index.
element_types.resize(targets.len() - 1, Type::Unknown);
// TODO: This should be `list[Unknown]`
element_types.insert(starred_index, todo_type!("starred unpacking"));

Cow::Owned(element_types)
}
}
Expand Down

0 comments on commit ff53cb1

Please sign in to comment.