Skip to content

Commit

Permalink
fixed loop capture of snapshoted variables (#5934)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomerStarkware authored Jul 14, 2024
1 parent 0372232 commit 5474cf3
Show file tree
Hide file tree
Showing 13 changed files with 971 additions and 86 deletions.
31 changes: 31 additions & 0 deletions corelib/src/test/language_features/while_test.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,34 @@ fn test_outer_loop_break() {
};
assert_eq!(i, 10);
}

#[test]
fn test_borrow_usage() {
let mut i = 0;
let arr = array![1, 2, 3, 4];
while i != arr.len() {
i += 1;
};
assert_eq!(arr.len(), 4);
}

#[derive(Drop)]
struct NonCopy {
x: felt252,
}

fn assert_x_eq(a: @NonCopy, x: felt252) {
assert_eq!(a.x, @x);
}

#[test]
fn test_borrow_with_inner_change() {
let mut a = NonCopy { x: 0 };
let mut i = 0;
while i != 5 {
a.x = i;
assert_x_eq(@a, i);
i += 1;
};
}

49 changes: 49 additions & 0 deletions crates/cairo-lang-lowering/src/lower/block_builder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use cairo_lang_defs::ids::{MemberId, NamedLanguageElementId};
use cairo_lang_diagnostics::Maybe;
use cairo_lang_semantic as semantic;
use cairo_lang_semantic::types::{peel_snapshots, wrap_in_snapshots};
use cairo_lang_syntax::node::TypedStablePtr;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
Expand All @@ -26,6 +27,8 @@ use crate::{
pub struct BlockBuilder {
/// A store for semantic variables, owning their OwnedVariable instances.
pub semantics: SemanticLoweringMapping,
/// The semantic variables that are captured as snapshots in this block.
pub snapped_semantics: OrderedHashMap<MemberPath, VariableId>,
/// The semantic variables that are added/changed in this block.
changed_member_paths: OrderedHashSet<MemberPath>,
/// Current sequence of lowered statements emitted.
Expand All @@ -38,6 +41,7 @@ impl BlockBuilder {
pub fn root(_ctx: &mut LoweringContext<'_, '_>, block_id: BlockId) -> Self {
BlockBuilder {
semantics: Default::default(),
snapped_semantics: Default::default(),
changed_member_paths: Default::default(),
statements: Default::default(),
block_id,
Expand All @@ -48,6 +52,7 @@ impl BlockBuilder {
pub fn child_block_builder(&self, block_id: BlockId) -> BlockBuilder {
BlockBuilder {
semantics: self.semantics.clone(),
snapped_semantics: self.snapped_semantics.clone(),
changed_member_paths: Default::default(),
statements: Default::default(),
block_id,
Expand All @@ -59,6 +64,7 @@ impl BlockBuilder {
pub fn sibling_block_builder(&self, block_id: BlockId) -> BlockBuilder {
BlockBuilder {
semantics: self.semantics.clone(),
snapped_semantics: self.snapped_semantics.clone(),
changed_member_paths: self.changed_member_paths.clone(),
statements: Default::default(),
block_id,
Expand Down Expand Up @@ -119,6 +125,49 @@ impl BlockBuilder {
.map(|var_id| VarUsage { var_id, location })
}

/// Updates the reference of a semantic variable to a snapshot of its lowered variable.
pub fn update_snap_ref(&mut self, member_path: &ExprVarMemberPath, var: VariableId) {
self.snapped_semantics.insert(member_path.into(), var);
}

/// Gets the reference of a snapshot of semantic variable, possibly by deconstructing a
/// its parents.
pub fn get_snap_ref(
&mut self,
ctx: &mut LoweringContext<'_, '_>,
member_path: &ExprVarMemberPath,
) -> Option<VarUsage> {
let location = ctx.get_location(member_path.stable_ptr().untyped());
if let Some(var_id) = self.snapped_semantics.get::<MemberPath>(&member_path.into()) {
return Some(VarUsage { var_id: *var_id, location });
}
let ExprVarMemberPath::Member { parent, member_id, concrete_struct_id, .. } = member_path
else {
return None;
};
// TODO(TomerStarkware): Consider adding the result to snap_semantics to avoid
// recomputation.
let parent_var = self.get_snap_ref(ctx, parent)?;
let members = ctx.db.concrete_struct_members(*concrete_struct_id).ok()?;
let (parent_number_of_snapshots, _) =
peel_snapshots(ctx.db.upcast(), ctx.variables[parent_var.var_id].ty);
let member_idx = members.iter().position(|(_, member)| member.id == *member_id)?;
Some(
generators::StructMemberAccess {
input: parent_var,
member_tys: members
.into_iter()
.map(|(_, member)| {
wrap_in_snapshots(ctx.db.upcast(), member.ty, parent_number_of_snapshots)
})
.collect(),
member_idx,
location,
}
.add(ctx, &mut self.statements),
)
}

/// Gets the type of a semantic variable.
pub fn get_ty(
&mut self,
Expand Down
15 changes: 2 additions & 13 deletions crates/cairo-lang-lowering/src/lower/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ use cairo_lang_utils::Intern;
use defs::diagnostic_utils::StableLocation;
use id_arena::Arena;
use itertools::{zip_eq, Itertools};
use semantic::corelib::{core_module, get_ty_by_name, get_usize_ty};
use semantic::corelib::{core_module, get_ty_by_name};
use semantic::expr::inference::InferenceError;
use semantic::items::constant::value_as_const_value;
use semantic::types::wrap_in_snapshots;
use semantic::{ExprVarMemberPath, MatchArmSelector, TypeLongId};
use {cairo_lang_defs as defs, cairo_lang_semantic as semantic};
Expand Down Expand Up @@ -300,17 +299,7 @@ impl LoweredExpr {
LoweredExpr::Snapshot { expr, .. } => {
wrap_in_snapshots(ctx.db.upcast(), expr.ty(ctx), 1)
}
LoweredExpr::FixedSizeArray { exprs, .. } => semantic::TypeLongId::FixedSizeArray {
type_id: exprs[0].ty(ctx),
size: value_as_const_value(
ctx.db.upcast(),
get_usize_ty(ctx.db.upcast()),
&exprs.len().into(),
)
.unwrap()
.intern(ctx.db),
}
.intern(ctx.db),
LoweredExpr::FixedSizeArray { ty, .. } => *ty,
}
}
pub fn location(&self) -> LocationId {
Expand Down
101 changes: 89 additions & 12 deletions crates/cairo-lang-lowering/src/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use std::vec;
use block_builder::BlockBuilder;
use cairo_lang_debug::DebugWithDb;
use cairo_lang_diagnostics::{Diagnostics, Maybe};
use cairo_lang_semantic::corelib::{self, unwrap_error_propagation_type, ErrorPropagationType};
use cairo_lang_semantic::corelib::{unwrap_error_propagation_type, ErrorPropagationType};
use cairo_lang_semantic::db::SemanticGroup;
use cairo_lang_semantic::{LocalVariable, VarId};
use cairo_lang_semantic::{corelib, ExprVar, LocalVariable, VarId};
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_syntax::node::TypedStablePtr;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
Expand All @@ -27,6 +27,7 @@ use semantic::{
ExprFunctionCallArg, ExprId, ExprPropagateError, ExprVarMemberPath, GenericArgumentId,
MatchArmSelector, SemanticDiagnostic, TypeLongId,
};
use usage::MemberPath;
use {cairo_lang_defs as defs, cairo_lang_semantic as semantic};

use self::block_builder::SealedBlockBuilder;
Expand Down Expand Up @@ -413,6 +414,7 @@ pub fn lower_loop_function(
function_id: FunctionWithBodyId,
loop_signature: Signature,
loop_expr_id: semantic::ExprId,
snapped_params: &OrderedHashMap<MemberPath, semantic::ExprVarMemberPath>,
) -> Maybe<FlatLowered> {
let mut ctx = LoweringContext::new(encapsulating_ctx, function_id, loop_signature.clone())?;
let old_loop_expr_id = std::mem::replace(&mut ctx.current_loop_expr_id, Some(loop_expr_id));
Expand All @@ -429,7 +431,11 @@ pub fn lower_loop_function(
.map(|param| {
let location = ctx.get_location(param.stable_ptr().untyped());
let var = ctx.new_var(VarRequest { ty: param.ty(), location });
builder.semantics.introduce((&param).into(), var);
if snapped_params.contains_key::<MemberPath>(&(&param).into()) {
builder.update_snap_ref(&param, var)
} else {
builder.semantics.introduce((&param).into(), var);
}
var
})
.collect_vec();
Expand Down Expand Up @@ -1116,8 +1122,31 @@ fn lower_expr_snapshot(
builder: &mut BlockBuilder,
) -> LoweringResult<LoweredExpr> {
log::trace!("Lowering a snapshot: {:?}", expr.debug(&ctx.expr_formatter));
// If the inner expression is a variable, or a member access, and we already have a snapshot var
// we can use it without creating a new one.
// Note that in a closure we might only have a snapshot of the variable and not the original.
match &ctx.function_body.exprs[expr.inner] {
semantic::Expr::Var(expr_var) => {
let member_path = ExprVarMemberPath::Var(expr_var.clone());
if let Some(var) = builder.get_snap_ref(ctx, &member_path) {
return Ok(LoweredExpr::AtVariable(var));
}
}
semantic::Expr::MemberAccess(expr) => {
if let Some(var) = expr
.member_path
.clone()
.and_then(|member_path| builder.get_snap_ref(ctx, &member_path))
{
return Ok(LoweredExpr::AtVariable(var));
}
}
_ => {}
}
let lowered = lower_expr(ctx, builder, expr.inner)?;

let location = ctx.get_location(expr.stable_ptr.untyped());
let expr = Box::new(lower_expr(ctx, builder, expr.inner)?);
let expr = Box::new(lowered);
Ok(LoweredExpr::Snapshot { expr, location })
}

Expand Down Expand Up @@ -1348,7 +1377,26 @@ fn lower_expr_loop(
let usage = &ctx.block_usages.block_usages[&loop_expr_id];

// Determine signature.
let params = usage.usage.iter().map(|(_, expr)| expr.clone()).collect_vec();
let params = usage
.usage
.iter()
.map(|(_, expr)| expr.clone())
.chain(usage.snap_usage.iter().map(|(_, expr)| match expr {
ExprVarMemberPath::Var(var) => ExprVarMemberPath::Var(ExprVar {
ty: wrap_in_snapshots(ctx.db.upcast(), var.ty, 1),
..*var
}),
ExprVarMemberPath::Member { parent, member_id, stable_ptr, concrete_struct_id, ty } => {
ExprVarMemberPath::Member {
parent: parent.clone(),
member_id: *member_id,
stable_ptr: *stable_ptr,
concrete_struct_id: *concrete_struct_id,
ty: wrap_in_snapshots(ctx.db.upcast(), *ty, 1),
}
}
}))
.collect_vec();
let extra_rets = usage.changes.iter().map(|(_, expr)| expr.clone()).collect_vec();

let loop_signature = Signature {
Expand All @@ -1367,17 +1415,37 @@ fn lower_expr_loop(
}
.intern(ctx.db);

let snap_usage = ctx.block_usages.block_usages[&loop_expr_id].snap_usage.clone();

// Generate the function.
let encapsulating_ctx = std::mem::take(&mut ctx.encapsulating_ctx).unwrap();
let lowered =
lower_loop_function(encapsulating_ctx, function, loop_signature.clone(), loop_expr_id)
.map_err(LoweringFlowError::Failed)?;
let lowered = lower_loop_function(
encapsulating_ctx,
function,
loop_signature.clone(),
loop_expr_id,
&snap_usage,
)
.map_err(LoweringFlowError::Failed)?;
// TODO(spapini): Recursive call.
encapsulating_ctx.lowerings.insert(loop_expr_id, lowered);

ctx.encapsulating_ctx = Some(encapsulating_ctx);
let old_loop_expr_id = std::mem::replace(&mut ctx.current_loop_expr_id, Some(loop_expr_id));
for snapshot_param in snap_usage.values() {
// if we have access to the real member we generate a snapshot, otherwise it should be
// accessible with `builder.get_snap_ref`
if let Some(input) = builder.get_ref(ctx, snapshot_param) {
let (original, snapped) = generators::Snapshot {
input,
location: ctx.get_location(snapshot_param.stable_ptr().untyped()),
}
.add(ctx, &mut builder.statements);
builder.update_snap_ref(snapshot_param, snapped);
builder.update_ref(ctx, snapshot_param, original);
}
}
let call = call_loop_func(ctx, loop_signature, builder, loop_expr_id, stable_ptr.untyped());

ctx.current_loop_expr_id = old_loop_expr_id;
call
}
Expand All @@ -1402,9 +1470,18 @@ fn call_loop_func(
.params
.into_iter()
.map(|param| {
builder.get_ref(ctx, &param).ok_or_else(|| {
LoweringFlowError::Failed(ctx.diagnostics.report(stable_ptr, MemberPathLoop))
})
builder
.get_ref(ctx, &param)
.and_then(|var| (ctx.variables[var.var_id].ty == param.ty()).then_some(var))
.or_else(|| {
let var = builder.get_snap_ref(ctx, &param)?;
(ctx.variables[var.var_id].ty == param.ty()).then_some(var)
})
.ok_or_else(|| {
// TODO(TomerStaskware): make sure this is unreachable and remove
// `MemberPathLoop` diagnostic.
LoweringFlowError::Failed(ctx.diagnostics.report(stable_ptr, MemberPathLoop))
})
})
.collect::<LoweringResult<Vec<_>>>()?;
let extra_ret_tys = loop_signature.extra_rets.iter().map(|path| path.ty()).collect_vec();
Expand Down
Loading

0 comments on commit 5474cf3

Please sign in to comment.