Skip to content

Commit

Permalink
Implement use_sites manually to get more control. (#4961)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyalesokhin-starkware authored Jan 31, 2024
1 parent c2e3829 commit df58ef2
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 63 deletions.
83 changes: 20 additions & 63 deletions crates/cairo-lang-lowering/src/optimizations/cancel_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,9 @@ use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use itertools::{izip, zip_eq, Itertools};

use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
use crate::borrow_check::demand::DemandReporter;
use crate::borrow_check::Demand;
use crate::utils::{Rebuilder, RebuilderEx};
use crate::{BlockId, FlatLowered, MatchInfo, Statement, VarRemapping, VarUsage, VariableId};

pub type CancelOpsDemand = Demand<VariableId, StatementLocation, ()>;

/// Demand reporter for cancel ops.
/// use sites are reported using `dup` and `last_use`.
impl DemandReporter<VariableId> for CancelOpsContext<'_> {
type IntroducePosition = ();
type UsePosition = StatementLocation;

fn dup(
&mut self,
position: StatementLocation,
var: VariableId,
_next_usage_position: StatementLocation,
) {
self.use_sites.entry(var).or_default().push(position);
}

fn last_use(&mut self, position: StatementLocation, var: VariableId) {
self.use_sites.entry(var).or_default().push(position);
}
}

/// Cancels out a (StructConstruct, StructDestructure) and (Snap, Desnap) pair.
///
///
Expand Down Expand Up @@ -100,10 +76,6 @@ fn get_use_sites<'a>(
}
}

#[derive(Clone)]
pub struct AnalysisInfo {
demand: CancelOpsDemand,
}
impl<'a> CancelOpsContext<'a> {
fn rename_var(&mut self, from: VariableId, to: VariableId) {
assert!(
Expand All @@ -118,6 +90,10 @@ impl<'a> CancelOpsContext<'a> {
}
}

fn add_use_site(&mut self, var: VariableId, use_site: StatementLocation) {
self.use_sites.entry(var).or_default().push(use_site);
}

/// Handles a statement and returns true if it can be removed.
fn handle_stmt(&mut self, stmt: &'a Statement, statement_location: StatementLocation) -> bool {
match stmt {
Expand Down Expand Up @@ -260,71 +236,52 @@ impl<'a> CancelOpsContext<'a> {
}

impl<'a> Analyzer<'a> for CancelOpsContext<'a> {
type Info = AnalysisInfo;
type Info = ();

fn visit_stmt(
&mut self,
info: &mut Self::Info,
_info: &mut Self::Info,
statement_location: StatementLocation,
stmt: &'a Statement,
) {
if !self.handle_stmt(stmt, statement_location) {
info.demand.variables_introduced(self, &stmt.outputs(), ());
info.demand.variables_used(
self,
stmt.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, statement_location)),
);
for input in stmt.inputs() {
self.add_use_site(input.var_id, statement_location);
}
}
}

fn visit_goto(
&mut self,
info: &mut Self::Info,
_info: &mut Self::Info,
statement_location: StatementLocation,
_target_block_id: BlockId,
remapping: &VarRemapping,
) {
info.demand.apply_remapping(
self,
remapping.iter().map(|(dst, src)| (dst, (&src.var_id, statement_location))),
);
for src in remapping.values() {
self.add_use_site(src.var_id, statement_location);
}
}

fn merge_match(
&mut self,
statement_location: StatementLocation,
match_info: &'a MatchInfo,
infos: &[Self::Info],
_infos: &[Self::Info],
) -> Self::Info {
let arm_demands = zip_eq(match_info.arms(), infos)
.map(|(arm, info)| {
let mut demand = info.demand.clone();
demand.variables_introduced(self, &arm.var_ids, ());

(demand, ())
})
.collect_vec();
let mut demand = CancelOpsDemand::merge_demands(&arm_demands, self);

demand.variables_used(
self,
match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, statement_location)),
);

Self::Info { demand }
for var in match_info.inputs() {
self.add_use_site(var.var_id, statement_location);
}
}

fn info_from_return(
&mut self,
statement_location: StatementLocation,
vars: &[VarUsage],
) -> Self::Info {
let mut demand = CancelOpsDemand::default();
demand.variables_used(
self,
vars.iter().map(|VarUsage { var_id, .. }| (var_id, statement_location)),
);
Self::Info { demand }
for var in vars {
self.add_use_site(var.var_id, statement_location);
}
}
}

Expand Down
85 changes: 85 additions & 0 deletions crates/cairo-lang-lowering/src/optimizations/test_data/cancel_ops
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,88 @@ Statements:
(v2: test::MyStruct) <- struct_construct(v1)
End:
Return(v2)

//! > ==========================================================================

//! > destracture remapped to snapshot.

//! > test_runner_name
test_cancel_ops

//! > function
fn foo(a: (u32,), b: felt252) -> u32 {
let d = @if b == 0 {
let (c, ) = a;
c
} else {
let (c, ) = a;
c
};

*d
}

//! > function_name
foo

//! > module_code

//! > semantic_diagnostics

//! > lowering_diagnostics

//! > before
Parameters: v0: (core::integer::u32,), v1: core::felt252
blk0 (root):
Statements:
End:
Match(match core::felt252_is_zero(v1) {
IsZeroResult::Zero => blk1,
IsZeroResult::NonZero(v3) => blk2,
})

blk1:
Statements:
(v2: core::integer::u32) <- struct_destructure(v0)
End:
Goto(blk3, {v2 -> v5})

blk2:
Statements:
(v4: core::integer::u32) <- struct_destructure(v0)
End:
Goto(blk3, {v4 -> v5})

blk3:
Statements:
(v6: core::integer::u32, v7: @core::integer::u32) <- snapshot(v5)
(v8: core::integer::u32) <- desnap(v7)
End:
Return(v8)

//! > after
Parameters: v0: (core::integer::u32,), v1: core::felt252
blk0 (root):
Statements:
End:
Match(match core::felt252_is_zero(v1) {
IsZeroResult::Zero => blk1,
IsZeroResult::NonZero(v3) => blk2,
})

blk1:
Statements:
(v2: core::integer::u32) <- struct_destructure(v0)
End:
Goto(blk3, {v2 -> v5})

blk2:
Statements:
(v4: core::integer::u32) <- struct_destructure(v0)
End:
Goto(blk3, {v4 -> v5})

blk3:
Statements:
End:
Return(v5)

0 comments on commit df58ef2

Please sign in to comment.