Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

feat(fee): calculates messages size field #1290

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions crates/blockifier/src/execution/call_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use starknet_api::state::StorageKey;
use starknet_api::transaction::{EventContent, L2ToL1Payload};

use crate::execution::entry_point::CallEntryPoint;
use crate::fee::gas_usage::get_message_segment_length;
use crate::state::cached_state::StorageEntry;
use crate::transaction::errors::TransactionExecutionError;
use crate::transaction::objects::TransactionExecutionResult;
Expand All @@ -28,6 +29,29 @@ pub struct OrderedEvent {
pub event: EventContent,
}

#[derive(Debug, Default, Eq, PartialEq)]
pub struct MessageL1CostInfo {
pub l2_to_l1_payload_lengths: Vec<usize>,
pub message_segment_length: usize,
}

impl MessageL1CostInfo {
pub fn calculate<'a>(
call_infos: impl Iterator<Item = &'a CallInfo>,
l1_handler_payload_size: Option<usize>,
) -> TransactionExecutionResult<Self> {
let mut l2_to_l1_payload_lengths = Vec::new();
for call_info in call_infos {
l2_to_l1_payload_lengths.extend(call_info.get_sorted_l2_to_l1_payload_lengths()?);
}

let message_segment_length =
get_message_segment_length(&l2_to_l1_payload_lengths, l1_handler_payload_size);

Ok(Self { l2_to_l1_payload_lengths, message_segment_length })
}
}

