diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 56df1c44d9adee..71e15f03dc2e70 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -38,8 +38,8 @@ pub(super) struct SemanticIndexBuilder<'db> { file: File, module: &'db ParsedModule, scope_stack: Vec, - /// The assignment we're currently visiting. - current_assignment: Option>, + /// The assignments we're currently visiting. + current_assignments: Vec>, /// The match case we're currently visiting. current_match_case: Option>, /// Flow states at each `break` in the current loop. @@ -67,7 +67,7 @@ impl<'db> SemanticIndexBuilder<'db> { file, module: parsed, scope_stack: Vec::new(), - current_assignment: None, + current_assignments: vec![], current_match_case: None, loop_break_states: vec![], @@ -359,12 +359,13 @@ impl<'db> SemanticIndexBuilder<'db> { self.visit_expr(&generator.iter); self.push_scope(scope); - self.current_assignment = Some(CurrentAssignment::Comprehension { - node: generator, - first: true, - }); + self.current_assignments + .push(CurrentAssignment::Comprehension { + node: generator, + first: true, + }); self.visit_expr(&generator.target); - self.current_assignment = None; + debug_assert!(self.current_assignments.pop().is_some()); for expr in &generator.ifs { self.visit_expr(expr); @@ -374,12 +375,13 @@ impl<'db> SemanticIndexBuilder<'db> { self.add_standalone_expression(&generator.iter); self.visit_expr(&generator.iter); - self.current_assignment = Some(CurrentAssignment::Comprehension { - node: generator, - first: false, - }); + self.current_assignments + .push(CurrentAssignment::Comprehension { + node: generator, + first: false, + }); self.visit_expr(&generator.target); - self.current_assignment = None; + debug_assert!(self.current_assignments.pop().is_some()); for expr in &generator.ifs { self.visit_expr(expr); @@ -415,7 +417,7 @@ impl<'db> SemanticIndexBuilder<'db> { self.pop_scope(); assert!(self.scope_stack.is_empty()); - assert!(self.current_assignment.is_none()); + assert_eq!(self.current_assignments.len(), 0); let mut symbol_tables: IndexVec<_, _> = self .symbol_tables @@ -563,24 +565,22 @@ where } } ast::Stmt::Assign(node) => { - debug_assert!(self.current_assignment.is_none()); self.visit_expr(&node.value); self.add_standalone_expression(&node.value); - self.current_assignment = Some(node.into()); + self.current_assignments.push(node.into()); for target in &node.targets { self.visit_expr(target); } - self.current_assignment = None; + debug_assert!(self.current_assignments.pop().is_some()); } ast::Stmt::AnnAssign(node) => { - debug_assert!(self.current_assignment.is_none()); self.visit_expr(&node.annotation); if let Some(value) = &node.value { self.visit_expr(value); } - self.current_assignment = Some(node.into()); + self.current_assignments.push(node.into()); self.visit_expr(&node.target); - self.current_assignment = None; + debug_assert!(self.current_assignments.pop().is_some()); } ast::Stmt::AugAssign( aug_assign @ ast::StmtAugAssign { @@ -590,11 +590,10 @@ where value, }, ) => { - debug_assert!(self.current_assignment.is_none()); self.visit_expr(value); - self.current_assignment = Some(aug_assign.into()); + self.current_assignments.push(aug_assign.into()); self.visit_expr(target); - self.current_assignment = None; + debug_assert!(self.current_assignments.pop().is_some()); } ast::Stmt::If(node) => { self.visit_expr(&node.test); @@ -662,9 +661,9 @@ where self.visit_expr(&item.context_expr); if let Some(optional_vars) = item.optional_vars.as_deref() { self.add_standalone_expression(&item.context_expr); - self.current_assignment = Some(item.into()); + self.current_assignments.push(item.into()); self.visit_expr(optional_vars); - self.current_assignment = None; + self.current_assignments.pop(); } } self.visit_body(body); @@ -689,10 +688,9 @@ where let pre_loop = self.flow_snapshot(); let saved_break_states = std::mem::take(&mut self.loop_break_states); - debug_assert!(self.current_assignment.is_none()); - self.current_assignment = Some(for_stmt.into()); + self.current_assignments.push(for_stmt.into()); self.visit_expr(target); - self.current_assignment = None; + debug_assert!(self.current_assignments.pop().is_some()); // TODO: Definitions created by loop variables // (and definitions created inside the body) @@ -802,7 +800,7 @@ where match expr { ast::Expr::Name(name_node @ ast::ExprName { id, ctx, .. }) => { - let (is_use, is_definition) = match (ctx, self.current_assignment) { + let (is_use, is_definition) = match (ctx, self.current_assignments.last()) { (ast::ExprContext::Store, Some(CurrentAssignment::AugAssign(_))) => { // For augmented assignment, the target expression is also used. (true, true) @@ -813,8 +811,9 @@ where (ast::ExprContext::Invalid, _) => (false, false), }; let symbol = self.add_symbol(id.clone()); + if is_definition { - match self.current_assignment { + match self.current_assignments.last().copied() { Some(CurrentAssignment::Assign(assignment)) => { self.add_definition( symbol, @@ -879,12 +878,11 @@ where walk_expr(self, expr); } ast::Expr::Named(node) => { - debug_assert!(self.current_assignment.is_none()); // TODO walrus in comprehensions is implicitly nonlocal self.visit_expr(&node.value); - self.current_assignment = Some(node.into()); + self.current_assignments.push(node.into()); self.visit_expr(&node.target); - self.current_assignment = None; + debug_assert!(self.current_assignments.pop().is_some()); } ast::Expr::Lambda(lambda) => { if let Some(parameters) = &lambda.parameters { diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index f894d6dc24133a..149801c8fd997e 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -3433,6 +3433,23 @@ mod tests { Ok(()) } + #[test] + fn test_assignment_in_assignment() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + x = [0] + x[0 if y := 2 else 1] = 1 + ", + )?; + + assert_file_diagnostics(&db, "/src/a.py", &[]); + + Ok(()) + } + #[test] fn follow_import_to_class() -> anyhow::Result<()> { let mut db = setup_db();