Skip to content

Commit

Permalink
[red-knot] Infer subscript expression types for bytes literals (#13901)
Browse files Browse the repository at this point in the history
## Summary

Infer subscript expression types for bytes literals:
```py
b = b"\x00abc\xff"

reveal_type(b[0])  # revealed: Literal[b"\x00"]
reveal_type(b[1])  # revealed: Literal[b"a"]
reveal_type(b[-1])  # revealed: Literal[b"\xff"]
reveal_type(b[-2])  # revealed: Literal[b"c"]

reveal_type(b[False])  # revealed: Literal[b"\x00"]
reveal_type(b[True])  # revealed: Literal[b"a"]
```


part of #13689
(#13689 (comment))

## Test Plan

- New Markdown-based tests (see `mdtest/subscript/bytes.md`)
- Added missing test for `string_literal[bool_literal]`
  • Loading branch information
sharkdp authored Oct 24, 2024
1 parent 73ee72b commit 77ae0cc
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 68 deletions.
10 changes: 10 additions & 0 deletions crates/red_knot_python_semantic/resources/mdtest/literal/bytes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Bytes literals

## Simple

```py
reveal_type(b"red" b"knot") # revealed: Literal[b"redknot"]
reveal_type(b"hello") # revealed: Literal[b"hello"]
reveal_type(b"world" + b"!") # revealed: Literal[b"world!"]
reveal_type(b"\xff\x00") # revealed: Literal[b"\xff\x00"]
```
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,33 @@
## Simple

```py
reveal_type(b"red" b"knot") # revealed: Literal[b"redknot"]
reveal_type(b"hello") # revealed: Literal[b"hello"]
reveal_type(b"world" + b"!") # revealed: Literal[b"world!"]
reveal_type(b"\xff\x00") # revealed: Literal[b"\xff\x00"]
b = b"\x00abc\xff"

reveal_type(b[0]) # revealed: Literal[b"\x00"]
reveal_type(b[1]) # revealed: Literal[b"a"]
reveal_type(b[4]) # revealed: Literal[b"\xff"]

reveal_type(b[-1]) # revealed: Literal[b"\xff"]
reveal_type(b[-2]) # revealed: Literal[b"c"]
reveal_type(b[-5]) # revealed: Literal[b"\x00"]

reveal_type(b[False]) # revealed: Literal[b"\x00"]
reveal_type(b[True]) # revealed: Literal[b"a"]

x = b[5] # error: [index-out-of-bounds] "Index 5 is out of bounds for bytes literal `Literal[b"\x00abc\xff"]` with length 5"
reveal_type(x) # revealed: Unknown

y = b[-6] # error: [index-out-of-bounds] "Index -6 is out of bounds for bytes literal `Literal[b"\x00abc\xff"]` with length 5"
reveal_type(y) # revealed: Unknown
```

## Function return

```py
def int_instance() -> int: ...


a = b"abcde"[int_instance()]
# TODO: Support overloads... Should be `bytes`
reveal_type(a) # revealed: @Todo
```
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ reveal_type(s[1]) # revealed: Literal["b"]
reveal_type(s[-1]) # revealed: Literal["e"]
reveal_type(s[-2]) # revealed: Literal["d"]

reveal_type(s[False]) # revealed: Literal["a"]
reveal_type(s[True]) # revealed: Literal["b"]

a = s[8] # error: [index-out-of-bounds] "Index 8 is out of bounds for string `Literal["abcde"]` with length 5"
reveal_type(a) # revealed: Unknown

Expand All @@ -20,11 +23,10 @@ reveal_type(b) # revealed: Unknown
## Function return

```py
def add(x: int, y: int) -> int:
return x + y
def int_instance() -> int: ...


a = "abcde"[add(0, 1)]
a = "abcde"[int_instance()]
# TODO: Support overloads... Should be `str`
reveal_type(a) # revealed: @Todo
```
1 change: 1 addition & 0 deletions crates/red_knot_python_semantic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ mod semantic_model;
pub(crate) mod site_packages;
mod stdlib;
pub mod types;
mod util;

type FxOrderSet<V> = ordermap::set::OrderSet<V, BuildHasherDefault<FxHasher>>;
89 changes: 28 additions & 61 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ use crate::types::{
typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionType, KnownFunction,
StringLiteralType, Truthiness, TupleType, Type, TypeArrayDisplay, UnionType,
};
use crate::util::subscript::PythonSubscript;
use crate::Db;

use super::{KnownClass, UnionBuilder};
Expand Down Expand Up @@ -1466,8 +1467,9 @@ impl<'db> TypeInferenceBuilder<'db> {
}

/// Emit a diagnostic declaring that an index is out of bounds for a tuple.
pub(super) fn tuple_index_out_of_bounds_diagnostic(
pub(super) fn index_out_of_bounds_diagnostic(
&mut self,
kind: &'static str,
node: AnyNodeRef,
tuple_ty: Type<'db>,
length: usize,
Expand All @@ -1477,30 +1479,12 @@ impl<'db> TypeInferenceBuilder<'db> {
node,
"index-out-of-bounds",
format_args!(
"Index {index} is out of bounds for tuple of type `{}` with length {length}",
"Index {index} is out of bounds for {kind} `{}` with length {length}",
tuple_ty.display(self.db)
),
);
}

/// Emit a diagnostic declaring that an index is out of bounds for a string.
pub(super) fn string_index_out_of_bounds_diagnostic(
&mut self,
node: AnyNodeRef,
string_ty: Type<'db>,
length: usize,
index: i64,
) {
self.add_diagnostic(
node,
"index-out-of-bounds",
format_args!(
"Index {index} is out of bounds for string `{}` with length {length}",
string_ty.display(self.db)
),
);
}

/// Emit a diagnostic declaring that a type does not support subscripting.
pub(super) fn non_subscriptable_diagnostic(
&mut self,
Expand Down Expand Up @@ -3192,30 +3176,15 @@ impl<'db> TypeInferenceBuilder<'db> {
) -> Type<'db> {
match (value_ty, slice_ty) {
// Ex) Given `("a", "b", "c", "d")[1]`, return `"b"`
(Type::Tuple(tuple_ty), Type::IntLiteral(int)) if int >= 0 => {
(Type::Tuple(tuple_ty), Type::IntLiteral(int)) => {
let elements = tuple_ty.elements(self.db);
usize::try_from(int)
.ok()
.and_then(|index| elements.get(index).copied())
.unwrap_or_else(|| {
self.tuple_index_out_of_bounds_diagnostic(
value_node.into(),
value_ty,
elements.len(),
int,
);
Type::Unknown
})
}
// Ex) Given `("a", "b", "c", "d")[-1]`, return `"c"`
(Type::Tuple(tuple_ty), Type::IntLiteral(int)) if int < 0 => {
let elements = tuple_ty.elements(self.db);
int.checked_neg()
.and_then(|int| usize::try_from(int).ok())
.and_then(|index| elements.len().checked_sub(index))
.and_then(|index| elements.get(index).copied())
elements
.iter()
.python_subscript(int)
.copied()
.unwrap_or_else(|| {
self.tuple_index_out_of_bounds_diagnostic(
self.index_out_of_bounds_diagnostic(
"tuple",
value_node.into(),
value_ty,
elements.len(),
Expand All @@ -3231,19 +3200,20 @@ impl<'db> TypeInferenceBuilder<'db> {
Type::IntLiteral(i64::from(bool)),
),
// Ex) Given `"value"[1]`, return `"a"`
(Type::StringLiteral(literal_ty), Type::IntLiteral(int)) if int >= 0 => {
(Type::StringLiteral(literal_ty), Type::IntLiteral(int)) => {
let literal_value = literal_ty.value(self.db);
usize::try_from(int)
.ok()
.and_then(|index| literal_value.chars().nth(index))
literal_value
.chars()
.python_subscript(int)
.map(|ch| {
Type::StringLiteral(StringLiteralType::new(
self.db,
ch.to_string().into_boxed_str(),
))
})
.unwrap_or_else(|| {
self.string_index_out_of_bounds_diagnostic(
self.index_out_of_bounds_diagnostic(
"string",
value_node.into(),
value_ty,
literal_value.chars().count(),
Expand All @@ -3252,31 +3222,28 @@ impl<'db> TypeInferenceBuilder<'db> {
Type::Unknown
})
}
// Ex) Given `"value"[-1]`, return `"e"`
(Type::StringLiteral(literal_ty), Type::IntLiteral(int)) if int < 0 => {
// Ex) Given `b"value"[1]`, return `b"a"`
(Type::BytesLiteral(literal_ty), Type::IntLiteral(int)) => {
let literal_value = literal_ty.value(self.db);
int.checked_neg()
.and_then(|int| usize::try_from(int).ok())
.and_then(|index| index.checked_sub(1))
.and_then(|index| literal_value.chars().rev().nth(index))
.map(|ch| {
Type::StringLiteral(StringLiteralType::new(
self.db,
ch.to_string().into_boxed_str(),
))
literal_value
.iter()
.python_subscript(int)
.map(|byte| {
Type::BytesLiteral(BytesLiteralType::new(self.db, [*byte].as_slice()))
})
.unwrap_or_else(|| {
self.string_index_out_of_bounds_diagnostic(
self.index_out_of_bounds_diagnostic(
"bytes literal",
value_node.into(),
value_ty,
literal_value.chars().count(),
literal_value.len(),
int,
);
Type::Unknown
})
}
// Ex) Given `"value"[True]`, return `"a"`
(Type::StringLiteral(_), Type::BooleanLiteral(bool)) => self
(Type::StringLiteral(_) | Type::BytesLiteral(_), Type::BooleanLiteral(bool)) => self
.infer_subscript_expression_types(
value_node,
value_ty,
Expand Down
1 change: 1 addition & 0 deletions crates/red_knot_python_semantic/src/util/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub(crate) mod subscript;
82 changes: 82 additions & 0 deletions crates/red_knot_python_semantic/src/util/subscript.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
pub(crate) trait PythonSubscript {
type Item;

fn python_subscript(&mut self, index: i64) -> Option<Self::Item>;
}

impl<I, T: DoubleEndedIterator<Item = I>> PythonSubscript for T {
type Item = I;

fn python_subscript(&mut self, index: i64) -> Option<I> {
if index >= 0 {
self.nth(usize::try_from(index).ok()?)
} else {
let nth_rev = usize::try_from(index.checked_neg()?).ok()?.checked_sub(1)?;
self.rev().nth(nth_rev)
}
}
}

#[cfg(test)]
mod tests {
use super::PythonSubscript;

#[test]
fn python_subscript_basic() {
let iter = 'a'..='e';

assert_eq!(iter.clone().python_subscript(0), Some('a'));
assert_eq!(iter.clone().python_subscript(1), Some('b'));
assert_eq!(iter.clone().python_subscript(4), Some('e'));
assert_eq!(iter.clone().python_subscript(5), None);

assert_eq!(iter.clone().python_subscript(-1), Some('e'));
assert_eq!(iter.clone().python_subscript(-2), Some('d'));
assert_eq!(iter.clone().python_subscript(-5), Some('a'));
assert_eq!(iter.clone().python_subscript(-6), None);
}

#[test]
fn python_subscript_empty() {
let iter = 'a'..'a';

assert_eq!(iter.clone().python_subscript(0), None);
assert_eq!(iter.clone().python_subscript(1), None);
assert_eq!(iter.clone().python_subscript(-1), None);
}

#[test]
fn python_subscript_single_element() {
let iter = 'a'..='a';

assert_eq!(iter.clone().python_subscript(0), Some('a'));
assert_eq!(iter.clone().python_subscript(1), None);
assert_eq!(iter.clone().python_subscript(-1), Some('a'));
assert_eq!(iter.clone().python_subscript(-2), None);
}

#[test]
fn python_subscript_uses_full_index_range() {
let iter = 0..=u64::MAX;

assert_eq!(iter.clone().python_subscript(0), Some(0));
assert_eq!(iter.clone().python_subscript(1), Some(1));
assert_eq!(
iter.clone().python_subscript(i64::MAX),
Some(i64::MAX as u64)
);

assert_eq!(iter.clone().python_subscript(-1), Some(u64::MAX));
assert_eq!(iter.clone().python_subscript(-2), Some(u64::MAX - 1));

// i64::MIN is not representable as a positive number, so it is not
// a valid index:
assert_eq!(iter.clone().python_subscript(i64::MIN), None);

// but i64::MIN +1 is:
assert_eq!(
iter.clone().python_subscript(i64::MIN + 1),
Some(2u64.pow(63) + 1)
);
}
}

0 comments on commit 77ae0cc

Please sign in to comment.