#[cfg_attr(test, derive(Clone))]
#[derive(Debug, Default, Eq, PartialEq)]
pub struct MessageToL1 {
Expand Down Expand Up @@ -99,9 +123,9 @@ impl CallInfo {

/// Returns a list of Starknet L2ToL1Payload length collected during the execution, sorted
/// by the order in which they were sent.
pub fn get_sorted_l2_to_l1_payloads_length(&self) -> TransactionExecutionResult<Vec<usize>> {
pub fn get_sorted_l2_to_l1_payload_lengths(&self) -> TransactionExecutionResult<Vec<usize>> {
let n_messages = self.into_iter().map(|call| call.execution.l2_to_l1_messages.len()).sum();
let mut starknet_l2_to_l1_payloads_length: Vec<Option<usize>> = vec![None; n_messages];
let mut starknet_l2_to_l1_payload_lengths: Vec<Option<usize>> = vec![None; n_messages];

for call_info in self.into_iter() {
for ordered_message_content in &call_info.execution.l2_to_l1_messages {
Expand All @@ -113,12 +137,12 @@ impl CallInfo {
max_order: n_messages,
});
}
starknet_l2_to_l1_payloads_length[message_order] =
starknet_l2_to_l1_payload_lengths[message_order] =
Some(ordered_message_content.message.payload.0.len());
}
}

starknet_l2_to_l1_payloads_length.into_iter().enumerate().try_fold(
starknet_l2_to_l1_payload_lengths.into_iter().enumerate().try_fold(
Vec::new(),
|mut acc, (i, option)| match option {
Some(value) => {
Expand Down
42 changes: 11 additions & 31 deletions crates/blockifier/src/fee/gas_usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use starknet_api::transaction::Fee;
use super::fee_utils::{calculate_tx_l1_gas_usages, get_fee_by_l1_gas_usage};
use crate::abi::constants;
use crate::block_context::BlockContext;
use crate::execution::call_info::CallInfo;
use crate::execution::call_info::{CallInfo, MessageL1CostInfo};
use crate::fee::eth_gas_constants;
use crate::fee::os_resources::OS_RESOURCES;
use crate::state::cached_state::StateChangesCount;
Expand All @@ -19,23 +19,6 @@ use crate::transaction::objects::{
#[path = "gas_usage_test.rs"]
pub mod test;

// TODO(Ayelet, 10/1/2024): Use to calculate message segment length in transaction_executer's
// execute
fn calculate_l2_to_l1_payloads_length_and_message_segment_length<'a>(
call_infos: impl Iterator<Item = &'a CallInfo>,
l1_handler_payload_size: Option<usize>,
) -> TransactionExecutionResult<(Vec<usize>, usize)> {
let mut l2_to_l1_payloads_length = Vec::new();
for call_info in call_infos {
l2_to_l1_payloads_length.extend(call_info.get_sorted_l2_to_l1_payloads_length()?);
}

let message_segment_length =
get_message_segment_length(&l2_to_l1_payloads_length, l1_handler_payload_size);

Ok((l2_to_l1_payloads_length, message_segment_length))
}

pub fn calculate_tx_gas_and_blob_gas_usage<'a>(
call_infos: impl Iterator<Item = &'a CallInfo>,
state_changes_count: StateChangesCount,
Expand Down Expand Up @@ -75,29 +58,26 @@ pub fn calculate_tx_gas_usage_messages<'a>(
call_infos: impl Iterator<Item = &'a CallInfo>,
l1_handler_payload_size: Option<usize>,
) -> TransactionExecutionResult<usize> {
let (l2_to_l1_payloads_length, residual_message_segment_length) =
calculate_l2_to_l1_payloads_length_and_message_segment_length(
call_infos,
l1_handler_payload_size,
)?;
let MessageL1CostInfo { l2_to_l1_payload_lengths, message_segment_length } =
MessageL1CostInfo::calculate(call_infos, l1_handler_payload_size)?;

let n_l2_to_l1_messages = l2_to_l1_payloads_length.len();
let n_l2_to_l1_messages = l2_to_l1_payload_lengths.len();
let n_l1_to_l2_messages = usize::from(l1_handler_payload_size.is_some());

let starknet_gas_usage =
// Starknet's updateState gets the message segment as an argument.
residual_message_segment_length * eth_gas_constants::GAS_PER_MEMORY_WORD
message_segment_length * eth_gas_constants::GAS_PER_MEMORY_WORD
// Starknet's updateState increases a (storage) counter for each L2-to-L1 message.
+ n_l2_to_l1_messages * eth_gas_constants::GAS_PER_ZERO_TO_NONZERO_STORAGE_SET
// Starknet's updateState decreases a (storage) counter for each L1-to-L2 consumed message.
// (Note that we will probably get a refund of 15,000 gas for each consumed message but we
// ignore it since refunded gas cannot be used for the current transaction execution).
+ n_l1_to_l2_messages * eth_gas_constants::GAS_PER_COUNTER_DECREASE
+ get_consumed_message_to_l2_emissions_cost(l1_handler_payload_size)
+ get_log_message_to_l1_emissions_cost(&l2_to_l1_payloads_length);
+ get_log_message_to_l1_emissions_cost(&l2_to_l1_payload_lengths);

let sharp_gas_usage_without_data =
residual_message_segment_length * eth_gas_constants::SHARP_GAS_PER_MEMORY_WORD;
message_segment_length * eth_gas_constants::SHARP_GAS_PER_MEMORY_WORD;

Ok(starknet_gas_usage + sharp_gas_usage_without_data)
}
Expand Down Expand Up @@ -152,12 +132,12 @@ pub fn get_onchain_data_cost(state_changes_count: StateChangesCount) -> usize {
/// a transaction with the given parameters to a batch. Note that constant cells - such as the one
/// that holds the segment size - are not counted.
pub fn get_message_segment_length(
l2_to_l1_payloads_length: &[usize],
l2_to_l1_payload_lengths: &[usize],
l1_handler_payload_size: Option<usize>,
) -> usize {
// Add L2-to-L1 message segment length; for each message, the OS outputs the following:
// to_address, from_address, payload_size, payload.
let mut message_segment_length = l2_to_l1_payloads_length
let mut message_segment_length = l2_to_l1_payload_lengths
.iter()
.map(|payload_length| constants::L2_TO_L1_MSG_HEADER_SIZE + payload_length)
.sum();
Expand Down Expand Up @@ -189,8 +169,8 @@ pub fn get_consumed_message_to_l2_emissions_cost(l1_handler_payload_size: Option
}

/// Returns the cost of LogMessageToL1 event emissions caused by the given messages payload length.
pub fn get_log_message_to_l1_emissions_cost(l2_to_l1_payloads_length: &[usize]) -> usize {
l2_to_l1_payloads_length
pub fn get_log_message_to_l1_emissions_cost(l2_to_l1_payload_lengths: &[usize]) -> usize {
l2_to_l1_payload_lengths
.iter()
.map(|length| {
get_event_emission_cost(
Expand Down
12 changes: 6 additions & 6 deletions crates/blockifier/src/fee/gas_usage_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ fn test_calculate_tx_gas_usage_basic() {
call_infos.push(call_info);
}

// l2_to_l1_payloads_length is [0, 1, 2, 3]
// l2_to_l1_payload_lengths is [0, 1, 2, 3]
let call_infos_iter = call_infos.iter();
let l2_to_l1_payloads_length: Vec<usize> = call_infos_iter
let l2_to_l1_payload_lengths: Vec<usize> = call_infos_iter
.clone()
.flat_map(|call_info| call_info.get_sorted_l2_to_l1_payloads_length().unwrap())
.flat_map(|call_info| call_info.get_sorted_l2_to_l1_payload_lengths().unwrap())
.collect();

let l2_to_l1_state_changes_count = StateChangesCount {
Expand All @@ -170,11 +170,11 @@ fn test_calculate_tx_gas_usage_basic() {
l2_to_l1_messages_gas_and_blob_gas_usage;

// Manual calculation.
let message_segment_length = get_message_segment_length(&l2_to_l1_payloads_length, None);
let n_l2_to_l1_messages = l2_to_l1_payloads_length.len();
let message_segment_length = get_message_segment_length(&l2_to_l1_payload_lengths, None);
let n_l2_to_l1_messages = l2_to_l1_payload_lengths.len();
let manual_starknet_gas_usage = message_segment_length * eth_gas_constants::GAS_PER_MEMORY_WORD
+ n_l2_to_l1_messages * eth_gas_constants::GAS_PER_ZERO_TO_NONZERO_STORAGE_SET
+ get_log_message_to_l1_emissions_cost(&l2_to_l1_payloads_length);
+ get_log_message_to_l1_emissions_cost(&l2_to_l1_payload_lengths);
let manual_sharp_gas_usage = message_segment_length
* eth_gas_constants::SHARP_GAS_PER_MEMORY_WORD
+ get_onchain_data_cost(l2_to_l1_state_changes_count);
Expand Down
3 changes: 1 addition & 2 deletions crates/blockifier/src/transaction/transaction_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ impl<S: StateReader> ExecutableTransaction<S> for L1HandlerTransaction {
let mut remaining_gas = Transaction::initial_gas();
let execute_call_info =
self.run_execute(state, &mut execution_resources, &mut context, &mut remaining_gas)?;
// The calldata includes the "from" field, which is not a part of the payload.
let l1_handler_payload_size = self.tx.calldata.0.len() - 1;
let l1_handler_payload_size = self.payload_size();

let ActualCost { actual_fee, actual_resources } =
ActualCost::builder_for_l1_handler(block_context, tx_context, l1_handler_payload_size)
Expand Down
5 changes: 5 additions & 0 deletions crates/blockifier/src/transaction/transactions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,11 @@ impl L1HandlerTransaction {
max_fee: Fee::default(),
})
}

pub fn payload_size(&self) -> usize {
// The calldata includes the "from" field, which is not a part of the payload.
self.tx.calldata.0.len() - 1
}
}

impl HasRelatedFeeType for L1HandlerTransaction {
Expand Down
21 changes: 18 additions & 3 deletions crates/native_blockifier/src/transaction_executor.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::collections::{HashMap, HashSet};
use std::vec::IntoIter;

use blockifier::block_context::BlockContext;
use blockifier::block_execution::pre_process_block;
use blockifier::execution::call_info::CallInfo;
use blockifier::execution::call_info::{CallInfo, MessageL1CostInfo};
use blockifier::execution::entry_point::ExecutionResources;
use blockifier::fee::actual_cost::ActualCost;
use blockifier::state::cached_state::{
Expand Down Expand Up @@ -74,7 +75,12 @@ impl<S: StateReader> TransactionExecutor<S> {
charge_fee: bool,
) -> NativeBlockifierResult<(PyTransactionExecutionInfo, PyBouncerInfo)> {
let tx: Transaction = py_tx(tx, raw_contract_class)?;

let l1_handler_payload_size: usize =
if let Transaction::L1HandlerTransaction(l1_handler_tx) = &tx {
l1_handler_tx.payload_size()
} else {
0
};
let mut tx_executed_class_hashes = HashSet::<ClassHash>::new();
let mut tx_visited_storage_entries = HashSet::<StorageEntry>::new();
let mut transactional_state = CachedState::create_transactional(&mut self.state);
Expand All @@ -88,9 +94,18 @@ impl<S: StateReader> TransactionExecutor<S> {
// TODO(Elin, 01/06/2024): consider traversing the calls to collect data once.
tx_executed_class_hashes.extend(tx_execution_info.get_executed_class_hashes());
tx_visited_storage_entries.extend(tx_execution_info.get_visited_storage_entries());
let call_infos: IntoIter<&CallInfo> =
[&tx_execution_info.validate_call_info, &tx_execution_info.execute_call_info]
.iter()
.filter_map(|&call_info| call_info.as_ref())
.collect::<Vec<&CallInfo>>()
.into_iter();
let MessageL1CostInfo { l2_to_l1_payload_lengths: _, message_segment_length } =
MessageL1CostInfo::calculate(call_infos, Some(l1_handler_payload_size))?;

// TODO(Elin, 01/06/2024): consider moving Bouncer logic to a function.
let py_tx_execution_info = PyTransactionExecutionInfo::from(tx_execution_info);

let mut additional_os_resources = get_casm_hash_calculation_resources(
&mut transactional_state,
&self.executed_class_hashes,
Expand All @@ -101,7 +116,7 @@ impl<S: StateReader> TransactionExecutor<S> {
&tx_visited_storage_entries,
)?;
let py_bouncer_info = PyBouncerInfo {
message_segment_length: 0,
message_segment_length,
state_diff_size: 0,
additional_os_resources: PyVmExecutionResources::from(additional_os_resources),
};
Expand Down
Loading