Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] Infer subscript expression types for bytes literals #13901

Merged
merged 10 commits into from
Oct 24, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Bytes literals

## Simple

```py
reveal_type(b"red" b"knot") # revealed: Literal[b"redknot"]
reveal_type(b"hello") # revealed: Literal[b"hello"]
reveal_type(b"world" + b"!") # revealed: Literal[b"world!"]
reveal_type(b"\xff\x00") # revealed: Literal[b"\xff\x00"]
```
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,33 @@
## Simple

```py
reveal_type(b"red" b"knot") # revealed: Literal[b"redknot"]
reveal_type(b"hello") # revealed: Literal[b"hello"]
reveal_type(b"world" + b"!") # revealed: Literal[b"world!"]
reveal_type(b"\xff\x00") # revealed: Literal[b"\xff\x00"]
Comment on lines -6 to -9
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were placed in the wrong file => moved to mdtest/literal/bytes.md.

b = b"\x00abc\xff"

reveal_type(b[0]) # revealed: Literal[b"\x00"]
reveal_type(b[1]) # revealed: Literal[b"a"]
reveal_type(b[4]) # revealed: Literal[b"\xff"]

reveal_type(b[-1]) # revealed: Literal[b"\xff"]
reveal_type(b[-2]) # revealed: Literal[b"c"]
reveal_type(b[-5]) # revealed: Literal[b"\x00"]

reveal_type(b[False]) # revealed: Literal[b"\x00"]
reveal_type(b[True]) # revealed: Literal[b"a"]

x = b[5] # error: [index-out-of-bounds] "Index 5 is out of bounds for bytes literal `Literal[b"\x00abc\xff"]` with length 5"
reveal_type(x) # revealed: Unknown

y = b[-6] # error: [index-out-of-bounds] "Index -6 is out of bounds for bytes literal `Literal[b"\x00abc\xff"]` with length 5"
reveal_type(y) # revealed: Unknown
```

## Function return

```py
def int_instance() -> int: ...


a = b"abcde"[int_instance()]
# TODO: Support overloads... Should be `bytes`
reveal_type(a) # revealed: @Todo
```
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ reveal_type(s[1]) # revealed: Literal["b"]
reveal_type(s[-1]) # revealed: Literal["e"]
reveal_type(s[-2]) # revealed: Literal["d"]

reveal_type(s[False]) # revealed: Literal["a"]
reveal_type(s[True]) # revealed: Literal["b"]

a = s[8] # error: [index-out-of-bounds] "Index 8 is out of bounds for string `Literal["abcde"]` with length 5"
reveal_type(a) # revealed: Unknown

Expand All @@ -20,11 +23,10 @@ reveal_type(b) # revealed: Unknown
## Function return

```py
def add(x: int, y: int) -> int:
return x + y
Comment on lines -23 to -24
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am assuming add was intended to be a function that definitely returns an int, not a LiteralInt. Elsewhere, we just use def int_instance() -> int: ... for this. Let me know if I'm misunderstanding the intention.

def int_instance() -> int: ...


a = "abcde"[add(0, 1)]
a = "abcde"[int_instance()]
# TODO: Support overloads... Should be `str`
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that we want to infer str here. But I don't understand the "Support overloads" comment. Can someone explain?

Copy link
Member

@AlexWaygood AlexWaygood Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We attempt to infer the type here by looking up str.__getitem__ in typeshed and looking at what that method is annotated as returning. But str.__getitem__ is an overloaded function in typeshed, and we don't yet have the ability to understand those.

In fact, we just infer @Todo as the return type for all decorated functions currently, since a decorated function (and its return type) is transformed by its decorator(s) before the user sees it. That means it's less simple than just "believe the return type annotation" when it comes to understanding the return type of decorated functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

reveal_type(a) # revealed: @Todo
```
126 changes: 57 additions & 69 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1466,8 +1466,9 @@ impl<'db> TypeInferenceBuilder<'db> {
}

