diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index ac5463053d08c..f5c101921cdb1 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -198,8 +198,10 @@ impl<'db> SemanticIndex<'db> { pub(crate) fn definition( &self, definition_key: impl Into, - ) -> Definition<'db> { - self.definitions_by_node[&definition_key.into()] + ) -> Option> { + self.definitions_by_node + .get(&definition_key.into()) + .copied() } /// Returns the [`Expression`] ingredient for an expression node. @@ -734,8 +736,9 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): .elt .as_name_expr() .unwrap(); - let element_use_id = - element.scoped_use_id(&db, comprehension_scope_id.to_scope_id(&db, file)); + let element_use_id = element + .scoped_use_id(&db, comprehension_scope_id.to_scope_id(&db, file)) + .unwrap(); let binding = use_def.first_binding_at_use(element_use_id).unwrap(); let DefinitionKind::Comprehension(comprehension) = binding.kind(&db) else { @@ -985,7 +988,7 @@ class C[T]: let ast::Expr::Name(x_use_expr_name) = x_use_expr.as_ref() else { panic!("expected a Name"); }; - let x_use_id = x_use_expr_name.scoped_use_id(&db, scope); + let x_use_id = x_use_expr_name.scoped_use_id(&db, scope).unwrap(); let use_def = use_def_map(&db, scope); let binding = use_def.first_binding_at_use(x_use_id).unwrap(); let DefinitionKind::Assignment(assignment) = binding.kind(&db) else { diff --git a/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs b/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs index 77750b730368f..528837fa66609 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs @@ -33,12 +33,12 @@ pub(crate) struct AstIds { } impl AstIds { - fn expression_id(&self, key: impl Into) -> ScopedExpressionId { - self.expressions_map[&key.into()] + fn expression_id(&self, key: impl Into) -> Option { + self.expressions_map.get(&key.into()).copied() } - fn use_id(&self, key: impl Into) -> ScopedUseId { - self.uses_map[&key.into()] + fn use_id(&self, key: impl Into) -> Option { + self.uses_map.get(&key.into()).copied() } } @@ -51,7 +51,7 @@ pub trait HasScopedUseId { type Id: Copy; /// Returns the ID that uniquely identifies the use in `scope`. - fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id; + fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> Option; } /// Uniquely identifies a use of a name in a [`crate::semantic_index::symbol::FileScopeId`]. @@ -61,7 +61,7 @@ pub struct ScopedUseId; impl HasScopedUseId for ast::ExprName { type Id = ScopedUseId; - fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id { + fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> Option { let expression_ref = ExpressionRef::from(self); expression_ref.scoped_use_id(db, scope) } @@ -70,7 +70,7 @@ impl HasScopedUseId for ast::ExprName { impl HasScopedUseId for ast::ExpressionRef<'_> { type Id = ScopedUseId; - fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id { + fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> Option { let ast_ids = ast_ids(db, scope); ast_ids.use_id(*self) } @@ -81,7 +81,7 @@ pub trait HasScopedAstId { type Id: Copy; /// Returns the ID that uniquely identifies the node in `scope`. - fn scoped_ast_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id; + fn scoped_ast_id(&self, db: &dyn Db, scope: ScopeId) -> Option; } /// Uniquely identifies an [`ast::Expr`] in a [`crate::semantic_index::symbol::FileScopeId`]. @@ -93,7 +93,7 @@ macro_rules! impl_has_scoped_expression_id { impl HasScopedAstId for $ty { type Id = ScopedExpressionId; - fn scoped_ast_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id { + fn scoped_ast_id(&self, db: &dyn Db, scope: ScopeId) -> Option { let expression_ref = ExpressionRef::from(self); expression_ref.scoped_ast_id(db, scope) } @@ -138,7 +138,7 @@ impl_has_scoped_expression_id!(ast::Expr); impl HasScopedAstId for ast::ExpressionRef<'_> { type Id = ScopedExpressionId; - fn scoped_ast_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id { + fn scoped_ast_id(&self, db: &dyn Db, scope: ScopeId) -> Option { let ast_ids = ast_ids(db, scope); ast_ids.expression_id(*self) } diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 56df1c44d9ade..8c288370cc6b0 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -9,6 +9,7 @@ use ruff_python_ast as ast; use ruff_python_ast::name::Name; use ruff_python_ast::visitor::{walk_expr, walk_pattern, walk_stmt, Visitor}; use ruff_python_ast::AnyParameterRef; +use ruff_python_ast::Expr; use crate::ast_node_ref::AstNodeRef; use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; @@ -212,7 +213,9 @@ impl<'db> SemanticIndexBuilder<'db> { let existing_definition = self .definitions_by_node .insert(definition_node.key(), definition); - debug_assert_eq!(existing_definition, None); + if existing_definition.is_some() { + tracing::warn!("Existing definition was unexpectedly evicted"); + } if category.is_binding() { self.mark_symbol_bound(symbol); @@ -574,13 +577,21 @@ where } ast::Stmt::AnnAssign(node) => { debug_assert!(self.current_assignment.is_none()); - self.visit_expr(&node.annotation); - if let Some(value) = &node.value { - self.visit_expr(value); + let valid_target = matches!( + *node.target, + Expr::Attribute(_) | Expr::Subscript(_) | Expr::Name(_) + ); + if valid_target { + self.visit_expr(&node.annotation); + if let Some(value) = &node.value { + self.visit_expr(value); + } + self.current_assignment = Some(node.into()); + self.visit_expr(&node.target); + self.current_assignment = None; + } else { + tracing::warn!("Annotated assignment with invalid target received"); } - self.current_assignment = Some(node.into()); - self.visit_expr(&node.target); - self.current_assignment = None; } ast::Stmt::AugAssign( aug_assign @ ast::StmtAugAssign { @@ -879,12 +890,17 @@ where walk_expr(self, expr); } ast::Expr::Named(node) => { - debug_assert!(self.current_assignment.is_none()); - // TODO walrus in comprehensions is implicitly nonlocal - self.visit_expr(&node.value); - self.current_assignment = Some(node.into()); - self.visit_expr(&node.target); - self.current_assignment = None; + if self.current_assignment.is_some() { + // This can happen if we have something like x = y := 2 + // which is invalid syntax but still is provided in the AST + tracing::warn!("Current assignment is unexpectedly set"); + } else { + // TODO walrus in comprehensions is implicitly nonlocal + self.visit_expr(&node.value); + self.current_assignment = Some(node.into()); + self.visit_expr(&node.target); + self.current_assignment = None; + } } ast::Expr::Lambda(lambda) => { if let Some(parameters) = &lambda.parameters { diff --git a/crates/red_knot_python_semantic/src/semantic_model.rs b/crates/red_knot_python_semantic/src/semantic_model.rs index 411d87b6770e3..478f8dfc5de82 100644 --- a/crates/red_knot_python_semantic/src/semantic_model.rs +++ b/crates/red_knot_python_semantic/src/semantic_model.rs @@ -58,8 +58,13 @@ impl HasTy for ast::ExpressionRef<'_> { let file_scope = index.expression_scope_id(*self); let scope = file_scope.to_scope_id(model.db, model.file); - let expression_id = self.scoped_ast_id(model.db, scope); - infer_scope_types(model.db, scope).expression_ty(expression_id) + if let Some(expression_id) = self.scoped_ast_id(model.db, scope) { + let lookup = infer_scope_types(model.db, scope).try_expression_ty(expression_id); + lookup.unwrap_or(Type::Unknown) + } else { + tracing::warn!("Couldn't find expression ID"); + Type::Unknown + } } } @@ -153,8 +158,10 @@ macro_rules! impl_binding_has_ty { #[inline] fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { let index = semantic_index(model.db, model.file); - let binding = index.definition(self); - binding_ty(model.db, binding) + match index.definition(self) { + Some(binding) => binding_ty(model.db, binding), + None => Type::Unknown, + } } } }; diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index f820331a21c35..acfed5f169132 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -114,12 +114,18 @@ fn definition_expression_ty<'db>( definition: Definition<'db>, expression: &ast::Expr, ) -> Type<'db> { - let expr_id = expression.scoped_ast_id(db, definition.scope(db)); - let inference = infer_definition_types(db, definition); - if let Some(ty) = inference.try_expression_ty(expr_id) { - ty + if let Some(expr_id) = expression.scoped_ast_id(db, definition.scope(db)) { + let inference = infer_definition_types(db, definition); + if let Some(ty) = inference.try_expression_ty(expr_id) { + ty + } else { + infer_deferred_types(db, definition) + .try_expression_ty(expr_id) + .unwrap_or(Type::Unknown) + } } else { - infer_deferred_types(db, definition).expression_ty(expr_id) + tracing::warn!("Can't find expression ID"); + Type::Unknown } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index f894d6dc24133..bef42a6526262 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -96,6 +96,27 @@ fn infer_definition_types_cycle_recovery<'db>( inference } +/// Cycle recovery for [`infer_deferred_types()`]: for now, just [`Type::Unknown`] +/// TODO fixpoint iteration +fn infer_deferred_types_cycle_recovery<'db>( + db: &'db dyn Db, + _cycle: &salsa::Cycle, + input: Definition<'db>, +) -> TypeInference<'db> { + tracing::trace!("infer_deferred_types_cycle_recovery"); + let mut inference = TypeInference::default(); + let category = input.category(db); + if category.is_declaration() { + inference.declarations.insert(input, Type::Unknown); + } + if category.is_binding() { + inference.bindings.insert(input, Type::Unknown); + } + // TODO we don't fill in expression types for the cycle-participant definitions, which can + // later cause a panic when looking up an expression type. + inference +} + /// Infer all types for a [`Definition`] (including sub-expressions). /// Use when resolving a symbol name use or public type of a symbol. #[salsa::tracked(return_ref, recovery_fn=infer_definition_types_cycle_recovery)] @@ -120,7 +141,7 @@ pub(crate) fn infer_definition_types<'db>( /// /// Deferred expressions are type expressions (annotations, base classes, aliases...) in a stub /// file, or in a file with `from __future__ import annotations`, or stringified annotations. -#[salsa::tracked(return_ref)] +#[salsa::tracked(return_ref, recovery_fn=infer_deferred_types_cycle_recovery)] pub(crate) fn infer_deferred_types<'db>( db: &'db dyn Db, definition: Definition<'db>, @@ -190,9 +211,9 @@ pub(crate) struct TypeInference<'db> { } impl<'db> TypeInference<'db> { - pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> { - self.expressions[&expression] - } + // pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> { + // self.expressions[&expression] + // } pub(crate) fn try_expression_ty(&self, expression: ScopedExpressionId) -> Option> { self.expressions.get(&expression).copied() @@ -334,9 +355,11 @@ impl<'db> TypeInferenceBuilder<'db> { /// Get the already-inferred type of an expression node. /// /// PANIC if no type has been inferred for this node. - fn expression_ty(&self, expr: &ast::Expr) -> Type<'db> { - self.types - .expression_ty(expr.scoped_ast_id(self.db, self.scope)) + fn try_expression_ty(&self, expr: &ast::Expr) -> Option> { + match expr.scoped_ast_id(self.db, self.scope) { + Some(id) => self.types.try_expression_ty(id), + None => None, + } } /// Infers types in the given [`InferenceRegion`]. @@ -700,9 +723,15 @@ impl<'db> TypeInferenceBuilder<'db> { } fn infer_definition(&mut self, node: impl Into) { - let definition = self.index.definition(node); - let result = infer_definition_types(self.db, definition); - self.extend(result); + match self.index.definition(node) { + Some(definition) => { + let result = infer_definition_types(self.db, definition); + self.extend(result); + } + None => { + tracing::warn!("Couldn't find definition"); + } + } } fn infer_function_definition_statement(&mut self, function: &ast::StmtFunctionDef) { @@ -1001,12 +1030,17 @@ impl<'db> TypeInferenceBuilder<'db> { // TODO(dhruvmanila): The correct type inference here is the return type of the __enter__ // method of the context manager. - let context_expr_ty = self.expression_ty(&with_item.context_expr); + let context_expr_ty = self.try_expression_ty(&with_item.context_expr).unwrap(); - self.types - .expressions - .insert(target.scoped_ast_id(self.db, self.scope), context_expr_ty); - self.add_binding(target.into(), definition, context_expr_ty); + match target.scoped_ast_id(self.db, self.scope) { + Some(id) => { + self.types.expressions.insert(id, context_expr_ty); + self.add_binding(target.into(), definition, context_expr_ty); + } + _ => { + tracing::warn!("Couldn't find ID to infer with"); + } + } } fn infer_except_handler_definition( @@ -1173,11 +1207,18 @@ impl<'db> TypeInferenceBuilder<'db> { let expression = self.index.expression(assignment.value.as_ref()); let result = infer_expression_types(self.db, expression); self.extend(result); - let value_ty = self.expression_ty(&assignment.value); + let value_ty = self + .try_expression_ty(&assignment.value) + .unwrap_or(Type::Unknown); self.add_binding(assignment.into(), definition, value_ty); - self.types - .expressions - .insert(target.scoped_ast_id(self.db, self.scope), value_ty); + match target.scoped_ast_id(self.db, self.scope) { + Some(id) => { + self.types.expressions.insert(id, value_ty); + } + None => { + tracing::warn!("Couldn't find ID for target"); + } + } } fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) { @@ -1369,7 +1410,7 @@ impl<'db> TypeInferenceBuilder<'db> { let expression = self.index.expression(iterable); let result = infer_expression_types(self.db, expression); self.extend(result); - let iterable_ty = self.expression_ty(iterable); + let iterable_ty = self.try_expression_ty(iterable).unwrap(); let loop_var_value_ty = if is_async { // TODO(Alex): async iterables/iterators! @@ -1380,10 +1421,15 @@ impl<'db> TypeInferenceBuilder<'db> { .unwrap_with_diagnostic(iterable.into(), self) }; - self.types - .expressions - .insert(target.scoped_ast_id(self.db, self.scope), loop_var_value_ty); - self.add_binding(target.into(), definition, loop_var_value_ty); + match target.scoped_ast_id(self.db, self.scope) { + Some(id) => { + self.types.expressions.insert(id, loop_var_value_ty); + self.add_binding(target.into(), definition, loop_var_value_ty); + } + None => { + tracing::warn!("Failed to find target ID"); + } + } } fn infer_while_statement(&mut self, while_statement: &ast::StmtWhile) { @@ -1706,12 +1752,29 @@ impl<'db> TypeInferenceBuilder<'db> { ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression), ast::Expr::YieldFrom(yield_from) => self.infer_yield_from_expression(yield_from), ast::Expr::Await(await_expression) => self.infer_await_expression(await_expression), - ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"), + ast::Expr::IpyEscapeCommand(_) => { + // todo!("Implement Ipy escape command support"), + return Type::Unknown; + } }; - let expr_id = expression.scoped_ast_id(self.db, self.scope); - let previous = self.types.expressions.insert(expr_id, ty); - assert_eq!(previous, None); + match expression.scoped_ast_id(self.db, self.scope) { + Some(expr_id) => { + let previous = self.types.expressions.insert(expr_id, ty); + match previous { + None => {} + Some(Type::Unknown) => { + tracing::warn!("Already had included an unknown type"); + } + other => { + assert_eq!(other, None); + } + } + } + None => { + tracing::warn!("Could not find ID for expression"); + } + } ty } @@ -2034,10 +2097,20 @@ impl<'db> TypeInferenceBuilder<'db> { .parent_scope_id(self.scope.file_scope_id(self.db)) .expect("A comprehension should never be the top-level scope") .to_scope_id(self.db, self.file); - result.expression_ty(iterable.scoped_ast_id(self.db, lookup_scope)) + if let Some(id) = iterable.scoped_ast_id(self.db, lookup_scope) { + result.try_expression_ty(id).unwrap_or(Type::Unknown) + } else { + tracing::warn!("Couldn't find AST ID for iterable"); + Type::Unknown + } } else { self.extend(result); - result.expression_ty(iterable.scoped_ast_id(self.db, self.scope)) + if let Some(id) = iterable.scoped_ast_id(self.db, self.scope) { + result.try_expression_ty(id).unwrap_or(Type::Unknown) + } else { + tracing::warn!("Couldn't find AST ID for iterable"); + Type::Unknown + } }; let target_ty = if is_async { @@ -2049,17 +2122,26 @@ impl<'db> TypeInferenceBuilder<'db> { .unwrap_with_diagnostic(iterable.into(), self) }; - self.types - .expressions - .insert(target.scoped_ast_id(self.db, self.scope), target_ty); + match target.scoped_ast_id(self.db, self.scope) { + Some(id) => { + self.types.expressions.insert(id, target_ty); + } + None => { + tracing::warn!("Couldn't find AST ID for expression"); + } + } self.add_binding(target.into(), definition, target_ty); } fn infer_named_expression(&mut self, named: &ast::ExprNamed) -> Type<'db> { - let definition = self.index.definition(named); - let result = infer_definition_types(self.db, definition); - self.extend(result); - result.binding_ty(definition) + if let Some(definition) = self.index.definition(named) { + let result = infer_definition_types(self.db, definition); + self.extend(result); + result.binding_ty(definition) + } else { + tracing::warn!("Couldn't find definition"); + Type::Unknown + } } fn infer_named_expression_definition( @@ -2258,11 +2340,13 @@ impl<'db> TypeInferenceBuilder<'db> { match ctx { ExprContext::Load => { let use_def = self.index.use_def_map(file_scope_id); - let symbol = self - .index - .symbol_table(file_scope_id) - .symbol_id_by_name(id) - .expect("Expected the symbol table to create a symbol for every Name node"); + let Some(symbol) = self.index.symbol_table(file_scope_id).symbol_id_by_name(id) + else { + tracing::warn!( + "Expected the symbol table to create a symbol for every Name node" + ); + return Type::Unknown; + }; // if we're inferring types of deferred expressions, always treat them as public symbols let (definitions, may_be_unbound) = if self.is_deferred() { ( @@ -2270,11 +2354,15 @@ impl<'db> TypeInferenceBuilder<'db> { use_def.public_may_be_unbound(symbol), ) } else { - let use_id = name.scoped_use_id(self.db, self.scope); - ( - use_def.bindings_at_use(use_id), - use_def.use_may_be_unbound(use_id), - ) + if let Some(use_id) = name.scoped_use_id(self.db, self.scope) { + ( + use_def.bindings_at_use(use_id), + use_def.use_may_be_unbound(use_id), + ) + } else { + tracing::warn!("Failed to find name"); + return Type::Unknown; + } }; let unbound_ty = if may_be_unbound { @@ -2521,8 +2609,8 @@ impl<'db> TypeInferenceBuilder<'db> { .tuple_windows::<(_, _)>() .zip(ops.iter()) .map(|((left, right), op)| { - let left_ty = self.expression_ty(left); - let right_ty = self.expression_ty(right); + let left_ty = self.try_expression_ty(left).unwrap_or(Type::Unknown); + let right_ty = self.try_expression_ty(right).unwrap_or(Type::Unknown); self.infer_binary_type_comparison(left_ty, *op, right_ty) .unwrap_or_else(|| { @@ -2990,11 +3078,13 @@ impl<'db> TypeInferenceBuilder<'db> { let ty = match expression { ast::Expr::Name(name) => { - debug_assert!( - name.ctx.is_load(), - "name in a type expression is always 'load' but got: '{:?}'", - name.ctx - ); + if !name.ctx.is_load() { + tracing::warn!( + "name in a type expression is always 'load' but got: '{:?}'", + name.ctx + ); + return Type::Unknown; + } self.infer_name_expression(name).to_instance(self.db) } @@ -3122,9 +3212,15 @@ impl<'db> TypeInferenceBuilder<'db> { ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"), }; - let expr_id = expression.scoped_ast_id(self.db, self.scope); - let previous = self.types.expressions.insert(expr_id, ty); - assert!(previous.is_none()); + match expression.scoped_ast_id(self.db, self.scope) { + Some(expr_id) => { + let previous = self.types.expressions.insert(expr_id, ty); + assert!(previous.is_none()); + } + None => { + tracing::warn!("Could not find AST ID for expression"); + } + } ty } diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 8ca57af1168bc..d710697259dc0 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -154,13 +154,22 @@ impl<'db> NarrowingConstraintsBuilder<'db> { let scope = self.scope(); let inference = infer_expression_types(self.db, expression); for (op, comparator) in std::iter::zip(&**ops, &**comparators) { - let comp_ty = inference.expression_ty(comparator.scoped_ast_id(self.db, scope)); - if matches!(op, ast::CmpOp::IsNot) { - let ty = IntersectionBuilder::new(self.db) - .add_negative(comp_ty) - .build(); - self.constraints.insert(symbol, ty); - }; + match comparator.scoped_ast_id(self.db, scope) { + Some(comparator_id) => { + let comp_ty = inference + .try_expression_ty(comparator_id) + .unwrap_or(Type::Unknown); + if matches!(op, ast::CmpOp::IsNot) { + let ty = IntersectionBuilder::new(self.db) + .add_negative(comp_ty) + .build(); + self.constraints.insert(symbol, ty); + }; + } + None => { + tracing::warn!("Can't find ID for comparator"); + } + } // TODO other comparison types } }