From bbb47e25a8215851a94164ba3f1492874522c0d5 Mon Sep 17 00:00:00 2001 From: Ori Ziv Date: Thu, 12 Oct 2023 12:57:03 +0300 Subject: [PATCH] Made ap-change be calculated without LP. commit-id:b1c84e1f --- .../src/compute.rs | 379 ++++++++++++++++++ crates/cairo-lang-sierra-ap-change/src/lib.rs | 2 + .../cairo-lang-sierra-to-casm/src/metadata.rs | 12 +- .../src/casm_contract_class.rs | 1 + crates/cairo-lang-test-runner/src/lib.rs | 6 +- tests/e2e_test.rs | 12 +- tests/e2e_test_data/metadata_computation | 30 ++ 7 files changed, 436 insertions(+), 6 deletions(-) create mode 100644 crates/cairo-lang-sierra-ap-change/src/compute.rs diff --git a/crates/cairo-lang-sierra-ap-change/src/compute.rs b/crates/cairo-lang-sierra-ap-change/src/compute.rs new file mode 100644 index 00000000000..497d44f52d1 --- /dev/null +++ b/crates/cairo-lang-sierra-ap-change/src/compute.rs @@ -0,0 +1,379 @@ +use std::collections::hash_map::Entry; + +use cairo_lang_sierra::algorithm::topological_order::get_topological_ordering; +use cairo_lang_sierra::extensions::core::{CoreLibfunc, CoreType}; +use cairo_lang_sierra::extensions::gas::CostTokenType; +use cairo_lang_sierra::ids::{ConcreteTypeId, FunctionId}; +use cairo_lang_sierra::program::{Program, Statement, StatementIdx}; +use cairo_lang_sierra::program_registry::ProgramRegistry; +use cairo_lang_sierra_type_size::{get_type_size_map, TypeSizeMap}; +use cairo_lang_utils::casts::IntoOrPanic; +use cairo_lang_utils::ordered_hash_map::OrderedHashMap; +use cairo_lang_utils::unordered_hash_map::UnorderedHashMap; + +use crate::ap_change_info::ApChangeInfo; +use crate::core_libfunc_ap_change::{self, InvocationApChangeInfoProvider}; +use crate::{ApChange, ApChangeError}; + +/// Helper to implement the `InvocationApChangeInfoProvider` for the equation generation. +struct InvocationApChangeInfoProviderForEqGen<'a, TokenUsages: Fn(CostTokenType) -> usize> { + /// Registry for providing the sizes of the types. + type_sizes: &'a TypeSizeMap, + /// Closure providing the token usages for the invocation. + token_usages: TokenUsages, +} + +impl<'a, TokenUsages: Fn(CostTokenType) -> usize> InvocationApChangeInfoProvider + for InvocationApChangeInfoProviderForEqGen<'a, TokenUsages> +{ + fn type_size(&self, ty: &ConcreteTypeId) -> usize { + self.type_sizes[ty].into_or_panic() + } + + fn token_usages(&self, token_type: CostTokenType) -> usize { + (self.token_usages)(token_type) + } +} + +/// A base to start ap tracking from. +#[derive(Clone, Debug)] +enum ApTrackingBase { + FunctionStart(FunctionId), + EnableStatement(StatementIdx), +} + +/// The information for ap tracking of a statement. +#[derive(Clone, Debug)] +struct ApTrackingInfo { + /// The base tracking from. + base: ApTrackingBase, + /// The ap-change from the base. + ap_change: usize, +} + +/// Helper for calculating the ap-changes of a program. +struct ApChangeCalcHelper<'a, TokenUsages: Fn(StatementIdx, CostTokenType) -> usize> { + /// The program. + program: &'a Program, + /// The program registry. + registry: ProgramRegistry, + /// Registry for providing the sizes of the types. + type_sizes: TypeSizeMap, + /// Closure providing the token usages for the invocation. + token_usages: TokenUsages, + /// The size of allocated locals until the statement. + locals_size: UnorderedHashMap, + /// The lower bound of a ap-change to the furthest return per statement. + known_ap_change_to_return: UnorderedHashMap, + /// The ap_change of functions with known ap changes. + function_ap_change: OrderedHashMap, + /// The ap tracking information per statement. + tracking_info: UnorderedHashMap, + /// The effective ap change from the statement's base. + effective_ap_change_from_base: UnorderedHashMap, + /// The variables for ap alignment. + variable_values: OrderedHashMap, +} +impl<'a, TokenUsages: Fn(StatementIdx, CostTokenType) -> usize> + ApChangeCalcHelper<'a, TokenUsages> +{ + /// Creates a new helper. + fn new(program: &'a Program, token_usages: TokenUsages) -> Result { + let registry = ProgramRegistry::::new(program)?; + let type_sizes = get_type_size_map(program, ®istry).unwrap(); + Ok(Self { + program, + registry, + type_sizes, + token_usages, + locals_size: Default::default(), + known_ap_change_to_return: Default::default(), + function_ap_change: Default::default(), + tracking_info: Default::default(), + effective_ap_change_from_base: Default::default(), + variable_values: Default::default(), + }) + } + + /// Calculates the locals size and function ap changes. + fn calc_locals_and_function_ap_changes(&mut self) -> Result<(), ApChangeError> { + let ordering = self.known_ap_change_topological_order()?; + for idx in ordering.iter().rev() { + self.calc_locals_for_statement(*idx)?; + } + for idx in ordering { + self.calc_known_ap_change_for_statement(idx)?; + } + self.function_ap_change = self + .program + .funcs + .iter() + .filter_map(|f| { + self.known_ap_change_to_return + .get(&f.entry_point) + .cloned() + .map(|ap_change| (f.id.clone(), ap_change)) + }) + .collect(); + Ok(()) + } + + /// Calculates the locals size for a statement. + fn calc_locals_for_statement(&mut self, idx: StatementIdx) -> Result<(), ApChangeError> { + for (ap_change, target) in self.get_branches(idx)? { + match ap_change { + ApChange::AtLocalsFinalization(x) => { + self.locals_size.insert(target, self.get_statement_locals(idx) + x); + } + ApChange::Unknown | ApChange::FinalizeLocals => {} + ApChange::FromMetadata + | ApChange::FunctionCall(_) + | ApChange::EnableApTracking + | ApChange::Known(_) + | ApChange::DisableApTracking => { + if let Some(locals) = self.locals_size.get(&idx) { + self.locals_size.insert(target, *locals); + } + } + } + } + Ok(()) + } + + /// Calculates the lower bound of a ap-change to the furthest return per statement. + /// If it is unknown does not set it. + fn calc_known_ap_change_for_statement( + &mut self, + idx: StatementIdx, + ) -> Result<(), ApChangeError> { + let mut max_change = 0; + for (ap_change, target) in self.get_branches(idx)? { + let Some(target_ap_change) = self.known_ap_change_to_return.get(&target) else { + return Ok(()); + }; + if let Some(ap_change) = self.branch_ap_change(idx, &ap_change, |id| { + self.known_ap_change_to_return.get(&self.func_entry_point(id).ok()?).cloned() + }) { + max_change = max_change.max(target_ap_change + ap_change); + } else { + return Ok(()); + }; + } + self.known_ap_change_to_return.insert(idx, max_change); + Ok(()) + } + + /// Returns the topological ordering of the program statements for fully known ap-changes. + fn known_ap_change_topological_order(&self) -> Result, ApChangeError> { + get_topological_ordering( + false, + (0..self.program.statements.len()).map(StatementIdx), + self.program.statements.len(), + |idx| { + let mut res = vec![]; + for (ap_change, target) in self.get_branches(idx)? { + res.push(target); + if let ApChange::FunctionCall(id) = ap_change { + res.push(self.func_entry_point(&id)?); + } + } + Ok(res) + }, + ApChangeError::StatementOutOfBounds, + |_| unreachable!("Cycle isn't an error."), + ) + } + + /// Returns the topological ordering of the program statements where tracked ap changes give the + /// ordering. + fn tracked_ap_change_topological_order(&self) -> Result, ApChangeError> { + get_topological_ordering( + false, + (0..self.program.statements.len()).map(StatementIdx), + self.program.statements.len(), + |idx| { + Ok(self + .get_branches(idx)? + .into_iter() + .flat_map(|(ap_change, target)| match ap_change { + ApChange::Unknown => None, + ApChange::FunctionCall(id) => { + if self.function_ap_change.contains_key(&id) { + Some(target) + } else { + None + } + } + ApChange::Known(_) + | ApChange::DisableApTracking + | ApChange::FromMetadata + | ApChange::AtLocalsFinalization(_) + | ApChange::FinalizeLocals + | ApChange::EnableApTracking => Some(target), + }) + .collect()) + }, + ApChangeError::StatementOutOfBounds, + |_| unreachable!("Cycle isn't an error."), + ) + } + + /// Calculates the tracking information for a statement. + fn calc_tracking_info_for_statement(&mut self, idx: StatementIdx) -> Result<(), ApChangeError> { + for (ap_change, target) in self.get_branches(idx)? { + if matches!(ap_change, ApChange::EnableApTracking) { + self.tracking_info.insert( + target, + ApTrackingInfo { base: ApTrackingBase::EnableStatement(idx), ap_change: 0 }, + ); + continue; + } + let Some(mut base_info) = self.tracking_info.get(&idx).cloned() else { + continue; + }; + if let Some(ap_change) = self + .branch_ap_change(idx, &ap_change, |id| self.function_ap_change.get(id).cloned()) + { + base_info.ap_change += ap_change; + } else { + continue; + } + match self.tracking_info.entry(target) { + Entry::Occupied(mut e) => { + e.get_mut().ap_change = e.get().ap_change.max(base_info.ap_change); + } + Entry::Vacant(e) => { + e.insert(base_info); + } + } + } + Ok(()) + } + + /// Calculates the effective ap change for a statement, and the variables for ap alignment. + fn calc_effective_ap_change_and_variables_per_statement( + &mut self, + idx: StatementIdx, + ) -> Result<(), ApChangeError> { + let Some(base_info) = self.tracking_info.get(&idx).cloned() else { + return Ok(()); + }; + if matches!(self.program.get_statement(&idx), Some(Statement::Return(_))) { + if let ApTrackingBase::FunctionStart(id) = base_info.base { + if let Some(func_change) = self.function_ap_change.get(&id) { + self.effective_ap_change_from_base.insert(idx, *func_change); + } + } + return Ok(()); + } + let mut source_ap_change = None; + let mut paths_ap_change = vec![]; + for (ap_change, target) in self.get_branches(idx)? { + if matches!(ap_change, ApChange::EnableApTracking) { + continue; + } + let Some(change) = self + .branch_ap_change(idx, &ap_change, |id| self.function_ap_change.get(id).cloned()) + else { + source_ap_change = Some(base_info.ap_change); + continue; + }; + let Some(target_ap_change) = self.effective_ap_change_from_base.get(&target) else { + continue; + }; + let calc_ap_change = target_ap_change - change; + paths_ap_change.push((target, calc_ap_change)); + if let Some(source_ap_change) = &mut source_ap_change { + *source_ap_change = (*source_ap_change).min(calc_ap_change); + } else { + source_ap_change = Some(calc_ap_change); + } + } + if let Some(source_ap_change) = source_ap_change { + self.effective_ap_change_from_base.insert(idx, source_ap_change); + for (target, path_ap_change) in paths_ap_change { + self.variable_values.insert(target, path_ap_change - source_ap_change); + } + } + Ok(()) + } + + /// Gets the actual ap-change of a branch. + fn branch_ap_change( + &self, + idx: StatementIdx, + ap_change: &ApChange, + func_ap_change: impl Fn(&FunctionId) -> Option, + ) -> Option { + match ap_change { + ApChange::Unknown | ApChange::DisableApTracking => None, + ApChange::Known(x) => Some(*x), + ApChange::FromMetadata + | ApChange::AtLocalsFinalization(_) + | ApChange::EnableApTracking => Some(0), + ApChange::FinalizeLocals => Some(self.get_statement_locals(idx)), + ApChange::FunctionCall(id) => func_ap_change(id).map(|x| 2 + x), + } + } + + /// Returns the locals size for a statement. + fn get_statement_locals(&self, idx: StatementIdx) -> usize { + self.locals_size.get(&idx).cloned().unwrap_or_default() + } + + /// Returns the branches of a statement. + fn get_branches( + &self, + idx: StatementIdx, + ) -> Result, ApChangeError> { + Ok(match self.program.get_statement(&idx).unwrap() { + Statement::Invocation(invocation) => { + let libfunc = self.registry.get_libfunc(&invocation.libfunc_id)?; + core_libfunc_ap_change::core_libfunc_ap_change( + libfunc, + &InvocationApChangeInfoProviderForEqGen { + type_sizes: &self.type_sizes, + token_usages: |token_type| (self.token_usages)(idx, token_type), + }, + ) + .into_iter() + .zip(&invocation.branches) + .map(|(ap_change, branch_info)| (ap_change, idx.next(&branch_info.target))) + .collect() + } + Statement::Return(_) => vec![], + }) + } + + /// Returns the entry point of a function. + fn func_entry_point(&self, id: &FunctionId) -> Result { + Ok(self.registry.get_function(id)?.entry_point) + } +} + +/// Calculates ap change information for a given program. +pub fn calc_ap_changes usize>( + program: &Program, + token_usages: TokenUsages, +) -> Result { + let mut helper = ApChangeCalcHelper::new(program, token_usages)?; + helper.calc_locals_and_function_ap_changes()?; + let ap_tracked_topological_ordering = helper.tracked_ap_change_topological_order()?; + // Seting tracking info for function entry points. + for f in &program.funcs { + helper.tracking_info.insert( + f.entry_point, + ApTrackingInfo { base: ApTrackingBase::FunctionStart(f.id.clone()), ap_change: 0 }, + ); + } + for idx in ap_tracked_topological_ordering.iter().rev() { + helper.calc_tracking_info_for_statement(*idx)?; + } + for idx in ap_tracked_topological_ordering { + helper.calc_effective_ap_change_and_variables_per_statement(idx)?; + } + Ok(ApChangeInfo { + variable_values: helper.variable_values, + function_ap_change: helper.function_ap_change, + }) +} diff --git a/crates/cairo-lang-sierra-ap-change/src/lib.rs b/crates/cairo-lang-sierra-ap-change/src/lib.rs index 1a0cc2b0722..19007d5e1dc 100644 --- a/crates/cairo-lang-sierra-ap-change/src/lib.rs +++ b/crates/cairo-lang-sierra-ap-change/src/lib.rs @@ -14,6 +14,8 @@ use itertools::Itertools; use thiserror::Error; pub mod ap_change_info; +/// Direct linear computation of AP-Changes instead of equation solver. +pub mod compute; pub mod core_libfunc_ap_change; mod generate_equations; diff --git a/crates/cairo-lang-sierra-to-casm/src/metadata.rs b/crates/cairo-lang-sierra-to-casm/src/metadata.rs index eaa1605d02b..35f3be5f9b7 100644 --- a/crates/cairo-lang-sierra-to-casm/src/metadata.rs +++ b/crates/cairo-lang-sierra-to-casm/src/metadata.rs @@ -2,6 +2,7 @@ use cairo_lang_sierra::extensions::gas::CostTokenType; use cairo_lang_sierra::ids::FunctionId; use cairo_lang_sierra::program::Program; use cairo_lang_sierra_ap_change::ap_change_info::ApChangeInfo; +use cairo_lang_sierra_ap_change::compute::calc_ap_changes as linear_calc_ap_changes; use cairo_lang_sierra_ap_change::{calc_ap_changes, ApChangeError}; use cairo_lang_sierra_gas::gas_info::GasInfo; use cairo_lang_sierra_gas::{ @@ -37,6 +38,9 @@ pub struct MetadataComputationConfig { /// If true, uses a linear-time algorithm for calculating the gas, instead of solving /// equations. pub linear_gas_solver: bool, + /// If true, uses a linear-time algorithm for calculating ap changes, instead of solving + /// equations. + pub linear_ap_change_solver: bool, } /// Calculates the metadata for a Sierra program, with ap change info only. @@ -72,9 +76,11 @@ pub fn calc_metadata( pre_gas_info.assert_eq_variables(&pre_gas_info2); pre_gas_info.assert_eq_functions(&pre_gas_info2); - let ap_change_info = calc_ap_changes(program, |idx, token_type| { - pre_gas_info.variable_values[(idx, token_type)] as usize - })?; + let ap_change_info = + if config.linear_ap_change_solver { linear_calc_ap_changes } else { calc_ap_changes }( + program, + |idx, token_type| pre_gas_info.variable_values[(idx, token_type)] as usize, + )?; let post_function_set_costs = config .function_set_costs diff --git a/crates/cairo-lang-starknet/src/casm_contract_class.rs b/crates/cairo-lang-starknet/src/casm_contract_class.rs index 9c486be2aee..fbdf3152b4a 100644 --- a/crates/cairo-lang-starknet/src/casm_contract_class.rs +++ b/crates/cairo-lang-starknet/src/casm_contract_class.rs @@ -284,6 +284,7 @@ impl CasmContractClass { .map(|id| (id, [(CostTokenType::Const, ENTRY_POINT_COST)].into())) .collect(), linear_gas_solver: false, + linear_ap_change_solver: false, }; let metadata = calc_metadata(&program, metadata_computation_config)?; diff --git a/crates/cairo-lang-test-runner/src/lib.rs b/crates/cairo-lang-test-runner/src/lib.rs index e6837daaa4e..af84deeaeb9 100644 --- a/crates/cairo-lang-test-runner/src/lib.rs +++ b/crates/cairo-lang-test-runner/src/lib.rs @@ -262,7 +262,11 @@ pub fn run_tests( ) -> Result { let runner = SierraCasmRunner::new( sierra_program, - Some(MetadataComputationConfig { function_set_costs, linear_gas_solver: true }), + Some(MetadataComputationConfig { + function_set_costs, + linear_gas_solver: true, + linear_ap_change_solver: true, + }), contracts_info, ) .with_context(|| "Failed setting up runner.")?; diff --git a/tests/e2e_test.rs b/tests/e2e_test.rs index fe4cdb8a614..6b126f5f554 100644 --- a/tests/e2e_test.rs +++ b/tests/e2e_test.rs @@ -188,8 +188,11 @@ fn run_e2e_test( }; // Compute the metadata. - let mut metadata_config = - MetadataComputationConfig { function_set_costs: enforced_costs, linear_gas_solver: false }; + let mut metadata_config = MetadataComputationConfig { + function_set_costs: enforced_costs, + linear_gas_solver: false, + linear_ap_change_solver: false, + }; let metadata = calc_metadata(&sierra_program, metadata_config.clone()).unwrap(); // Compile to casm. @@ -201,10 +204,15 @@ fn run_e2e_test( OrderedHashMap::from([("casm".into(), casm), ("sierra_code".into(), sierra_program_str)]); if params.metadata_computation { metadata_config.linear_gas_solver = true; + metadata_config.linear_ap_change_solver = true; let metadata_no_solver = calc_metadata(&sierra_program, metadata_config).unwrap(); res.insert("gas_solution".into(), format!("{}", metadata.gas_info)); res.insert("gas_solution_no_solver".into(), format!("{}", metadata_no_solver.gas_info)); res.insert("ap_solution".into(), format!("{}", metadata.ap_change_info)); + res.insert( + "ap_solution_no_solver".into(), + format!("{}", metadata_no_solver.ap_change_info), + ); // Compile again, this time with the no-solver metadata. cairo_lang_sierra_to_casm::compiler::compile(&sierra_program, &metadata_no_solver, true) diff --git a/tests/e2e_test_data/metadata_computation b/tests/e2e_test_data/metadata_computation index 3f59b992a19..38417ff5b45 100644 --- a/tests/e2e_test_data/metadata_computation +++ b/tests/e2e_test_data/metadata_computation @@ -67,6 +67,13 @@ test::bar: OrderedHashMap({Const: 10000}) test::bar: 1 test::foo: 13 +//! > ap_solution_no_solver +#2: 3 +#8: 3 + +test::foo: 13 +test::bar: 1 + //! > sierra_code type felt252 = felt252 [storable: true, drop: true, dup: true, zero_sized: false]; type Unit = Struct [storable: true, drop: true, dup: true, zero_sized: true]; @@ -269,6 +276,14 @@ test::bar: OrderedHashMap({Const: 10000}) test::bar: 1 test::foo: 12 +//! > ap_solution_no_solver +#1: 1 +#13: 4 +#19: 2 + +test::foo: 12 +test::bar: 1 + //! > sierra_code type RangeCheck = RangeCheck [storable: true, drop: false, dup: false, zero_sized: false]; type Unit = Struct [storable: true, drop: true, dup: true, zero_sized: true]; @@ -462,6 +477,14 @@ test::bar: OrderedHashMap({Const: 2000}) test::bar: 1 test::foo: 10 +//! > ap_solution_no_solver +#12: 4 +#18: 6 +#24: 4 + +test::foo: 10 +test::bar: 1 + //! > sierra_code type felt252 = felt252 [storable: true, drop: true, dup: true, zero_sized: false]; type Unit = Struct [storable: true, drop: true, dup: true, zero_sized: true]; @@ -635,6 +658,13 @@ test::bar: OrderedHashMap({Const: 10000}) test::bar: 1 test::foo: 9 +//! > ap_solution_no_solver +#15: 2 +#29: 4 + +test::foo: 9 +test::bar: 1 + //! > sierra_code type felt252 = felt252 [storable: true, drop: true, dup: true, zero_sized: false]; type Unit = Struct [storable: true, drop: true, dup: true, zero_sized: true];