Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement use_sites manually to get more control. #4961

Merged
merged 1 commit into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 20 additions & 63 deletions crates/cairo-lang-lowering/src/optimizations/cancel_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,9 @@ use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use itertools::{izip, zip_eq, Itertools};

use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
use crate::borrow_check::demand::DemandReporter;
use crate::borrow_check::Demand;
use crate::utils::{Rebuilder, RebuilderEx};
use crate::{BlockId, FlatLowered, MatchInfo, Statement, VarRemapping, VarUsage, VariableId};

pub type CancelOpsDemand = Demand<VariableId, StatementLocation, ()>;

/// Demand reporter for cancel ops.
/// use sites are reported using `dup` and `last_use`.
impl DemandReporter<VariableId> for CancelOpsContext<'_> {
type IntroducePosition = ();
type UsePosition = StatementLocation;

fn dup(
&mut self,
position: StatementLocation,
var: VariableId,
_next_usage_position: StatementLocation,
) {
self.use_sites.entry(var).or_default().push(position);
}

fn last_use(&mut self, position: StatementLocation, var: VariableId) {
self.use_sites.entry(var).or_default().push(position);
}
}

/// Cancels out a (StructConstruct, StructDestructure) and (Snap, Desnap) pair.
///
///
Expand Down Expand Up @@ -100,10 +76,6 @@ fn get_use_sites<'a>(
}
}

#[derive(Clone)]
pub struct AnalysisInfo {
demand: CancelOpsDemand,
}
impl<'a> CancelOpsContext<'a> {
fn rename_var(&mut self, from: VariableId, to: VariableId) {
assert!(
Expand All @@ -118,6 +90,10 @@ impl<'a> CancelOpsContext<'a> {
}
}

fn add_use_site(&mut self, var: VariableId, use_site: StatementLocation) {
self.use_sites.entry(var).or_default().push(use_site);
}

/// Handles a statement and returns true if it can be removed.
fn handle_stmt(&mut self, stmt: &'a Statement, statement_location: StatementLocation) -> bool {
match stmt {
Expand Down Expand Up @@ -260,71 +236,52 @@ impl<'a> CancelOpsContext<'a> {
}

impl<'a> Analyzer<'a> for CancelOpsContext<'a> {
type Info = AnalysisInfo;
type Info = ();

fn visit_stmt(
&mut self,
info: &mut Self::Info,
_info: &mut Self::Info,
statement_location: StatementLocation,
stmt: &'a Statement,
) {
if !self.handle_stmt(stmt, statement_location) {
info.demand.variables_introduced(self, &stmt.outputs(), ());
info.demand.variables_used(
self,
stmt.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, statement_location)),
);
for input in stmt.inputs() {
self.add_use_site(input.var_id, statement_location);
}
}
}

fn visit_goto(
&mut self,
info: &mut Self::Info,
_info: &mut Self::Info,
statement_location: StatementLocation,
_target_block_id: BlockId,
remapping: &VarRemapping,
) {
info.demand.apply_remapping(
self,
remapping.iter().map(|(dst, src)| (dst, (&src.var_id, statement_location))),
);
for src in remapping.values() {
self.add_use_site(src.var_id, statement_location);
}
}

fn merge_match(
&mut self,
statement_location: StatementLocation,
match_info: &'a MatchInfo,
infos: &[Self::Info],
_infos: &[Self::Info],
) -> Self::Info {
let arm_demands = zip_eq(match_info.arms(), infos)
.map(|(arm, info)| {
let mut demand = info.demand.clone();
demand.variables_introduced(self, &arm.var_ids, ());

(demand, ())
})
.collect_vec();
let mut demand = CancelOpsDemand::merge_demands(&arm_demands, self);

demand.variables_used(
self,
match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, statement_location)),
);

Self::Info { demand }
for var in match_info.inputs() {
self.add_use_site(var.var_id, statement_location);
}
}

fn info_from_return(
&mut self,
statement_location: StatementLocation,
vars: &[VarUsage],
) -> Self::Info {
let mut demand = CancelOpsDemand::default();
demand.variables_used(
self,
vars.iter().map(|VarUsage { var_id, .. }| (var_id, statement_location)),
);
Self::Info { demand }
for var in vars {
self.add_use_site(var.var_id, statement_location);
}
}
}

Expand Down
85 changes: 85 additions & 0 deletions crates/cairo-lang-lowering/src/optimizations/test_data/cancel_ops
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,88 @@ Statements:
(v2: test::MyStruct) <- struct_construct(v1)
End:
Return(v2)

//! > ==========================================================================

//! > destracture remapped to snapshot.

//! > test_runner_name
test_cancel_ops

//! > function
fn foo(a: (u32,), b: felt252) -> u32 {
let d = @if b == 0 {
let (c, ) = a;
c
} else {
let (c, ) = a;
c
};

*d
}

//! > function_name
foo

//! > module_code

//! > semantic_diagnostics

//! > lowering_diagnostics

//! > before
Parameters: v0: (core::integer::u32,), v1: core::felt252
blk0 (root):
Statements:
End:
Match(match core::felt252_is_zero(v1) {
IsZeroResult::Zero => blk1,
IsZeroResult::NonZero(v3) => blk2,
})

blk1:
Statements:
(v2: core::integer::u32) <- struct_destructure(v0)
End:
Goto(blk3, {v2 -> v5})

blk2:
Statements:
(v4: core::integer::u32) <- struct_destructure(v0)
End:
Goto(blk3, {v4 -> v5})

blk3:
Statements:
(v6: core::integer::u32, v7: @core::integer::u32) <- snapshot(v5)
(v8: core::integer::u32) <- desnap(v7)
End:
Return(v8)

//! > after
Parameters: v0: (core::integer::u32,), v1: core::felt252
blk0 (root):
Statements:
End:
Match(match core::felt252_is_zero(v1) {
IsZeroResult::Zero => blk1,
IsZeroResult::NonZero(v3) => blk2,
})

blk1:
Statements:
(v2: core::integer::u32) <- struct_destructure(v0)
End:
Goto(blk3, {v2 -> v5})

blk2:
Statements:
(v4: core::integer::u32) <- struct_destructure(v0)
End:
Goto(blk3, {v4 -> v5})

blk3:
Statements:
End:
Return(v5)
Loading