diff --git a/core/lib/dal/src/protocol_versions_dal.rs b/core/lib/dal/src/protocol_versions_dal.rs index 3382d8c836e5..fcc756e30069 100644 --- a/core/lib/dal/src/protocol_versions_dal.rs +++ b/core/lib/dal/src/protocol_versions_dal.rs @@ -190,6 +190,43 @@ impl ProtocolVersionsDal<'_, '_> { ProtocolVersionId::try_from(row.id as u16).map_err(|err| sqlx::Error::Decode(err.into())) } + /// Returns base system contracts' hashes. Prefer `load_base_system_contracts_by_version_id` if + /// you also want to load the contracts themselves AND expect the contracts to be in the DB + /// already. + pub async fn get_base_system_contract_hashes_by_version_id( + &mut self, + version_id: u16, + ) -> anyhow::Result> { + let row = sqlx::query!( + r#" + SELECT + bootloader_code_hash, + default_account_code_hash, + evm_emulator_code_hash + FROM + protocol_versions + WHERE + id = $1 + "#, + i32::from(version_id) + ) + .instrument("get_base_system_contract_hashes_by_version_id") + .with_arg("version_id", &version_id) + .fetch_optional(self.storage) + .await + .context("cannot fetch system contract hashes")?; + + Ok(if let Some(row) = row { + Some(BaseSystemContractsHashes { + bootloader: H256::from_slice(&row.bootloader_code_hash), + default_aa: H256::from_slice(&row.default_account_code_hash), + evm_emulator: row.evm_emulator_code_hash.as_deref().map(H256::from_slice), + }) + } else { + None + }) + } + pub async fn load_base_system_contracts_by_version_id( &mut self, version_id: u16, @@ -207,7 +244,9 @@ impl ProtocolVersionsDal<'_, '_> { "#, i32::from(version_id) ) - .fetch_optional(self.storage.conn()) + .instrument("load_base_system_contracts_by_version_id") + .with_arg("version_id", &version_id) + .fetch_optional(self.storage) .await .context("cannot fetch system contract hashes")?; diff --git a/core/node/node_sync/src/external_io.rs b/core/node/node_sync/src/external_io.rs index 5e3a5ce9f46e..a0be233a002e 100644 --- a/core/node/node_sync/src/external_io.rs +++ b/core/node/node_sync/src/external_io.rs @@ -104,6 +104,63 @@ impl ExternalIO { } }) } + + async fn ensure_protocol_version_is_saved( + &self, + protocol_version: ProtocolVersionId, + ) -> anyhow::Result<()> { + let base_system_contract_hashes = self + .pool + .connection_tagged("sync_layer") + .await? + .protocol_versions_dal() + .get_base_system_contract_hashes_by_version_id(protocol_version as u16) + .await?; + if base_system_contract_hashes.is_some() { + return Ok(()); + } + tracing::info!("Fetching protocol version {protocol_version:?} from the main node"); + + let protocol_version = self + .main_node_client + .fetch_protocol_version(protocol_version) + .await + .context("failed to fetch protocol version from the main node")? + .context("protocol version is missing on the main node")?; + let minor = protocol_version + .minor_version() + .context("Missing minor protocol version")?; + let bootloader_code_hash = protocol_version + .bootloader_code_hash() + .context("Missing bootloader code hash")?; + let default_account_code_hash = protocol_version + .default_account_code_hash() + .context("Missing default account code hash")?; + let evm_emulator_code_hash = protocol_version.evm_emulator_code_hash(); + let l2_system_upgrade_tx_hash = protocol_version.l2_system_upgrade_tx_hash(); + self.pool + .connection_tagged("sync_layer") + .await? + .protocol_versions_dal() + .save_protocol_version( + ProtocolSemanticVersion { + minor: minor + .try_into() + .context("cannot convert protocol version")?, + patch: VersionPatch(0), + }, + protocol_version.timestamp, + Default::default(), // verification keys are unused for EN + BaseSystemContractsHashes { + bootloader: bootloader_code_hash, + default_aa: default_account_code_hash, + evm_emulator: evm_emulator_code_hash, + }, + l2_system_upgrade_tx_hash, + ) + .await?; + Ok(()) + } } impl IoSealCriteria for ExternalIO { @@ -254,6 +311,8 @@ impl StateKeeperIO for ExternalIO { cursor.next_l2_block ); + self.ensure_protocol_version_is_saved(params.protocol_version) + .await?; self.pool .connection_tagged("sync_layer") .await? @@ -261,7 +320,7 @@ impl StateKeeperIO for ExternalIO { .insert_l1_batch(UnsealedL1BatchHeader { number: cursor.l1_batch, timestamp: params.first_l2_block.timestamp, - protocol_version: None, + protocol_version: Some(params.protocol_version), fee_address: params.operator_address, fee_input: params.fee_input, }) @@ -351,63 +410,21 @@ impl StateKeeperIO for ExternalIO { .connection_tagged("sync_layer") .await? .protocol_versions_dal() - .load_base_system_contracts_by_version_id(protocol_version as u16) - .await - .context("failed loading base system contracts")?; - - if let Some(contracts) = base_system_contracts { - return Ok(contracts); - } - tracing::info!("Fetching protocol version {protocol_version:?} from the main node"); - - let protocol_version = self - .main_node_client - .fetch_protocol_version(protocol_version) - .await - .context("failed to fetch protocol version from the main node")? - .context("protocol version is missing on the main node")?; - let minor = protocol_version - .minor_version() - .context("Missing minor protocol version")?; - let bootloader_code_hash = protocol_version - .bootloader_code_hash() - .context("Missing bootloader code hash")?; - let default_account_code_hash = protocol_version - .default_account_code_hash() - .context("Missing default account code hash")?; - let evm_emulator_code_hash = protocol_version.evm_emulator_code_hash(); - let l2_system_upgrade_tx_hash = protocol_version.l2_system_upgrade_tx_hash(); - self.pool - .connection_tagged("sync_layer") + .get_base_system_contract_hashes_by_version_id(protocol_version as u16) .await? - .protocol_versions_dal() - .save_protocol_version( - ProtocolSemanticVersion { - minor: minor - .try_into() - .context("cannot convert protocol version")?, - patch: VersionPatch(0), - }, - protocol_version.timestamp, - Default::default(), // verification keys are unused for EN - BaseSystemContractsHashes { - bootloader: bootloader_code_hash, - default_aa: default_account_code_hash, - evm_emulator: evm_emulator_code_hash, - }, - l2_system_upgrade_tx_hash, - ) - .await?; + .with_context(|| { + format!("Cannot load base system contracts' hashes for {protocol_version:?}. They should already be present") + })?; let bootloader = self - .get_base_system_contract(bootloader_code_hash, cursor.next_l2_block) + .get_base_system_contract(base_system_contracts.bootloader, cursor.next_l2_block) .await .with_context(|| format!("cannot fetch bootloader code for {protocol_version:?}"))?; let default_aa = self - .get_base_system_contract(default_account_code_hash, cursor.next_l2_block) + .get_base_system_contract(base_system_contracts.default_aa, cursor.next_l2_block) .await .with_context(|| format!("cannot fetch default AA code for {protocol_version:?}"))?; - let evm_emulator = if let Some(hash) = evm_emulator_code_hash { + let evm_emulator = if let Some(hash) = base_system_contracts.evm_emulator { Some( self.get_base_system_contract(hash, cursor.next_l2_block) .await @@ -459,3 +476,97 @@ impl StateKeeperIO for ExternalIO { Ok(hash) } } + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use zksync_dal::{ConnectionPool, CoreDal}; + use zksync_node_genesis::{insert_genesis_batch, GenesisParams}; + use zksync_state_keeper::{io::L1BatchParams, L2BlockParams, StateKeeperIO}; + use zksync_types::{ + api, fee_model::BatchFeeInput, L1BatchNumber, L2BlockNumber, L2ChainId, ProtocolVersionId, + H256, + }; + + use crate::{sync_action::SyncAction, testonly::MockMainNodeClient, ActionQueue, ExternalIO}; + + #[tokio::test] + async fn insert_batch_with_protocol_version() { + // Whenever ExternalIO inserts an unsealed batch into DB it should populate it with protocol + // version and make sure that it is present in the DB (i.e. fetch it from main node if not). + let pool = ConnectionPool::test_pool().await; + let mut conn = pool.connection().await.unwrap(); + insert_genesis_batch(&mut conn, &GenesisParams::mock()) + .await + .unwrap(); + let (actions_sender, action_queue) = ActionQueue::new(); + let mut client = MockMainNodeClient::default(); + let next_protocol_version = api::ProtocolVersion { + minor_version: Some(ProtocolVersionId::next() as u16), + timestamp: 1, + bootloader_code_hash: Some(H256::repeat_byte(1)), + default_account_code_hash: Some(H256::repeat_byte(1)), + evm_emulator_code_hash: Some(H256::repeat_byte(1)), + ..api::ProtocolVersion::default() + }; + client.insert_protocol_version(next_protocol_version.clone()); + let mut external_io = ExternalIO::new( + pool.clone(), + action_queue, + Box::new(client), + L2ChainId::default(), + ) + .unwrap(); + + let (cursor, _) = external_io.initialize().await.unwrap(); + let params = L1BatchParams { + protocol_version: ProtocolVersionId::next(), + validation_computational_gas_limit: u32::MAX, + operator_address: Default::default(), + fee_input: BatchFeeInput::pubdata_independent(2, 3, 4), + first_l2_block: L2BlockParams { + timestamp: 1, + virtual_blocks: 1, + }, + }; + actions_sender + .push_action_unchecked(SyncAction::OpenBatch { + params: params.clone(), + number: L1BatchNumber(1), + first_l2_block_number: L2BlockNumber(1), + }) + .await + .unwrap(); + let fetched_params = external_io + .wait_for_new_batch_params(&cursor, Duration::from_secs(10)) + .await + .unwrap() + .unwrap(); + assert_eq!(fetched_params, params); + + // Verify that the next protocol version is in DB + let fetched_protocol_version = conn + .protocol_versions_dal() + .get_protocol_version_with_latest_patch(ProtocolVersionId::next()) + .await + .unwrap() + .unwrap(); + assert_eq!( + fetched_protocol_version.version.minor as u16, + next_protocol_version.minor_version.unwrap() + ); + + // Verify that the unsealed batch has protocol version + let unsealed_batch = conn + .blocks_dal() + .get_unsealed_l1_batch() + .await + .unwrap() + .unwrap(); + assert_eq!( + unsealed_batch.protocol_version, + Some(fetched_protocol_version.version.minor) + ); + } +}