diff --git a/corelib/src/test/language_features/while_test.cairo b/corelib/src/test/language_features/while_test.cairo index d671ca0c91f..700d3c4d093 100644 --- a/corelib/src/test/language_features/while_test.cairo +++ b/corelib/src/test/language_features/while_test.cairo @@ -23,3 +23,34 @@ fn test_outer_loop_break() { }; assert_eq!(i, 10); } + +#[test] +fn test_borrow_usage() { + let mut i = 0; + let arr = array![1, 2, 3, 4]; + while i != arr.len() { + i += 1; + }; + assert_eq!(arr.len(), 4); +} + +#[derive(Drop)] +struct NonCopy { + x: felt252, +} + +fn assert_x_eq(a: @NonCopy, x: felt252) { + assert_eq!(a.x, @x); +} + +#[test] +fn test_borrow_with_inner_change() { + let mut a = NonCopy { x: 0 }; + let mut i = 0; + while i != 5 { + a.x = i; + assert_x_eq(@a, i); + i += 1; + }; +} + diff --git a/crates/cairo-lang-lowering/src/lower/block_builder.rs b/crates/cairo-lang-lowering/src/lower/block_builder.rs index 762f34f03d0..279806d781a 100644 --- a/crates/cairo-lang-lowering/src/lower/block_builder.rs +++ b/crates/cairo-lang-lowering/src/lower/block_builder.rs @@ -1,6 +1,7 @@ use cairo_lang_defs::ids::{MemberId, NamedLanguageElementId}; use cairo_lang_diagnostics::Maybe; use cairo_lang_semantic as semantic; +use cairo_lang_semantic::types::{peel_snapshots, wrap_in_snapshots}; use cairo_lang_syntax::node::TypedStablePtr; use cairo_lang_utils::ordered_hash_map::OrderedHashMap; use cairo_lang_utils::ordered_hash_set::OrderedHashSet; @@ -26,6 +27,8 @@ use crate::{ pub struct BlockBuilder { /// A store for semantic variables, owning their OwnedVariable instances. pub semantics: SemanticLoweringMapping, + /// The semantic variables that are captured as snapshots in this block. + pub snapped_semantics: OrderedHashMap, /// The semantic variables that are added/changed in this block. changed_member_paths: OrderedHashSet, /// Current sequence of lowered statements emitted. @@ -38,6 +41,7 @@ impl BlockBuilder { pub fn root(_ctx: &mut LoweringContext<'_, '_>, block_id: BlockId) -> Self { BlockBuilder { semantics: Default::default(), + snapped_semantics: Default::default(), changed_member_paths: Default::default(), statements: Default::default(), block_id, @@ -48,6 +52,7 @@ impl BlockBuilder { pub fn child_block_builder(&self, block_id: BlockId) -> BlockBuilder { BlockBuilder { semantics: self.semantics.clone(), + snapped_semantics: self.snapped_semantics.clone(), changed_member_paths: Default::default(), statements: Default::default(), block_id, @@ -59,6 +64,7 @@ impl BlockBuilder { pub fn sibling_block_builder(&self, block_id: BlockId) -> BlockBuilder { BlockBuilder { semantics: self.semantics.clone(), + snapped_semantics: self.snapped_semantics.clone(), changed_member_paths: self.changed_member_paths.clone(), statements: Default::default(), block_id, @@ -119,6 +125,49 @@ impl BlockBuilder { .map(|var_id| VarUsage { var_id, location }) } + /// Updates the reference of a semantic variable to a snapshot of its lowered variable. + pub fn update_snap_ref(&mut self, member_path: &ExprVarMemberPath, var: VariableId) { + self.snapped_semantics.insert(member_path.into(), var); + } + + /// Gets the reference of a snapshot of semantic variable, possibly by deconstructing a + /// its parents. + pub fn get_snap_ref( + &mut self, + ctx: &mut LoweringContext<'_, '_>, + member_path: &ExprVarMemberPath, + ) -> Option { + let location = ctx.get_location(member_path.stable_ptr().untyped()); + if let Some(var_id) = self.snapped_semantics.get::(&member_path.into()) { + return Some(VarUsage { var_id: *var_id, location }); + } + let ExprVarMemberPath::Member { parent, member_id, concrete_struct_id, .. } = member_path + else { + return None; + }; + // TODO(TomerStarkware): Consider adding the result to snap_semantics to avoid + // recomputation. + let parent_var = self.get_snap_ref(ctx, parent)?; + let members = ctx.db.concrete_struct_members(*concrete_struct_id).ok()?; + let (parent_number_of_snapshots, _) = + peel_snapshots(ctx.db.upcast(), ctx.variables[parent_var.var_id].ty); + let member_idx = members.iter().position(|(_, member)| member.id == *member_id)?; + Some( + generators::StructMemberAccess { + input: parent_var, + member_tys: members + .into_iter() + .map(|(_, member)| { + wrap_in_snapshots(ctx.db.upcast(), member.ty, parent_number_of_snapshots) + }) + .collect(), + member_idx, + location, + } + .add(ctx, &mut self.statements), + ) + } + /// Gets the type of a semantic variable. pub fn get_ty( &mut self, diff --git a/crates/cairo-lang-lowering/src/lower/context.rs b/crates/cairo-lang-lowering/src/lower/context.rs index 8e7de0fa071..a6d3b701263 100644 --- a/crates/cairo-lang-lowering/src/lower/context.rs +++ b/crates/cairo-lang-lowering/src/lower/context.rs @@ -13,9 +13,8 @@ use cairo_lang_utils::Intern; use defs::diagnostic_utils::StableLocation; use id_arena::Arena; use itertools::{zip_eq, Itertools}; -use semantic::corelib::{core_module, get_ty_by_name, get_usize_ty}; +use semantic::corelib::{core_module, get_ty_by_name}; use semantic::expr::inference::InferenceError; -use semantic::items::constant::value_as_const_value; use semantic::types::wrap_in_snapshots; use semantic::{ExprVarMemberPath, MatchArmSelector, TypeLongId}; use {cairo_lang_defs as defs, cairo_lang_semantic as semantic}; @@ -300,17 +299,7 @@ impl LoweredExpr { LoweredExpr::Snapshot { expr, .. } => { wrap_in_snapshots(ctx.db.upcast(), expr.ty(ctx), 1) } - LoweredExpr::FixedSizeArray { exprs, .. } => semantic::TypeLongId::FixedSizeArray { - type_id: exprs[0].ty(ctx), - size: value_as_const_value( - ctx.db.upcast(), - get_usize_ty(ctx.db.upcast()), - &exprs.len().into(), - ) - .unwrap() - .intern(ctx.db), - } - .intern(ctx.db), + LoweredExpr::FixedSizeArray { ty, .. } => *ty, } } pub fn location(&self) -> LocationId { diff --git a/crates/cairo-lang-lowering/src/lower/mod.rs b/crates/cairo-lang-lowering/src/lower/mod.rs index c4632247ea4..ebccdaf2603 100644 --- a/crates/cairo-lang-lowering/src/lower/mod.rs +++ b/crates/cairo-lang-lowering/src/lower/mod.rs @@ -3,9 +3,9 @@ use std::vec; use block_builder::BlockBuilder; use cairo_lang_debug::DebugWithDb; use cairo_lang_diagnostics::{Diagnostics, Maybe}; -use cairo_lang_semantic::corelib::{self, unwrap_error_propagation_type, ErrorPropagationType}; +use cairo_lang_semantic::corelib::{unwrap_error_propagation_type, ErrorPropagationType}; use cairo_lang_semantic::db::SemanticGroup; -use cairo_lang_semantic::{LocalVariable, VarId}; +use cairo_lang_semantic::{corelib, ExprVar, LocalVariable, VarId}; use cairo_lang_syntax::node::ids::SyntaxStablePtrId; use cairo_lang_syntax::node::TypedStablePtr; use cairo_lang_utils::ordered_hash_map::OrderedHashMap; @@ -27,6 +27,7 @@ use semantic::{ ExprFunctionCallArg, ExprId, ExprPropagateError, ExprVarMemberPath, GenericArgumentId, MatchArmSelector, SemanticDiagnostic, TypeLongId, }; +use usage::MemberPath; use {cairo_lang_defs as defs, cairo_lang_semantic as semantic}; use self::block_builder::SealedBlockBuilder; @@ -413,6 +414,7 @@ pub fn lower_loop_function( function_id: FunctionWithBodyId, loop_signature: Signature, loop_expr_id: semantic::ExprId, + snapped_params: &OrderedHashMap, ) -> Maybe { let mut ctx = LoweringContext::new(encapsulating_ctx, function_id, loop_signature.clone())?; let old_loop_expr_id = std::mem::replace(&mut ctx.current_loop_expr_id, Some(loop_expr_id)); @@ -429,7 +431,11 @@ pub fn lower_loop_function( .map(|param| { let location = ctx.get_location(param.stable_ptr().untyped()); let var = ctx.new_var(VarRequest { ty: param.ty(), location }); - builder.semantics.introduce((¶m).into(), var); + if snapped_params.contains_key::(&(¶m).into()) { + builder.update_snap_ref(¶m, var) + } else { + builder.semantics.introduce((¶m).into(), var); + } var }) .collect_vec(); @@ -1116,8 +1122,31 @@ fn lower_expr_snapshot( builder: &mut BlockBuilder, ) -> LoweringResult { log::trace!("Lowering a snapshot: {:?}", expr.debug(&ctx.expr_formatter)); + // If the inner expression is a variable, or a member access, and we already have a snapshot var + // we can use it without creating a new one. + // Note that in a closure we might only have a snapshot of the variable and not the original. + match &ctx.function_body.exprs[expr.inner] { + semantic::Expr::Var(expr_var) => { + let member_path = ExprVarMemberPath::Var(expr_var.clone()); + if let Some(var) = builder.get_snap_ref(ctx, &member_path) { + return Ok(LoweredExpr::AtVariable(var)); + } + } + semantic::Expr::MemberAccess(expr) => { + if let Some(var) = expr + .member_path + .clone() + .and_then(|member_path| builder.get_snap_ref(ctx, &member_path)) + { + return Ok(LoweredExpr::AtVariable(var)); + } + } + _ => {} + } + let lowered = lower_expr(ctx, builder, expr.inner)?; + let location = ctx.get_location(expr.stable_ptr.untyped()); - let expr = Box::new(lower_expr(ctx, builder, expr.inner)?); + let expr = Box::new(lowered); Ok(LoweredExpr::Snapshot { expr, location }) } @@ -1348,7 +1377,26 @@ fn lower_expr_loop( let usage = &ctx.block_usages.block_usages[&loop_expr_id]; // Determine signature. - let params = usage.usage.iter().map(|(_, expr)| expr.clone()).collect_vec(); + let params = usage + .usage + .iter() + .map(|(_, expr)| expr.clone()) + .chain(usage.snap_usage.iter().map(|(_, expr)| match expr { + ExprVarMemberPath::Var(var) => ExprVarMemberPath::Var(ExprVar { + ty: wrap_in_snapshots(ctx.db.upcast(), var.ty, 1), + ..*var + }), + ExprVarMemberPath::Member { parent, member_id, stable_ptr, concrete_struct_id, ty } => { + ExprVarMemberPath::Member { + parent: parent.clone(), + member_id: *member_id, + stable_ptr: *stable_ptr, + concrete_struct_id: *concrete_struct_id, + ty: wrap_in_snapshots(ctx.db.upcast(), *ty, 1), + } + } + })) + .collect_vec(); let extra_rets = usage.changes.iter().map(|(_, expr)| expr.clone()).collect_vec(); let loop_signature = Signature { @@ -1367,17 +1415,37 @@ fn lower_expr_loop( } .intern(ctx.db); + let snap_usage = ctx.block_usages.block_usages[&loop_expr_id].snap_usage.clone(); + // Generate the function. let encapsulating_ctx = std::mem::take(&mut ctx.encapsulating_ctx).unwrap(); - let lowered = - lower_loop_function(encapsulating_ctx, function, loop_signature.clone(), loop_expr_id) - .map_err(LoweringFlowError::Failed)?; + let lowered = lower_loop_function( + encapsulating_ctx, + function, + loop_signature.clone(), + loop_expr_id, + &snap_usage, + ) + .map_err(LoweringFlowError::Failed)?; // TODO(spapini): Recursive call. encapsulating_ctx.lowerings.insert(loop_expr_id, lowered); - ctx.encapsulating_ctx = Some(encapsulating_ctx); let old_loop_expr_id = std::mem::replace(&mut ctx.current_loop_expr_id, Some(loop_expr_id)); + for snapshot_param in snap_usage.values() { + // if we have access to the real member we generate a snapshot, otherwise it should be + // accessible with `builder.get_snap_ref` + if let Some(input) = builder.get_ref(ctx, snapshot_param) { + let (original, snapped) = generators::Snapshot { + input, + location: ctx.get_location(snapshot_param.stable_ptr().untyped()), + } + .add(ctx, &mut builder.statements); + builder.update_snap_ref(snapshot_param, snapped); + builder.update_ref(ctx, snapshot_param, original); + } + } let call = call_loop_func(ctx, loop_signature, builder, loop_expr_id, stable_ptr.untyped()); + ctx.current_loop_expr_id = old_loop_expr_id; call } @@ -1402,9 +1470,17 @@ fn call_loop_func( .params .into_iter() .map(|param| { - builder.get_ref(ctx, ¶m).ok_or_else(|| { - LoweringFlowError::Failed(ctx.diagnostics.report(stable_ptr, MemberPathLoop)) - }) + builder + .get_ref(ctx, ¶m) + .and_then(|var| (ctx.variables[var.var_id].ty == param.ty()).then_some(var)) + .or_else(|| { + let var = builder.get_snap_ref(ctx, ¶m)?; + (ctx.variables[var.var_id].ty == param.ty()).then_some(var) + }) + .ok_or_else(|| { + // TODO(TomerStaskware): remove this error. + LoweringFlowError::Failed(ctx.diagnostics.report(stable_ptr, MemberPathLoop)) + }) }) .collect::>>()?; let extra_ret_tys = loop_signature.extra_rets.iter().map(|path| path.ty()).collect_vec(); diff --git a/crates/cairo-lang-lowering/src/lower/test_data/loop b/crates/cairo-lang-lowering/src/lower/test_data/loop index c68d69df474..618ebc85a16 100644 --- a/crates/cairo-lang-lowering/src/lower/test_data/loop +++ b/crates/cairo-lang-lowering/src/lower/test_data/loop @@ -910,3 +910,582 @@ Statements: (v20: core::panics::PanicResult::<(core::integer::u8, ())>) <- PanicResult::Err(v19) End: Return(v5, v6, v20) + +//! > ========================================================================== + +//! > Test snap usage after loop. + +//! > test_runner_name +test_generated_function + +//! > function +fn foo() -> bool { + let mut s = S {}; + loop { + s.foo(); + break; + }; + s.foo(); + false +} + +//! > function_name +foo + +//! > module_code +#[derive(Drop)] +struct S {} +trait T { + fn foo(self: @S); +} +impl I of T { + fn foo(self: @S) {} +} + +//! > semantic_diagnostics + +//! > lowering_diagnostics + +//! > lowering_flat +Parameters: + +//! > lowering +Main: +Parameters: +blk0 (root): +Statements: + (v0: test::S) <- struct_construct() + (v1: test::S, v2: @test::S) <- snapshot(v0) + (v3: ()) <- test::foo[expr5](v2) + (v4: ()) <- test::I::foo(v2) + (v5: ()) <- struct_construct() + (v6: core::bool) <- bool::False(v5) +End: + Return(v6) + + +Final lowering: +Parameters: +blk0 (root): +Statements: + (v0: ()) <- struct_construct() + (v1: core::bool) <- bool::False(v0) +End: + Return(v1) + + +Generated lowering for source location: + loop { + ^****^ + +Parameters: v0: @test::S +blk0 (root): +Statements: + (v1: ()) <- test::I::foo(v0) + (v2: ()) <- struct_construct() +End: + Return(v2) + + +Final lowering: +Parameters: v0: @test::S +blk0 (root): +Statements: +End: + Return() + +//! > ========================================================================== + +//! > Test snap usage after loop of member. + +//! > test_runner_name +test_generated_function + +//! > function +fn foo() { + let t = T { s: S {} }; + loop { + TT::f1oo(@t.s); + break; + }; + TT::f1oo(@t.s); +} + +//! > function_name +foo + +//! > module_code +#[derive(Drop)] +struct S {} +#[derive(Drop)] +struct T { + s: S +} +trait TT { + fn f1oo(self: @S); +} +impl STT of TT { + fn f1oo(self: @S) {} +} + +//! > semantic_diagnostics + +//! > lowering_diagnostics + +//! > lowering_flat +Parameters: + +//! > lowering +Main: +Parameters: +blk0 (root): +Statements: + (v0: test::S) <- struct_construct() + (v1: test::T) <- struct_construct(v0) + (v2: test::S) <- struct_destructure(v1) + (v3: test::S, v4: @test::S) <- snapshot(v2) + (v5: ()) <- test::foo[expr7](v4) + (v6: ()) <- test::STT::f1oo(v4) + (v7: ()) <- struct_construct() +End: + Return(v7) + + +Final lowering: +Parameters: +blk0 (root): +Statements: +End: + Return() + + +Generated lowering for source location: + loop { + ^****^ + +Parameters: v0: @test::S +blk0 (root): +Statements: + (v1: ()) <- test::STT::f1oo(v0) + (v2: ()) <- struct_construct() +End: + Return(v2) + + +Final lowering: +Parameters: v0: @test::S +blk0 (root): +Statements: +End: + Return() + +//! > ========================================================================== + +//! > Test real usage of inner with snap usage of outer. + +//! > test_runner_name +test_generated_function + +//! > function +fn foo() { + let a = A { b: B { c: 3 } }; + loop { + let _x = @a.b; + ex1(_x); + loop { + let _y = a.b.c; + ex(_y); + break; + }; + break; + }; +} + +//! > function_name +foo + +//! > module_code +extern fn ex(a: u32) nopanic; +extern fn ex1(a: @B) nopanic; +#[derive(Drop)] +struct B { + c: u32 +} +#[derive(Drop)] +struct A { + b: B +} + +//! > semantic_diagnostics + +//! > lowering_diagnostics + +//! > lowering_flat +Parameters: + +//! > lowering +Main: +Parameters: +blk0 (root): +Statements: + (v0: core::integer::u32) <- 3 + (v1: test::B) <- struct_construct(v0) + (v2: test::A) <- struct_construct(v1) + (v3: test::B) <- struct_destructure(v2) + (v4: test::B, v5: @test::B) <- snapshot(v3) + (v6: core::integer::u32) <- struct_destructure(v4) + (v7: test::B) <- struct_construct(v6) + (v8: ()) <- test::foo[expr16](v6, v5) + (v9: ()) <- struct_construct() +End: + Return(v9) + + +Final lowering: +Parameters: +blk0 (root): +Statements: + (v0: core::integer::u32) <- 3 + (v1: test::B) <- struct_construct(v0) + (v2: test::B, v3: @test::B) <- snapshot(v1) + () <- test::ex1(v3) + (v4: core::integer::u32) <- struct_destructure(v2) + () <- test::ex(v4) +End: + Return() + + +Generated lowering for source location: + loop { + ^****^ + +Parameters: v0: core::integer::u32 +blk0 (root): +Statements: + () <- test::ex(v0) + (v1: ()) <- struct_construct() +End: + Return(v1) + + +Final lowering: +Parameters: v0: core::integer::u32 +blk0 (root): +Statements: + () <- test::ex(v0) +End: + Return() + + +Generated lowering for source location: + loop { + ^****^ + +Parameters: v0: core::integer::u32, v1: @test::B +blk0 (root): +Statements: + () <- test::ex1(v1) + (v2: ()) <- test::foo[expr14](v0) + (v3: ()) <- struct_construct() +End: + Return(v3) + + +Final lowering: +Parameters: v0: core::integer::u32, v1: @test::B +blk0 (root): +Statements: + () <- test::ex1(v1) + () <- test::ex(v0) +End: + Return() + +//! > ========================================================================== + +//! > Test snap usage of inner with real usage of outer. + +//! > test_runner_name +test_generated_function + +//! > function +fn foo() { + let a = A { b: B { c: 3 } }; + loop { + let _x = a.b; + ex1(_x); + loop { + let _y = @a.b.c; + ex(_y); + break; + }; + break; + }; +} + +//! > function_name +foo + +//! > module_code +extern fn ex(a: @u32) nopanic; +extern fn ex1(a: B) nopanic; +#[derive(Drop, Copy)] +struct B { + c: u32 +} +#[derive(Drop)] +struct A { + b: B +} + +//! > semantic_diagnostics + +//! > lowering_diagnostics + +//! > lowering_flat +Parameters: + +//! > lowering +Main: +Parameters: +blk0 (root): +Statements: + (v0: core::integer::u32) <- 3 + (v1: test::B) <- struct_construct(v0) + (v2: test::A) <- struct_construct(v1) + (v3: test::B) <- struct_destructure(v2) + (v4: ()) <- test::foo[expr16](v3) + (v5: ()) <- struct_construct() +End: + Return(v5) + + +Final lowering: +Parameters: +blk0 (root): +Statements: + (v0: core::integer::u32) <- 3 + (v1: test::B) <- struct_construct(v0) + () <- test::ex1(v1) + (v2: core::integer::u32, v3: @core::integer::u32) <- snapshot(v0) + () <- test::ex(v3) +End: + Return() + + +Generated lowering for source location: + loop { + ^****^ + +Parameters: v0: @core::integer::u32 +blk0 (root): +Statements: + () <- test::ex(v0) + (v1: ()) <- struct_construct() +End: + Return(v1) + + +Final lowering: +Parameters: v0: @core::integer::u32 +blk0 (root): +Statements: + () <- test::ex(v0) +End: + Return() + + +Generated lowering for source location: + loop { + ^****^ + +Parameters: v0: test::B +blk0 (root): +Statements: + () <- test::ex1(v0) + (v1: core::integer::u32) <- struct_destructure(v0) + (v2: core::integer::u32, v3: @core::integer::u32) <- snapshot(v1) + (v4: ()) <- test::foo[expr14](v3) + (v5: ()) <- struct_construct() +End: + Return(v5) + + +Final lowering: +Parameters: v0: test::B +blk0 (root): +Statements: + () <- test::ex1(v0) + (v1: core::integer::u32) <- struct_destructure(v0) + (v2: core::integer::u32, v3: @core::integer::u32) <- snapshot(v1) + () <- test::ex(v3) +End: + Return() + +//! > ========================================================================== + +//! > Test change usage of inner with snap usage of outer. + +//! > test_runner_name +test_generated_function + +//! > function +fn foo() { + let mut a = A { x: 0 }; + let mut i = 0; + while i != 5 { + a.x = i; + use_a(@a); + i += 1; + }; +} + +//! > function_name +foo + +//! > module_code +#[derive(Drop)] +struct A { + x: felt252, +} +extern fn use_a(a: @A) nopanic; + +//! > semantic_diagnostics + +//! > lowering_diagnostics + +//! > lowering_flat +Parameters: + +//! > lowering +Main: +Parameters: +blk0 (root): +Statements: + (v0: core::felt252) <- 0 + (v1: test::A) <- struct_construct(v0) + (v2: core::felt252) <- 0 + (v4: test::A, v5: core::felt252, v3: ()) <- test::foo[expr19](v2, v1) + (v6: ()) <- struct_construct() +End: + Return(v6) + + +Final lowering: +Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin +blk0 (root): +Statements: + (v2: core::felt252) <- 0 + (v3: core::felt252) <- 0 + (v4: test::A) <- struct_construct(v2) + (v5: core::RangeCheck, v6: core::gas::GasBuiltin, v7: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[expr19](v0, v1, v3, v4) +End: + Match(match_enum(v7) { + PanicResult::Ok(v8) => blk1, + PanicResult::Err(v9) => blk2, + }) + +blk1: +Statements: + (v10: ()) <- struct_construct() + (v11: ((),)) <- struct_construct(v10) + (v12: core::panics::PanicResult::<((),)>) <- PanicResult::Ok(v11) +End: + Return(v5, v6, v12) + +blk2: +Statements: + (v13: core::panics::PanicResult::<((),)>) <- PanicResult::Err(v9) +End: + Return(v5, v6, v13) + + +Generated lowering for source location: + while i != 5 { + ^************^ + +Parameters: v0: core::felt252, v1: test::A +blk0 (root): +Statements: + (v2: core::felt252, v3: @core::felt252) <- snapshot(v0) + (v4: core::felt252) <- 5 + (v5: core::felt252, v6: @core::felt252) <- snapshot(v4) + (v7: core::bool) <- core::Felt252PartialEq::ne(v3, v6) +End: + Match(match_enum(v7) { + bool::False(v19) => blk2, + bool::True(v8) => blk1, + }) + +blk1: +Statements: + (v9: core::felt252) <- struct_destructure(v1) + (v10: test::A) <- struct_construct(v2) + (v11: test::A, v12: @test::A) <- snapshot(v10) + () <- test::use_a(v12) + (v13: core::felt252) <- 1 + (v15: core::felt252, v14: ()) <- core::ops::arith::DeprecatedAddAssign::::add_assign(v2, v13) + (v17: test::A, v18: core::felt252, v16: ()) <- test::foo[expr19](v15, v11) +End: + Goto(blk3, {v17 -> v21, v18 -> v22, v16 -> v20}) + +blk2: +Statements: + (v23: ()) <- struct_construct() +End: + Goto(blk3, {v1 -> v21, v2 -> v22, v23 -> v20}) + +blk3: +Statements: +End: + Return(v21, v22, v20) + + +Final lowering: +Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin, v2: core::felt252, v3: test::A +blk0 (root): +Statements: +End: + Match(match core::gas::withdraw_gas(v0, v1) { + Option::Some(v4, v5) => blk1, + Option::None(v6, v7) => blk4, + }) + +blk1: +Statements: + (v8: core::felt252) <- 5 + (v9: core::felt252) <- core::felt252_sub(v2, v8) +End: + Match(match core::felt252_is_zero(v9) { + IsZeroResult::Zero => blk2, + IsZeroResult::NonZero(v10) => blk3, + }) + +blk2: +Statements: + (v11: ()) <- struct_construct() + (v12: (test::A, core::felt252, ())) <- struct_construct(v3, v2, v11) + (v13: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- PanicResult::Ok(v12) +End: + Return(v4, v5, v13) + +blk3: +Statements: + (v14: test::A) <- struct_construct(v2) + (v15: test::A, v16: @test::A) <- snapshot(v14) + () <- test::use_a(v16) + (v17: core::felt252) <- 1 + (v18: core::felt252) <- core::felt252_add(v2, v17) + (v19: core::RangeCheck, v20: core::gas::GasBuiltin, v21: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[expr19](v4, v5, v18, v15) +End: + Return(v19, v20, v21) + +blk4: +Statements: + (v22: core::array::Array::) <- core::array::array_new::() + (v23: core::felt252) <- 375233589013918064796019 + (v24: core::array::Array::) <- core::array::array_append::(v22, v23) + (v25: core::panics::Panic) <- struct_construct() + (v26: (core::panics::Panic, core::array::Array::)) <- struct_construct(v25, v24) + (v27: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- PanicResult::Err(v26) +End: + Return(v6, v7, v27) diff --git a/crates/cairo-lang-lowering/src/lower/test_data/usage b/crates/cairo-lang-lowering/src/lower/test_data/usage index 35daae3ea2c..b523a8c1892 100644 --- a/crates/cairo-lang-lowering/src/lower/test_data/usage +++ b/crates/cairo-lang-lowering/src/lower/test_data/usage @@ -20,12 +20,14 @@ foo //! > usage Block 2:4: - Usage: ParamId(test::b), - Changes: - Introductions: + Usage: ParamId(test::b), + Changes: + Snapshot_Usage: + Introductions: Block 0:27: - Usage: ParamId(test::a), ParamId(test::b), - Changes: + Usage: ParamId(test::a), ParamId(test::b), + Changes: + Snapshot_Usage: Introductions: //! > ========================================================================== @@ -66,20 +68,24 @@ struct B { //! > usage Block 11:16: - Usage: LocalVarId(test::h), ParamId(test::b), ParamId(test::a)::b, - Changes: LocalVarId(test::h), ParamId(test::b), - Introductions: LocalVarId(test::x), + Usage: LocalVarId(test::h), ParamId(test::b), ParamId(test::a)::b, + Changes: LocalVarId(test::h), ParamId(test::b), + Snapshot_Usage: + Introductions: LocalVarId(test::x), Block 8:9: - Usage: LocalVarId(test::c), ParamId(test::a)::b, ParamId(test::b), - Changes: ParamId(test::a)::b::c, ParamId(test::b), - Introductions: LocalVarId(test::h), + Usage: LocalVarId(test::c), ParamId(test::a)::b, ParamId(test::b), + Changes: ParamId(test::a)::b::c, ParamId(test::b), + Snapshot_Usage: + Introductions: LocalVarId(test::h), Loop 8:4: - Usage: LocalVarId(test::c), ParamId(test::a)::b, ParamId(test::b), - Changes: ParamId(test::a)::b::c, ParamId(test::b), - Introductions: LocalVarId(test::h), + Usage: LocalVarId(test::c), ParamId(test::a)::b, ParamId(test::b), + Changes: ParamId(test::a)::b::c, ParamId(test::b), + Snapshot_Usage: + Introductions: LocalVarId(test::h), Block 6:27: - Usage: ParamId(test::b), ParamId(test::a)::b, - Changes: ParamId(test::a)::b::c, ParamId(test::b), + Usage: ParamId(test::b), ParamId(test::a)::b, + Changes: ParamId(test::a)::b::c, ParamId(test::b), + Snapshot_Usage: Introductions: LocalVarId(test::c), //! > ========================================================================== @@ -113,14 +119,75 @@ struct B { //! > usage Block 9:38: - Usage: LocalVarId(test::c), ParamId(test::a)::b::c, - Changes: ParamId(test::a)::b::c, - Introductions: + Usage: LocalVarId(test::c), ParamId(test::a)::b::c, + Changes: ParamId(test::a)::b::c, + Snapshot_Usage: + Introductions: While 9:4: - Usage: LocalVarId(test::only_used_in_condition), LocalVarId(test::c), ParamId(test::a)::b::c, - Changes: ParamId(test::a)::b::c, - Introductions: + Usage: LocalVarId(test::c), ParamId(test::a)::b::c, + Changes: ParamId(test::a)::b::c, + Snapshot_Usage: LocalVarId(test::only_used_in_condition), + Introductions: Block 6:27: - Usage: ParamId(test::a)::b::c, - Changes: ParamId(test::a)::b::c, + Usage: ParamId(test::a)::b::c, + Changes: ParamId(test::a)::b::c, + Snapshot_Usage: + Introductions: LocalVarId(test::c), LocalVarId(test::only_used_in_condition), + +//! > ========================================================================== + +//! > Test snap usage + +//! > test_runner_name +test_function_usage + +//! > function +fn foo(mut a: A, ref b: B) { + let c = 5_usize; + let only_used_in_condition = 5; + while only_used_in_condition != c { + { + borrow_B(@b); + borrow_B(@a.b); + } + let _ = c; + consume_B(b); + }; +} + +//! > function_name +foo + +//! > module_code +struct A { + b: B +} +struct B { + c: usize, +} +fn consume_B(b: B) {} +fn borrow_B(b: @B) {} + +//! > semantic_diagnostics + +//! > usage +Block 12:8: + Usage: + Changes: + Snapshot_Usage: ParamId(test::b), ParamId(test::a)::b, + Introductions: +Block 11:38: + Usage: LocalVarId(test::c), ParamId(test::b), + Changes: + Snapshot_Usage: ParamId(test::a)::b, + Introductions: +While 11:4: + Usage: LocalVarId(test::c), ParamId(test::b), + Changes: + Snapshot_Usage: LocalVarId(test::only_used_in_condition), ParamId(test::a)::b, + Introductions: +Block 8:27: + Usage: ParamId(test::b), + Changes: + Snapshot_Usage: ParamId(test::a)::b, Introductions: LocalVarId(test::c), LocalVarId(test::only_used_in_condition), diff --git a/crates/cairo-lang-lowering/src/lower/usage.rs b/crates/cairo-lang-lowering/src/lower/usage.rs index d4f76157f5c..187478fd909 100644 --- a/crates/cairo-lang-lowering/src/lower/usage.rs +++ b/crates/cairo-lang-lowering/src/lower/usage.rs @@ -56,6 +56,8 @@ pub struct Usage { pub usage: OrderedHashMap, /// Member paths that are assigned to. pub changes: OrderedHashMap, + /// Member paths that are read as snapshots. + pub snap_usage: OrderedHashMap, /// Variables that are defined. pub introductions: OrderedHashSet, } @@ -69,6 +71,9 @@ impl Usage { for (path, expr) in usage.changes.iter() { self.changes.insert(path.clone(), expr.clone()); } + for (path, expr) in usage.snap_usage.iter() { + self.snap_usage.insert(path.clone(), expr.clone()); + } } /// Removes usage that was introduced current block and usage that is already covered @@ -91,6 +96,30 @@ impl Usage { } } } + for (member_path, _) in self.snap_usage.clone() { + // Prune usages from snap_usage. + if self.usage.contains_key(&member_path) { + self.snap_usage.swap_remove(&member_path); + continue; + } + + // Prune introductions from snap_usage. + if self.introductions.contains(&member_path.base_var()) { + self.snap_usage.swap_remove(&member_path); + } + + // Prune snap_usage that are members of other snap_usage or usages. + let mut current_path = member_path.clone(); + while let MemberPath::Member { parent, .. } = current_path { + current_path = *parent.clone(); + if self.snap_usage.contains_key(¤t_path) + | self.usage.contains_key(¤t_path) + { + self.snap_usage.swap_remove(&member_path); + break; + } + } + } for (member_path, _) in self.changes.clone() { // Prune introductions from changes. if self.introductions.contains(&member_path.base_var()) { @@ -98,9 +127,19 @@ impl Usage { } // Prune changes that are members of other changes. + // Also if a child is changed and its parent is used, then we change the parent. + // TODO(TomerStarkware): Deconstruct the parent, and snap_use other members. let mut current_path = member_path.clone(); while let MemberPath::Member { parent, .. } = current_path { current_path = *parent.clone(); + if self.snap_usage.contains_key(¤t_path) { + // Note that current_path must be top most usage as we prune snap_usage and + // usage. + if let Some(value) = self.snap_usage.swap_remove(¤t_path) { + self.usage.insert(current_path.clone(), value.clone()); + self.changes.insert(current_path.clone(), value); + }; + } if self.changes.contains_key(¤t_path) { self.changes.swap_remove(&member_path); break; @@ -142,7 +181,26 @@ impl BlockUsages { self.handle_expr(function_body, *value, current); } }, - Expr::Snapshot(expr) => self.handle_expr(function_body, expr.inner, current), + Expr::Snapshot(expr) => { + let expr_id = expr.inner; + + match &function_body.exprs[expr_id] { + Expr::Var(expr_var) => { + current.snap_usage.insert( + MemberPath::Var(expr_var.var), + ExprVarMemberPath::Var(expr_var.clone()), + ); + } + Expr::MemberAccess(expr) => { + if let Some(member_path) = &expr.member_path { + current.snap_usage.insert(member_path.into(), member_path.clone()); + } else { + self.handle_expr(function_body, expr.expr, current); + } + } + _ => self.handle_expr(function_body, expr_id, current), + } + } Expr::Desnap(expr) => self.handle_expr(function_body, expr.inner, current), Expr::Assignment(expr) => { self.handle_expr(function_body, expr.rhs, current); diff --git a/crates/cairo-lang-lowering/src/lower/usage_test.rs b/crates/cairo-lang-lowering/src/lower/usage_test.rs index cd8ce6edb7f..bb465cd733c 100644 --- a/crates/cairo-lang-lowering/src/lower/usage_test.rs +++ b/crates/cairo-lang-lowering/src/lower/usage_test.rs @@ -58,20 +58,26 @@ fn test_function_usage( _ => unreachable!(), } writeln!(usages_str, " {}:{}:", position.line, position.col).unwrap(); - write!(usages_str, " Usage: ").unwrap(); + write!(usages_str, " Usage:").unwrap(); for (_, expr) in usage.usage.iter() { - write!(usages_str, "{:?}, ", expr.debug(&expr_formatter)).unwrap(); + write!(usages_str, " {:?},", expr.debug(&expr_formatter)).unwrap(); } writeln!(usages_str).unwrap(); - write!(usages_str, " Changes: ").unwrap(); + write!(usages_str, " Changes:").unwrap(); for (_, expr) in usage.changes.iter() { - write!(usages_str, "{:?}, ", expr.debug(&expr_formatter)).unwrap(); + write!(usages_str, " {:?},", expr.debug(&expr_formatter)).unwrap(); } writeln!(usages_str).unwrap(); - write!(usages_str, " Introductions: ").unwrap(); + write!(usages_str, " Snapshot_Usage:").unwrap(); + for (_, expr) in usage.snap_usage.iter() { + write!(usages_str, " {:?},", expr.debug(&expr_formatter)).unwrap(); + } + writeln!(usages_str).unwrap(); + write!(usages_str, " Introductions:").unwrap(); for var in &usage.introductions { - write!(usages_str, "{:?}, ", var.debug(&expr_formatter)).unwrap(); + write!(usages_str, " {:?},", var.debug(&expr_formatter)).unwrap(); } + writeln!(usages_str).unwrap(); } diff --git a/crates/cairo-lang-lowering/src/test_data/while b/crates/cairo-lang-lowering/src/test_data/while index 2a0031a01f0..bb7aa00b4c9 100644 --- a/crates/cairo-lang-lowering/src/test_data/while +++ b/crates/cairo-lang-lowering/src/test_data/while @@ -25,26 +25,27 @@ Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin, v2: core::felt252 blk0 (root): Statements: (v3: core::felt252) <- 5 - (v4: core::RangeCheck, v5: core::gas::GasBuiltin, v6: core::panics::PanicResult::<(core::felt252, ())>) <- test::foo[expr12](v0, v1, v3, v2) + (v4: core::felt252, v5: @core::felt252) <- snapshot(v2) + (v6: core::RangeCheck, v7: core::gas::GasBuiltin, v8: core::panics::PanicResult::<(core::felt252, ())>) <- test::foo[expr12](v0, v1, v3, v5) End: - Match(match_enum(v6) { - PanicResult::Ok(v7) => blk1, - PanicResult::Err(v8) => blk2, + Match(match_enum(v8) { + PanicResult::Ok(v9) => blk1, + PanicResult::Err(v10) => blk2, }) blk1: Statements: - (v9: core::felt252, v10: ()) <- struct_destructure(v7) - (v11: ((),)) <- struct_construct(v10) - (v12: core::panics::PanicResult::<((),)>) <- PanicResult::Ok(v11) + (v11: core::felt252, v12: ()) <- struct_destructure(v9) + (v13: ((),)) <- struct_construct(v12) + (v14: core::panics::PanicResult::<((),)>) <- PanicResult::Ok(v13) End: - Return(v4, v5, v12) + Return(v6, v7, v14) blk2: Statements: - (v13: core::panics::PanicResult::<((),)>) <- PanicResult::Err(v8) + (v15: core::panics::PanicResult::<((),)>) <- PanicResult::Err(v10) End: - Return(v4, v5, v13) + Return(v6, v7, v15) //! > ========================================================================== diff --git a/tests/bug_samples/issue2816.cairo b/tests/bug_samples/issue2816.cairo new file mode 100644 index 00000000000..5e5440303d1 --- /dev/null +++ b/tests/bug_samples/issue2816.cairo @@ -0,0 +1,27 @@ +fn contains, +Drop, +Copy>( + ref self: Array, item: T +) -> bool { + let mut index = 0_usize; + loop { + if index >= self.len() { + break false; + } else if *self[index] == item { + break true; + } else { + index = index + 1_usize; + }; + } +} + +#[test] +fn test_contains() { + let mut arr: Array = array![1, 2, 3, 4]; + assert!(contains(ref arr, 1)); + assert!(contains(ref arr, 2)); + assert!(contains(ref arr, 3)); + assert!(contains(ref arr, 4)); + assert!(!contains(ref arr, 5)); + assert!(!contains(ref arr, 6)); + assert!(!contains(ref arr, 7)); + assert!(!contains(ref arr, 8)); +} diff --git a/tests/bug_samples/lib.cairo b/tests/bug_samples/lib.cairo index 798bcfd5c91..27cf287bf10 100644 --- a/tests/bug_samples/lib.cairo +++ b/tests/bug_samples/lib.cairo @@ -13,6 +13,7 @@ mod issue2480; mod issue2530; mod issue2567; mod issue2612; +mod issue2816; mod issue2819; mod issue2820; mod issue2932; diff --git a/tests/test_data/fib_loop.casm b/tests/test_data/fib_loop.casm index dbc64e0de37..28bb4c7db12 100644 --- a/tests/test_data/fib_loop.casm +++ b/tests/test_data/fib_loop.casm @@ -1,16 +1,16 @@ -[ap + 0] = [fp + -3], ap++; [ap + 0] = [fp + -5], ap++; +[ap + 0] = [fp + -3], ap++; [ap + 0] = [fp + -4], ap++; call rel 3; ret; -jmp rel 7 if [fp + -5] != 0; -[ap + 0] = [fp + -5], ap++; -[ap + 0] = [fp + -3], ap++; -[ap + 0] = [fp + -4], ap++; +jmp rel 7 if [fp + -4] != 0; [ap + 0] = [fp + -4], ap++; +[ap + 0] = [fp + -3], ap++; +[ap + 0] = [fp + -5], ap++; +[ap + 0] = [fp + -5], ap++; ret; -[fp + -5] = [ap + 0] + 1, ap++; [ap + 0] = [fp + -3], ap++; -[ap + 0] = [fp + -4] + [fp + -3], ap++; +[fp + -4] = [ap + 0] + 1, ap++; +[ap + 0] = [fp + -5] + [fp + -3], ap++; call rel -11; ret; diff --git a/tests/test_data/fib_loop.sierra b/tests/test_data/fib_loop.sierra index 1e0da787391..44fd88a9b1a 100644 --- a/tests/test_data/fib_loop.sierra +++ b/tests/test_data/fib_loop.sierra @@ -15,34 +15,34 @@ libfunc felt252_sub = felt252_sub; libfunc felt252_add = felt252_add; disable_ap_tracking() -> (); // 0 -store_temp([2]) -> ([2]); // 1 -store_temp([0]) -> ([0]); // 2 +store_temp([0]) -> ([0]); // 1 +store_temp([2]) -> ([2]); // 2 store_temp([1]) -> ([1]); // 3 -function_call([2], [0], [1]) -> ([3], [4], [5], [6]); // 4 +function_call([0], [2], [1]) -> ([3], [4], [5], [6]); // 4 drop([3]) -> (); // 5 drop([4]) -> (); // 6 drop([5]) -> (); // 7 return([6]); // 8 disable_ap_tracking() -> (); // 9 -dup([0]) -> ([0], [3]); // 10 +dup([1]) -> ([1], [3]); // 10 felt252_is_zero([3]) { fallthrough() 19([4]) }; // 11 branch_align() -> (); // 12 -store_temp([0]) -> ([0]); // 13 +store_temp([1]) -> ([1]); // 13 store_temp([2]) -> ([2]); // 14 -dup([1]) -> ([1], [5]); // 15 +dup([0]) -> ([0], [5]); // 15 store_temp([5]) -> ([5]); // 16 -store_temp([1]) -> ([1]); // 17 -return([0], [2], [5], [1]); // 18 +store_temp([0]) -> ([0]); // 17 +return([1], [2], [5], [0]); // 18 branch_align() -> (); // 19 drop>([4]) -> (); // 20 const_as_immediate>() -> ([6]); // 21 -felt252_sub([0], [6]) -> ([7]); // 22 +felt252_sub([1], [6]) -> ([7]); // 22 dup([2]) -> ([2], [8]); // 23 -felt252_add([1], [8]) -> ([9]); // 24 -store_temp([7]) -> ([7]); // 25 -store_temp([2]) -> ([2]); // 26 +felt252_add([0], [8]) -> ([9]); // 24 +store_temp([2]) -> ([2]); // 25 +store_temp([7]) -> ([7]); // 26 store_temp([9]) -> ([9]); // 27 -function_call([7], [2], [9]) -> ([10], [11], [12], [13]); // 28 +function_call([2], [7], [9]) -> ([10], [11], [12], [13]); // 28 return([10], [11], [12], [13]); // 29 examples::fib_loop::fib@0([0]: felt252, [1]: felt252, [2]: felt252) -> (felt252);