Skip to content

Commit

Permalink
Cached inference of all definitions involved in unpacking
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruvmanila committed Oct 30, 2024
1 parent bd0d782 commit 71c4911
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 119 deletions.
45 changes: 45 additions & 0 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,24 @@ pub enum DefinitionKind {
}

impl DefinitionKind {
pub(crate) fn as_unpack_target(&self) -> Option<UnpackTarget> {
match self {
DefinitionKind::Assignment(AssignmentDefinitionKind {
assignment,
target_index,
..
}) => Some(UnpackTarget::Assignment(UnpackTargetAssignment {
assignment: assignment.clone(),
target_index: *target_index,
})),
DefinitionKind::For(ForStmtDefinitionKind { .. }) => {
// TODO
None
}
_ => None,
}
}

pub(crate) fn category(&self) -> DefinitionCategory {
match self {
// functions, classes, and imports always bind, and we consider them declarations
Expand Down Expand Up @@ -661,3 +679,30 @@ impl From<&ast::ExceptHandlerExceptHandler> for DefinitionNodeKey {
Self(NodeKey::from_node(handler))
}
}

#[derive(Clone, Debug)]
pub enum UnpackTarget {
Assignment(UnpackTargetAssignment),
For(AstNodeRef<ast::Expr>),
}

impl UnpackTarget {
pub(crate) fn node(&self) -> &ast::Expr {
match self {
UnpackTarget::Assignment(assignment) => assignment.target(),
UnpackTarget::For(target) => target,
}
}
}

#[derive(Clone, Debug)]
pub struct UnpackTargetAssignment {
assignment: AstNodeRef<ast::StmtAssign>,
target_index: usize,
}

impl UnpackTargetAssignment {
pub(crate) fn target(&self) -> &ast::Expr {
&self.assignment.targets[self.target_index]
}
}
1 change: 1 addition & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ mod diagnostic;
mod display;
mod infer;
mod narrow;
mod unpack;

pub fn check_types(db: &dyn Db, file: File) -> TypeCheckDiagnostics {
let _span = tracing::trace_span!("check_types", file=?file.path(db)).entered();
Expand Down
4 changes: 4 additions & 0 deletions crates/red_knot_python_semantic/src/types/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ impl<'db> TypeCheckDiagnosticsBuilder<'db> {
});
}

pub(super) fn extend(&mut self, diagnostics: &TypeCheckDiagnostics) {
self.diagnostics.extend(diagnostics);
}

pub(super) fn finish(mut self) -> TypeCheckDiagnostics {
self.diagnostics.shrink_to_fit();
self.diagnostics
Expand Down
148 changes: 29 additions & 119 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
//! stringified annotations. We have a fourth Salsa query for inferring the deferred types
//! associated with a particular definition. Scope-level inference infers deferred types for all
//! definitions once the rest of the types in the scope have been inferred.
use std::borrow::Cow;
use std::num::NonZeroU32;

use itertools::Itertools;
Expand All @@ -52,6 +51,7 @@ use crate::stdlib::builtins_module_scope;
use crate::types::diagnostic::{
TypeCheckDiagnostic, TypeCheckDiagnostics, TypeCheckDiagnosticsBuilder,
};
use crate::types::unpack::{Unpack, UnpackResult, Unpacker};
use crate::types::{
bindings_ty, builtins_symbol_ty, declarations_ty, global_symbol_ty, symbol_ty,
typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionType, IterationOutcome,
Expand Down Expand Up @@ -161,6 +161,18 @@ pub(crate) fn infer_expression_types<'db>(
TypeInferenceBuilder::new(db, InferenceRegion::Expression(expression), index).finish()
}

#[salsa::tracked(return_ref)]
fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> UnpackResult<'db> {
let file = unpack.file(db);
let _span =
tracing::trace_span!("infer_unpack_types", unpack=?unpack.as_id(), file=%file.path(db))
.entered();

let mut unpacker = Unpacker::new(db, file);
unpacker.unpack(unpack);
unpacker.finish()
}

/// A region within which we can infer types.
pub(crate) enum InferenceRegion<'db> {
/// infer types for a standalone [`Expression`]
Expand Down Expand Up @@ -439,7 +451,6 @@ impl<'db> TypeInferenceBuilder<'db> {
}
DefinitionKind::Assignment(assignment) => {
self.infer_assignment_definition(
assignment.target(),
assignment.value(),
assignment.name(),
assignment.kind(),
Expand Down Expand Up @@ -1179,7 +1190,6 @@ impl<'db> TypeInferenceBuilder<'db> {

fn infer_assignment_definition(
&mut self,
target: &ast::Expr,
value: &ast::Expr,
name: &ast::ExprName,
kind: AssignmentKind,
Expand All @@ -1190,128 +1200,28 @@ impl<'db> TypeInferenceBuilder<'db> {
self.extend(result);

let value_ty = self.expression_ty(value);
let name_ast_id = name.scoped_ast_id(self.db, self.scope());

let target_ty = match kind {
AssignmentKind::Sequence => self.infer_sequence_unpacking(target, value_ty, name),
AssignmentKind::Sequence => {
let target = definition.kind(self.db).as_unpack_target().unwrap();
let unpack = Unpack::new(
self.db,
self.file,
definition.file_scope(self.db),
target,
value_ty,
countme::Count::default(),
);
let unpacked = infer_unpack_types(self.db, unpack);
self.diagnostics.extend(unpacked.diagnostics());
unpacked.get(name_ast_id).unwrap_or(Type::Unknown)
}
AssignmentKind::Name => value_ty,
};

self.add_binding(name.into(), definition, target_ty);
self.types
.expressions
.insert(name.scoped_ast_id(self.db, self.scope()), target_ty);
}

fn infer_sequence_unpacking(
&mut self,
target: &ast::Expr,
value_ty: Type<'db>,
name: &ast::ExprName,
) -> Type<'db> {
// The inner function is recursive and only differs in the return type which is an `Option`
// where if the variable is found, the corresponding type is returned otherwise `None`.
fn inner<'db>(
builder: &mut TypeInferenceBuilder<'db>,
target: &ast::Expr,
value_ty: Type<'db>,
name: &ast::ExprName,
) -> Option<Type<'db>> {
match target {
ast::Expr::Name(target_name) if target_name == name => {
return Some(value_ty);
}
ast::Expr::Starred(ast::ExprStarred { value, .. }) => {
return inner(builder, value, value_ty, name);
}
ast::Expr::List(ast::ExprList { elts, .. })
| ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => match value_ty {
Type::Tuple(tuple_ty) => {
let starred_index = elts.iter().position(ast::Expr::is_starred_expr);

let element_types = if let Some(starred_index) = starred_index {
if tuple_ty.len(builder.db) >= elts.len() - 1 {
let mut element_types = Vec::with_capacity(elts.len());
element_types.extend_from_slice(
// SAFETY: Safe because of the length check above.
&tuple_ty.elements(builder.db)[..starred_index],
);

// E.g., in `(a, *b, c, d) = ...`, the index of starred element `b`
// is 1 and the remaining elements after that are 2.
let remaining = elts.len() - (starred_index + 1);
// This index represents the type of the last element that belongs
// to the starred expression, in an exclusive manner.
let starred_end_index = tuple_ty.len(builder.db) - remaining;
// SAFETY: Safe because of the length check above.
let _starred_element_types = &tuple_ty.elements(builder.db)
[starred_index..starred_end_index];
// TODO: Combine the types into a list type. If the
// starred_element_types is empty, then it should be `List[Any]`.
// combine_types(starred_element_types);
element_types.push(Type::Todo);

element_types.extend_from_slice(
// SAFETY: Safe because of the length check above.
&tuple_ty.elements(builder.db)[starred_end_index..],
);
Cow::Owned(element_types)
} else {
let mut element_types = tuple_ty.elements(builder.db).to_vec();
element_types.insert(starred_index, Type::Todo);
Cow::Owned(element_types)
}
} else {
Cow::Borrowed(tuple_ty.elements(builder.db).as_ref())
};

for (index, element) in elts.iter().enumerate() {
if let Some(ty) = inner(
builder,
element,
element_types.get(index).copied().unwrap_or(Type::Unknown),
name,
) {
return Some(ty);
}
}
}
Type::StringLiteral(string_literal_ty) => {
// Deconstruct the string literal to delegate the inference back to the
// tuple type for correct handling of starred expressions. We could go
// further and deconstruct to an array of `StringLiteral` with each
// individual character, instead of just an array of `LiteralString`, but
// there would be a cost and it's not clear that it's worth it.
let value_ty = Type::Tuple(TupleType::new(
builder.db,
vec![Type::LiteralString; string_literal_ty.len(builder.db)]
.into_boxed_slice(),
));
if let Some(ty) = inner(builder, target, value_ty, name) {
return Some(ty);
}
}
_ => {
let value_ty = if value_ty.is_literal_string() {
Type::LiteralString
} else {
value_ty.iterate(builder.db).unwrap_with_diagnostic(
AnyNodeRef::from(target),
&mut builder.diagnostics,
)
};
for element in elts {
if let Some(ty) = inner(builder, element, value_ty, name) {
return Some(ty);
}
}
}
},
_ => {}
}
None
}

inner(self, target, value_ty, name).unwrap_or(Type::Unknown)
self.types.expressions.insert(name_ast_id, target_ty);
}

fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) {
Expand Down
Loading

0 comments on commit 71c4911

Please sign in to comment.