/// Emit a diagnostic declaring that an index is out of bounds for a tuple.
pub(super) fn tuple_index_out_of_bounds_diagnostic(
pub(super) fn index_out_of_bounds_diagnostic(
&mut self,
kind: &'static str,
node: AnyNodeRef,
tuple_ty: Type<'db>,
length: usize,
Expand All @@ -1477,30 +1478,12 @@ impl<'db> TypeInferenceBuilder<'db> {
node,
"index-out-of-bounds",
format_args!(
"Index {index} is out of bounds for tuple of type `{}` with length {length}",
"Index {index} is out of bounds for {kind} `{}` with length {length}",
tuple_ty.display(self.db)
),
);
}

/// Emit a diagnostic declaring that an index is out of bounds for a string.
pub(super) fn string_index_out_of_bounds_diagnostic(
&mut self,
node: AnyNodeRef,
string_ty: Type<'db>,
length: usize,
index: i64,
) {
self.add_diagnostic(
node,
"index-out-of-bounds",
format_args!(
"Index {index} is out of bounds for string `{}` with length {length}",
string_ty.display(self.db)
),
);
}

/// Emit a diagnostic declaring that a type does not support subscripting.
pub(super) fn non_subscriptable_diagnostic(
&mut self,
Expand Down Expand Up @@ -3190,39 +3173,50 @@ impl<'db> TypeInferenceBuilder<'db> {
value_ty: Type<'db>,
slice_ty: Type<'db>,
) -> Type<'db> {
fn iterator_at_index<T>(
sharkdp marked this conversation as resolved.
Show resolved Hide resolved
mut iter: impl DoubleEndedIterator<Item = T>,
index: i64,
) -> Option<T> {
if index < 0 {
let nth_rev = index
.checked_neg()
.and_then(|int| usize::try_from(int).ok())?
.checked_sub(1)?;

iter.rev().nth(nth_rev)
} else {
let nth = usize::try_from(index).ok()?;
iter.nth(nth)
}
}

fn slice_at_index<T>(slice: &[T], index: i64) -> Option<&T> {
let positive_index = if index < 0 {
slice.len().checked_sub(
index
.checked_neg()
.and_then(|int| usize::try_from(int).ok())?,
)
} else {
usize::try_from(index).ok()
};
slice.get(positive_index?)
}

match (value_ty, slice_ty) {
// Ex) Given `("a", "b", "c", "d")[1]`, return `"b"`
(Type::Tuple(tuple_ty), Type::IntLiteral(int)) if int >= 0 => {
let elements = tuple_ty.elements(self.db);
usize::try_from(int)
.ok()
.and_then(|index| elements.get(index).copied())
.unwrap_or_else(|| {
self.tuple_index_out_of_bounds_diagnostic(
value_node.into(),
value_ty,
elements.len(),
int,
);
Type::Unknown
})
}
// Ex) Given `("a", "b", "c", "d")[-1]`, return `"c"`
(Type::Tuple(tuple_ty), Type::IntLiteral(int)) if int < 0 => {
(Type::Tuple(tuple_ty), Type::IntLiteral(int)) => {
let elements = tuple_ty.elements(self.db);
int.checked_neg()
.and_then(|int| usize::try_from(int).ok())
.and_then(|index| elements.len().checked_sub(index))
.and_then(|index| elements.get(index).copied())
.unwrap_or_else(|| {
self.tuple_index_out_of_bounds_diagnostic(
value_node.into(),
value_ty,
elements.len(),
int,
);
Type::Unknown
})
slice_at_index(elements, int).copied().unwrap_or_else(|| {
self.index_out_of_bounds_diagnostic(
"tuple",
value_node.into(),
value_ty,
elements.len(),
int,
);
Type::Unknown
})
}
// Ex) Given `("a", "b", "c", "d")[True]`, return `"b"`
(Type::Tuple(_), Type::BooleanLiteral(bool)) => self.infer_subscript_expression_types(
Expand All @@ -3231,19 +3225,18 @@ impl<'db> TypeInferenceBuilder<'db> {
Type::IntLiteral(i64::from(bool)),
),
// Ex) Given `"value"[1]`, return `"a"`
(Type::StringLiteral(literal_ty), Type::IntLiteral(int)) if int >= 0 => {
(Type::StringLiteral(literal_ty), Type::IntLiteral(int)) => {
let literal_value = literal_ty.value(self.db);
usize::try_from(int)
.ok()
.and_then(|index| literal_value.chars().nth(index))
iterator_at_index(literal_value.chars(), int)
.map(|ch| {
Type::StringLiteral(StringLiteralType::new(
self.db,
ch.to_string().into_boxed_str(),
))
})
.unwrap_or_else(|| {
self.string_index_out_of_bounds_diagnostic(
self.index_out_of_bounds_diagnostic(
"string",
value_node.into(),
value_ty,
literal_value.chars().count(),
Expand All @@ -3252,31 +3245,26 @@ impl<'db> TypeInferenceBuilder<'db> {
Type::Unknown
})
}
// Ex) Given `"value"[-1]`, return `"e"`
(Type::StringLiteral(literal_ty), Type::IntLiteral(int)) if int < 0 => {
// Ex) Given `b"value"[1]`, return `b"a"`
(Type::BytesLiteral(literal_ty), Type::IntLiteral(int)) => {
let literal_value = literal_ty.value(self.db);
int.checked_neg()
.and_then(|int| usize::try_from(int).ok())
.and_then(|index| index.checked_sub(1))
.and_then(|index| literal_value.chars().rev().nth(index))
.map(|ch| {
Type::StringLiteral(StringLiteralType::new(
self.db,
ch.to_string().into_boxed_str(),
))
slice_at_index(literal_value.as_ref(), int)
.map(|byte| {
Type::BytesLiteral(BytesLiteralType::new(self.db, [*byte].as_slice()))
})
.unwrap_or_else(|| {
self.string_index_out_of_bounds_diagnostic(
self.index_out_of_bounds_diagnostic(
"bytes literal",
value_node.into(),
value_ty,
literal_value.chars().count(),
literal_value.len(),
int,
);
Type::Unknown
})
}
// Ex) Given `"value"[True]`, return `"a"`
(Type::StringLiteral(_), Type::BooleanLiteral(bool)) => self
(Type::StringLiteral(_) | Type::BytesLiteral(_), Type::BooleanLiteral(bool)) => self
.infer_subscript_expression_types(
value_node,
value_ty,
Expand Down
Loading