From d0a0bf669e505e8e978b18a12f3182a070c13971 Mon Sep 17 00:00:00 2001 From: Tao Zhu Date: Thu, 24 Mar 2022 01:02:16 -0500 Subject: [PATCH] cache computed fee-per-cu --- core/src/unprocessed_packet_batches.rs | 135 ++++++++++++++++++++----- 1 file changed, 111 insertions(+), 24 deletions(-) diff --git a/core/src/unprocessed_packet_batches.rs b/core/src/unprocessed_packet_batches.rs index 73c80e81ac3445..5f2fb3e889abca 100644 --- a/core/src/unprocessed_packet_batches.rs +++ b/core/src/unprocessed_packet_batches.rs @@ -5,6 +5,7 @@ use { solana_program_runtime::compute_budget::ComputeBudget, solana_runtime::bank::Bank, solana_sdk::{ + clock::Slot, hash::Hash, message::{ v0::{self}, @@ -22,6 +23,20 @@ use { }, }; +/// FeePerCu is valid by up to X slots +#[derive(Debug, Default)] +struct FeePerCu { + fee_per_cu: u64, + slot: Slot, +} + +impl FeePerCu { + fn too_old(&self, slot: &Slot) -> bool { + const MAX_SLOT_AGE: Slot = 1; + slot - &self.slot >= MAX_SLOT_AGE + } +} + /// Holds deserialized messages, as well as computed message_hash and other things needed to create /// SanitizedTransaction #[derive(Debug, Default)] @@ -34,6 +49,8 @@ pub struct DeserializedPacket { #[allow(dead_code)] is_simple_vote: bool, + + fee_per_cu: Option, } /// Defines the type of entry in `UnprocessedPacketBatches`, it holds original packet_batch @@ -128,7 +145,7 @@ impl UnprocessedPacketBatches { /// prioritize unprocessed packets by their fee/CU then by sender's stakes pub fn prioritize_by_fee_then_stakes( - &self, + &mut self, working_bank: Option>, ) -> Vec { let (stakes, locators) = self.get_stakes_and_locators(); @@ -194,14 +211,14 @@ impl UnprocessedPacketBatches { /// Index `locators` by their transaction's fee-per-cu value; For transactions /// have same fee-per-cu, their relative order remains same (eg. in sender_stake order). fn prioritize_by_fee_per_cu( - &self, + &mut self, locators: &Vec, bank: Option>, ) -> Vec { let mut fee_buckets = BTreeMap::>::new(); for locator in locators { // if unable to compute fee-per-cu for the packet, put it to the `0` bucket - let fee_per_cu = self.compute_fee_per_cu(locator, &bank).unwrap_or(0); + let fee_per_cu = self.get_computed_fee_per_cu(locator, &bank).unwrap_or(0); let bucket = fee_buckets .entry(fee_per_cu) @@ -215,29 +232,71 @@ impl UnprocessedPacketBatches { .collect() } - /// Computes `(addition_fee + base_fee / requested_cu)` for packet referenced by `PacketLocator` - fn compute_fee_per_cu(&self, locator: &PacketLocator, bank: &Option>) -> Option { - if let Some(bank) = bank { - let deserialized_packet_batch = self.get(locator.batch_index)?; - let deserialized_packet = deserialized_packet_batch - .unprocessed_packets - .get(&locator.packet_index)?; - let sanitized_message = Self::sanitize_message( - &deserialized_packet.versioned_transaction.message, - bank.as_ref(), - )?; - let total_fee = bank.get_fee_for_message(&sanitized_message)?; - - // TODO refactor `bank.get_fee_for_message()` to return both fee and CUs to avoid - // calling ComputeBudget twice. - let mut compute_budget = ComputeBudget::default(); - let _ = compute_budget - .process_message(&sanitized_message, false) - .ok()?; - - Some(total_fee / compute_budget.max_units) + /// get cached fee_per_cu for transaction referenced by `locator`, if cached value is + /// too old for current `bank`, or no cached value, then (re)compute and cache. + fn get_computed_fee_per_cu( + &mut self, + locator: &PacketLocator, + bank: &Option>, + ) -> Option { + if bank.is_none() { + return None; + } + let bank = bank.as_ref().unwrap(); + let deserialized_packet = self.locate_packet_mut(locator)?; + if let Some(cached_fee_per_cu) = + Self::get_cached_fee_per_cu(&deserialized_packet, &bank.slot()) + { + Some(cached_fee_per_cu) } else { + let computed_fee_per_cu = Self::compute_fee_per_cu(&deserialized_packet, bank); + if let Some(computed_fee_per_cu) = computed_fee_per_cu { + deserialized_packet.fee_per_cu = Some(FeePerCu { + fee_per_cu: computed_fee_per_cu, + slot: bank.slot(), + }); + } + computed_fee_per_cu + } + } + + #[allow(dead_code)] + fn locate_packet(&self, locator: &PacketLocator) -> Option<&DeserializedPacket> { + let deserialized_packet_batch = self.get(locator.batch_index)?; + deserialized_packet_batch + .unprocessed_packets + .get(&locator.packet_index) + } + + fn locate_packet_mut(&mut self, locator: &PacketLocator) -> Option<&mut DeserializedPacket> { + let deserialized_packet_batch = self.get_mut(locator.batch_index)?; + deserialized_packet_batch + .unprocessed_packets + .get_mut(&locator.packet_index) + } + + /// Computes `(addition_fee + base_fee / requested_cu)` for packet referenced by `PacketLocator` + fn compute_fee_per_cu(deserialized_packet: &DeserializedPacket, bank: &Bank) -> Option { + let sanitized_message = + Self::sanitize_message(&deserialized_packet.versioned_transaction.message, bank)?; + let total_fee = bank.get_fee_for_message(&sanitized_message)?; + + // TODO refactor `bank.get_fee_for_message()` to return both fee and CUs to avoid + // calling ComputeBudget twice. + let mut compute_budget = ComputeBudget::default(); + let _ = compute_budget + .process_message(&sanitized_message, false) + .ok()?; + + Some(total_fee / compute_budget.max_units) + } + + fn get_cached_fee_per_cu(deserialized_packet: &DeserializedPacket, slot: &Slot) -> Option { + let cached_fee_per_cu = deserialized_packet.fee_per_cu.as_ref()?; + if cached_fee_per_cu.too_old(slot) { None + } else { + Some(cached_fee_per_cu.fee_per_cu) } } @@ -400,6 +459,7 @@ impl DeserializedPacketBatch { versioned_transaction, message_hash, is_simple_vote, + fee_per_cu: None, }) } else { None @@ -786,4 +846,31 @@ mod tests { assert_eq!(expected_locators, prioritized_locators); } } + + #[test] + fn test_get_cached_fee_per_cu() { + let mut deserialized_packet = DeserializedPacket::default(); + let slot: Slot = 100; + + // assert default deserialized_packet has no cached fee-per-cu + assert!( + UnprocessedPacketBatches::get_cached_fee_per_cu(&deserialized_packet, &slot).is_none() + ); + + // cache fee-per-cu with slot 100 + let fee_per_cu = 1_000u64; + deserialized_packet.fee_per_cu = Some(FeePerCu { fee_per_cu, slot }); + + // assert cache fee-per-cu is available for same slot + assert_eq!( + fee_per_cu, + UnprocessedPacketBatches::get_cached_fee_per_cu(&deserialized_packet, &slot).unwrap() + ); + + // assert cached value became too old + assert!( + UnprocessedPacketBatches::get_cached_fee_per_cu(&deserialized_packet, &(slot + 1)) + .is_none() + ); + } }