Skip to content

Commit

Permalink
Refactor topo-order alg out of sierra-gas.
Browse files Browse the repository at this point in the history
In preparation for ap-change solve.

commit-id:eb28e7c1
  • Loading branch information
orizi committed Oct 12, 2023
1 parent da581bf commit d8357e1
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 104 deletions.
31 changes: 13 additions & 18 deletions crates/cairo-lang-sierra-gas/src/compute_costs.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::hash_map;
use std::ops::{Add, Sub};

use cairo_lang_sierra::algorithm::topological_order::get_topological_ordering;
use cairo_lang_sierra::extensions::gas::{BuiltinCostWithdrawGasLibfunc, CostTokenType};
use cairo_lang_sierra::ids::ConcreteLibfuncId;
use cairo_lang_sierra::program::{BranchInfo, Invocation, Program, Statement, StatementIdx};
Expand All @@ -13,7 +14,6 @@ use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use itertools::zip_eq;

use crate::gas_info::GasInfo;
use crate::generate_equations::{calculate_reverse_topological_ordering, TopologicalOrderStatus};
use crate::objects::{BranchCost, ConstCost, PreCost};
use crate::CostError;

Expand Down Expand Up @@ -402,7 +402,7 @@ impl<'a, CostType: CostTypeTrait> CostContext<'a, CostType> {
specific_cost_context: &SpecificCostContext,
) -> Result<(), CostError> {
let topological_order =
compute_topological_order(self.program.statements.len(), true, &|current_idx| {
compute_topological_order(self.program.statements.len(), true, |current_idx| {
match &self.program.get_statement(current_idx).unwrap() {
Statement::Return(_) => {
// Return has no dependencies.
Expand Down Expand Up @@ -471,8 +471,8 @@ impl<'a, CostType: CostTypeTrait> CostContext<'a, CostType> {
//
// Note, that we allow cycles, but the result may not be optimal in such a case.
let topological_order =
compute_topological_order(self.program.statements.len(), false, &|current_idx| {
match &self.program.get_statement(current_idx).unwrap() {
compute_topological_order(self.program.statements.len(), false, |current_idx| {
match self.program.get_statement(current_idx).unwrap() {
Statement::Return(_) => {
// Return has no dependencies.
vec![]
Expand Down Expand Up @@ -616,21 +616,16 @@ impl<'a, CostType: CostTypeTrait> CostContext<'a, CostType> {
fn compute_topological_order(
n_statements: usize,
detect_cycles: bool,
dependencies_callback: &dyn Fn(&StatementIdx) -> Vec<StatementIdx>,
dependencies_callback: impl Fn(&StatementIdx) -> Vec<StatementIdx>,
) -> Result<Vec<StatementIdx>, CostError> {
let mut topological_order: Vec<StatementIdx> = Default::default();
let mut status = vec![TopologicalOrderStatus::NotStarted; n_statements];
for idx in 0..n_statements {
calculate_reverse_topological_ordering(
&mut topological_order,
&mut status,
&StatementIdx(idx),
detect_cycles,
dependencies_callback,
)?;
}

Ok(topological_order)
get_topological_ordering(
detect_cycles,
(0..n_statements).map(StatementIdx),
n_statements,
|idx| Ok(dependencies_callback(&idx)),
CostError::StatementOutOfBounds,
|_| CostError::UnexpectedCycle,
)
}

pub struct PreCostContext {}
Expand Down
98 changes: 12 additions & 86 deletions crates/cairo-lang-sierra-gas/src/generate_equations.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use cairo_lang_sierra::algorithm::topological_order::get_topological_ordering;
use cairo_lang_sierra::extensions::gas::CostTokenType;
use cairo_lang_sierra::ids::ConcreteLibfuncId;
use cairo_lang_sierra::program::{Program, StatementIdx};
Expand Down Expand Up @@ -99,27 +100,14 @@ impl StatementFutureCost for EquationGenerator {
}
}

#[derive(Clone, Debug)]
pub enum TopologicalOrderStatus {
/// The computation for that statement did not start.
NotStarted,
/// The computation is in progress.
InProgress,
/// The computation was completed, and all the children were visited.
Done,
}

/// Returns the reverse topological ordering of the program statements.
fn get_reverse_topological_ordering(program: &Program) -> Result<Vec<StatementIdx>, CostError> {
let mut ordering = vec![];
let mut status = vec![TopologicalOrderStatus::NotStarted; program.statements.len()];
for f in &program.funcs {
calculate_reverse_topological_ordering(
&mut ordering,
&mut status,
&f.entry_point,
false,
|idx| match program.get_statement(idx).unwrap() {
get_topological_ordering(
false,
program.funcs.iter().map(|f| f.entry_point),
program.statements.len(),
|idx| {
Ok(match program.get_statement(&idx).unwrap() {
cairo_lang_sierra::program::Statement::Invocation(invocation) => invocation
.branches
.iter()
Expand All @@ -129,71 +117,9 @@ fn get_reverse_topological_ordering(program: &Program) -> Result<Vec<StatementId
cairo_lang_sierra::program::Statement::Return(_) => {
vec![]
}
},
)?;
}
Ok(ordering)
}

/// Recursively calculates the topological ordering of the program.
pub fn calculate_reverse_topological_ordering(
ordering: &mut Vec<StatementIdx>,
status: &mut [TopologicalOrderStatus],
idx0: &StatementIdx,
detect_cycles: bool,
children_callback: impl Fn(&StatementIdx) -> Vec<StatementIdx>,
) -> Result<(), CostError> {
// A stack of statements to visit.
// When the pair is popped out of the stack, `is_done=true` means that we've already visited
// all of its children, and we just need to add it to the ordering.
let mut stack = vec![*idx0];

while let Some(idx) = stack.pop() {
match status.get(idx.0) {
Some(TopologicalOrderStatus::NotStarted) => {
// Mark the statement as `InProgress`.
status[idx.0] = TopologicalOrderStatus::InProgress;

// Push the statement back to the stack, so that after visiting all
// of its children, we would add it to the ordering.
// Add the missing children on top of it.
stack.push(idx);
for child in children_callback(&idx) {
match status.get(child.0) {
Some(TopologicalOrderStatus::InProgress) => {
if detect_cycles {
return Err(CostError::UnexpectedCycle);
}
continue;
}
Some(TopologicalOrderStatus::Done) => {
continue;
}
Some(TopologicalOrderStatus::NotStarted) => {
stack.push(child);
}
None => {
return Err(CostError::StatementOutOfBounds(child));
}
}
}
}
Some(TopologicalOrderStatus::InProgress) => {
// Mark the statement as `Done`.
status[idx.0] = TopologicalOrderStatus::Done;

// Add the element to the ordering after visiting all its children.
// This gives us reverse topological ordering.
ordering.push(idx);
}
Some(TopologicalOrderStatus::Done) => {
// Do nothing.
}
None => {
return Err(CostError::StatementOutOfBounds(idx));
}
}
}

Ok(())
})
},
CostError::StatementOutOfBounds,
|_| unreachable!("Cycles are not detected."),
)
}
1 change: 1 addition & 0 deletions crates/cairo-lang-sierra/src/algorithm/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod topological_order;
110 changes: 110 additions & 0 deletions crates/cairo-lang-sierra/src/algorithm/topological_order.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
use crate::program::StatementIdx;

/// The status of a node during the topological ordering finding algorithm.
#[derive(Clone, Debug)]
enum TopologicalOrderStatus {
/// The computation for that statement did not start.
NotStarted,
/// The computation is in progress.
InProgress,
/// The computation was completed, and all the children were visited.
Done,
}

/// Returns the topological ordering.
/// `detect_cycles` - if true, the function will return an error if a cycle is detected.
/// `roots` - the roots of the graph.
/// `node_count` - the number of nodes in the graph.
/// `get_children` - a function that returns the children of a node.
/// `out_of_bounds_err` - a function that returns an error for a node is out of bounds.
/// `cycle_err` - a function that returns an error for a node that is part of a cycle.
/// Note: Will only work properly if the nodes are in the range [0, node_count).
pub fn get_topological_ordering<E>(
detect_cycles: bool,
roots: impl Iterator<Item = StatementIdx>,
node_count: usize,
get_children: impl Fn(StatementIdx) -> Result<Vec<StatementIdx>, E>,
out_of_bounds_err: impl Fn(StatementIdx) -> E,
cycle_err: impl Fn(StatementIdx) -> E,
) -> Result<Vec<StatementIdx>, E> {
let mut ordering = vec![];
let mut status = vec![TopologicalOrderStatus::NotStarted; node_count];
for root in roots {
calculate_topological_ordering(
detect_cycles,
&mut ordering,
&mut status,
root,
&get_children,
&out_of_bounds_err,
&cycle_err,
)?;
}
Ok(ordering)
}

/// Calculates the topological ordering starting from `root`. For more info see
/// `get_topological_ordering`.
fn calculate_topological_ordering<E>(
detect_cycles: bool,
ordering: &mut Vec<StatementIdx>,
status: &mut [TopologicalOrderStatus],
root: StatementIdx,
get_children: &impl Fn(StatementIdx) -> Result<Vec<StatementIdx>, E>,
out_of_bounds_err: &impl Fn(StatementIdx) -> E,
cycle_err: &impl Fn(StatementIdx) -> E,
) -> Result<(), E> {
// A stack of statements to visit.
// When the pair is popped out of the stack, `is_done=true` means that we've already visited
// all of its children, and we just need to add it to the ordering.
let mut stack = vec![root];

while let Some(idx) = stack.pop() {
match status.get(idx.0) {
Some(TopologicalOrderStatus::NotStarted) => {
// Mark the statement as `InProgress`.
status[idx.0] = TopologicalOrderStatus::InProgress;

// Push the statement back to the stack, so that after visiting all
// of its children, we would add it to the ordering.
// Add the missing children on top of it.
stack.push(idx);
for child in get_children(idx)? {
match status.get(child.0) {
Some(TopologicalOrderStatus::InProgress) => {
if detect_cycles {
return Err(cycle_err(child));
}
continue;
}
Some(TopologicalOrderStatus::Done) => {
continue;
}
Some(TopologicalOrderStatus::NotStarted) => {
stack.push(child);
}
None => {
return Err(out_of_bounds_err(child));
}
}
}
}
Some(TopologicalOrderStatus::InProgress) => {
// Mark the statement as `Done`.
status[idx.0] = TopologicalOrderStatus::Done;

// Add the element to the ordering after visiting all its children.
// This gives us reverse topological ordering.
ordering.push(idx);
}
Some(TopologicalOrderStatus::Done) => {
// Do nothing.
}
None => {
return Err(out_of_bounds_err(idx));
}
}
}

Ok(())
}
1 change: 1 addition & 0 deletions crates/cairo-lang-sierra/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

use lalrpop_util::lalrpop_mod;

pub mod algorithm;
pub mod debug_info;
pub mod edit_state;
pub mod extensions;
Expand Down

0 comments on commit d8357e1

Please sign in to comment.