diff --git a/client/rpc-spec-v2/src/chain_head/api.rs b/client/rpc-spec-v2/src/chain_head/api.rs index c002b75efe037..682cd690dd10c 100644 --- a/client/rpc-spec-v2/src/chain_head/api.rs +++ b/client/rpc-spec-v2/src/chain_head/api.rs @@ -119,4 +119,31 @@ pub trait ChainHeadApi { /// This method is unstable and subject to change in the future. #[method(name = "chainHead_unstable_unpin", blocking)] fn chain_head_unstable_unpin(&self, follow_subscription: String, hash: Hash) -> RpcResult<()>; + + /// Resumes a storage fetch started with `chainHead_storage` after it has generated an + /// `operationWaitingForContinue` event. + /// + /// # Unstable + /// + /// This method is unstable and subject to change in the future. + #[method(name = "chainHead_unstable_continue", blocking)] + fn chain_head_unstable_continue( + &self, + follow_subscription: String, + operation_id: String, + ) -> RpcResult<()>; + + /// Stops an operation started with chainHead_unstable_body, chainHead_unstable_call, or + /// chainHead_unstable_storage. If the operation was still in progress, this interrupts it. If + /// the operation was already finished, this call has no effect. + /// + /// # Unstable + /// + /// This method is unstable and subject to change in the future. + #[method(name = "chainHead_unstable_stopOperation", blocking)] + fn chain_head_unstable_stop_operation( + &self, + follow_subscription: String, + operation_id: String, + ) -> RpcResult<()>; } diff --git a/client/rpc-spec-v2/src/chain_head/chain_head.rs b/client/rpc-spec-v2/src/chain_head/chain_head.rs index 79cf251f18068..bae7c84df0ed9 100644 --- a/client/rpc-spec-v2/src/chain_head/chain_head.rs +++ b/client/rpc-spec-v2/src/chain_head/chain_head.rs @@ -61,6 +61,9 @@ pub struct ChainHeadConfig { pub subscription_max_pinned_duration: Duration, /// The maximum number of ongoing operations per subscription. pub subscription_max_ongoing_operations: usize, + /// The maximum number of items reported by the `chainHead_storage` before + /// pagination is required. + pub operation_max_storage_items: usize, } /// Maximum pinned blocks across all connections. @@ -78,12 +81,17 @@ const MAX_PINNED_DURATION: Duration = Duration::from_secs(60); /// Note: The lower limit imposed by the spec is 16. const MAX_ONGOING_OPERATIONS: usize = 16; +/// The maximum number of items the `chainHead_storage` can return +/// before paginations is required. +const MAX_STORAGE_ITER_ITEMS: usize = 5; + impl Default for ChainHeadConfig { fn default() -> Self { ChainHeadConfig { global_max_pinned_blocks: MAX_PINNED_BLOCKS, subscription_max_pinned_duration: MAX_PINNED_DURATION, subscription_max_ongoing_operations: MAX_ONGOING_OPERATIONS, + operation_max_storage_items: MAX_STORAGE_ITER_ITEMS, } } } @@ -100,6 +108,9 @@ pub struct ChainHead, Block: BlockT, Client> { subscriptions: Arc>, /// The hexadecimal encoded hash of the genesis block. genesis_hash: String, + /// The maximum number of items reported by the `chainHead_storage` before + /// pagination is required. + operation_max_storage_items: usize, /// Phantom member to pin the block type. _phantom: PhantomData, } @@ -124,6 +135,7 @@ impl, Block: BlockT, Client> ChainHead { config.subscription_max_ongoing_operations, backend, )), + operation_max_storage_items: config.operation_max_storage_items, genesis_hash, _phantom: PhantomData, } @@ -232,7 +244,7 @@ where follow_subscription: String, hash: Block::Hash, ) -> RpcResult { - let block_guard = match self.subscriptions.lock_block(&follow_subscription, hash, 1) { + let mut block_guard = match self.subscriptions.lock_block(&follow_subscription, hash, 1) { Ok(block) => block, Err(SubscriptionManagementError::SubscriptionAbsent) | Err(SubscriptionManagementError::ExceededLimits) => return Ok(MethodResponse::LimitReached), @@ -243,6 +255,8 @@ where Err(_) => return Err(ChainHeadRpcError::InvalidBlock.into()), }; + let operation_id = block_guard.operation().operation_id(); + let event = match self.client.block(hash) { Ok(Some(signed_block)) => { let extrinsics = signed_block @@ -252,7 +266,7 @@ where .map(|extrinsic| hex_string(&extrinsic.encode())) .collect(); FollowEvent::::OperationBodyDone(OperationBodyDone { - operation_id: block_guard.operation_id(), + operation_id: operation_id.clone(), value: extrinsics, }) }, @@ -268,16 +282,13 @@ where return Err(ChainHeadRpcError::InvalidBlock.into()) }, Err(error) => FollowEvent::::OperationError(OperationError { - operation_id: block_guard.operation_id(), + operation_id: operation_id.clone(), error: error.to_string(), }), }; let _ = block_guard.response_sender().unbounded_send(event); - Ok(MethodResponse::Started(MethodResponseStarted { - operation_id: block_guard.operation_id(), - discarded_items: None, - })) + Ok(MethodResponse::Started(MethodResponseStarted { operation_id, discarded_items: None })) } fn chain_head_unstable_header( @@ -337,7 +348,7 @@ where .transpose()? .map(ChildInfo::new_default_from_vec); - let block_guard = + let mut block_guard = match self.subscriptions.lock_block(&follow_subscription, hash, items.len()) { Ok(block) => block, Err(SubscriptionManagementError::SubscriptionAbsent) | @@ -349,17 +360,21 @@ where Err(_) => return Err(ChainHeadRpcError::InvalidBlock.into()), }; - let storage_client = ChainHeadStorage::::new(self.client.clone()); - let operation_id = block_guard.operation_id(); + let mut storage_client = ChainHeadStorage::::new( + self.client.clone(), + self.operation_max_storage_items, + ); + let operation = block_guard.operation(); + let operation_id = operation.operation_id(); // The number of operations we are allowed to execute. - let num_operations = block_guard.num_reserved(); + let num_operations = operation.num_reserved(); let discarded = items.len().saturating_sub(num_operations); let mut items = items; items.truncate(num_operations); let fut = async move { - storage_client.generate_events(block_guard, hash, items, child_trie); + storage_client.generate_events(block_guard, hash, items, child_trie).await; }; self.executor @@ -379,7 +394,7 @@ where ) -> RpcResult { let call_parameters = Bytes::from(parse_hex_param(call_parameters)?); - let block_guard = match self.subscriptions.lock_block(&follow_subscription, hash, 1) { + let mut block_guard = match self.subscriptions.lock_block(&follow_subscription, hash, 1) { Ok(block) => block, Err(SubscriptionManagementError::SubscriptionAbsent) | Err(SubscriptionManagementError::ExceededLimits) => { @@ -401,28 +416,26 @@ where .into()) } + let operation_id = block_guard.operation().operation_id(); let event = self .client .executor() .call(hash, &function, &call_parameters, CallContext::Offchain) .map(|result| { FollowEvent::::OperationCallDone(OperationCallDone { - operation_id: block_guard.operation_id(), + operation_id: operation_id.clone(), output: hex_string(&result), }) }) .unwrap_or_else(|error| { FollowEvent::::OperationError(OperationError { - operation_id: block_guard.operation_id(), + operation_id: operation_id.clone(), error: error.to_string(), }) }); let _ = block_guard.response_sender().unbounded_send(event); - Ok(MethodResponse::Started(MethodResponseStarted { - operation_id: block_guard.operation_id(), - discarded_items: None, - })) + Ok(MethodResponse::Started(MethodResponseStarted { operation_id, discarded_items: None })) } fn chain_head_unstable_unpin( @@ -443,4 +456,35 @@ where Err(_) => Err(ChainHeadRpcError::InvalidBlock.into()), } } + + fn chain_head_unstable_continue( + &self, + follow_subscription: String, + operation_id: String, + ) -> RpcResult<()> { + let Some(operation) = self.subscriptions.get_operation(&follow_subscription, &operation_id) else { + return Ok(()) + }; + + if !operation.submit_continue() { + // Continue called without generating a `WaitingForContinue` event. + Err(ChainHeadRpcError::InvalidContinue.into()) + } else { + Ok(()) + } + } + + fn chain_head_unstable_stop_operation( + &self, + follow_subscription: String, + operation_id: String, + ) -> RpcResult<()> { + let Some(operation) = self.subscriptions.get_operation(&follow_subscription, &operation_id) else { + return Ok(()) + }; + + operation.stop_operation(); + + Ok(()) + } } diff --git a/client/rpc-spec-v2/src/chain_head/chain_head_storage.rs b/client/rpc-spec-v2/src/chain_head/chain_head_storage.rs index 393e4489c8c07..5e1f38f9a9978 100644 --- a/client/rpc-spec-v2/src/chain_head/chain_head_storage.rs +++ b/client/rpc-spec-v2/src/chain_head/chain_head_storage.rs @@ -18,7 +18,7 @@ //! Implementation of the `chainHead_storage` method. -use std::{marker::PhantomData, sync::Arc}; +use std::{collections::VecDeque, marker::PhantomData, sync::Arc}; use sc_client_api::{Backend, ChildInfo, StorageKey, StorageProvider}; use sc_utils::mpsc::TracingUnboundedSender; @@ -37,10 +37,6 @@ use super::{ FollowEvent, }; -/// The maximum number of items the `chainHead_storage` can return -/// before paginations is required. -const MAX_ITER_ITEMS: usize = 10; - /// The query type of an interation. enum IterQueryType { /// Iterating over (key, value) pairs. @@ -53,16 +49,34 @@ enum IterQueryType { pub struct ChainHeadStorage { /// Substrate client. client: Arc, - _phantom: PhantomData<(Block, BE)>, + /// Queue of operations that may require pagination. + iter_operations: VecDeque, + /// The maximum number of items reported by the `chainHead_storage` before + /// pagination is required. + operation_max_storage_items: usize, + _phandom: PhantomData<(BE, Block)>, } impl ChainHeadStorage { /// Constructs a new [`ChainHeadStorage`]. - pub fn new(client: Arc) -> Self { - Self { client, _phantom: PhantomData } + pub fn new(client: Arc, operation_max_storage_items: usize) -> Self { + Self { + client, + iter_operations: VecDeque::new(), + operation_max_storage_items, + _phandom: PhantomData, + } } } +/// Query to iterate over storage. +struct QueryIter { + /// The next key from which the iteration should continue. + next_key: StorageKey, + /// The type of the query (either value or hash). + ty: IterQueryType, +} + /// Checks if the provided key (main or child key) is valid /// for queries. /// @@ -77,7 +91,7 @@ fn is_key_queryable(key: &[u8]) -> bool { type QueryResult = Result, String>; /// The result of iterating over keys. -type QueryIterResult = Result, String>; +type QueryIterResult = Result<(Vec, Option), String>; impl ChainHeadStorage where @@ -131,64 +145,118 @@ where .unwrap_or_else(|error| QueryResult::Err(error.to_string())) } - /// Handle iterating over (key, value) or (key, hash) pairs. - fn query_storage_iter( + /// Iterate over at most `operation_max_storage_items` keys. + /// + /// Returns the storage result with a potential next key to resume iteration. + fn query_storage_iter_pagination( &self, + query: QueryIter, hash: Block::Hash, - key: &StorageKey, child_key: Option<&ChildInfo>, - ty: IterQueryType, ) -> QueryIterResult { - let keys_iter = if let Some(child_key) = child_key { - self.client.child_storage_keys(hash, child_key.to_owned(), Some(key), None) + let QueryIter { next_key, ty } = query; + + let mut keys_iter = if let Some(child_key) = child_key { + self.client + .child_storage_keys(hash, child_key.to_owned(), Some(&next_key), None) } else { - self.client.storage_keys(hash, Some(key), None) + self.client.storage_keys(hash, Some(&next_key), None) } - .map_err(|error| error.to_string())?; + .map_err(|err| err.to_string())?; + + let mut ret = Vec::with_capacity(self.operation_max_storage_items); + for _ in 0..self.operation_max_storage_items { + let Some(key) = keys_iter.next() else { + break + }; - let mut ret = Vec::with_capacity(MAX_ITER_ITEMS); - let mut keys_iter = keys_iter.take(MAX_ITER_ITEMS); - while let Some(key) = keys_iter.next() { let result = match ty { IterQueryType::Value => self.query_storage_value(hash, &key, child_key), IterQueryType::Hash => self.query_storage_hash(hash, &key, child_key), }?; - if let Some(result) = result { - ret.push(result); + if let Some(value) = result { + ret.push(value); } } - QueryIterResult::Ok(ret) + // Save the next key if any to continue the iteration. + let maybe_next_query = keys_iter.next().map(|next_key| QueryIter { next_key, ty }); + Ok((ret, maybe_next_query)) } - /// Generate the block events for the `chainHead_storage` method. - pub fn generate_events( - &self, - block_guard: BlockGuard, + /// Iterate over (key, hash) and (key, value) generating the `WaitingForContinue` event if + /// necessary. + async fn generate_storage_iter_events( + &mut self, + mut block_guard: BlockGuard, hash: Block::Hash, - items: Vec>, child_key: Option, ) { - /// Build and send the opaque error back to the `chainHead_follow` method. - fn send_error( - sender: &TracingUnboundedSender>, - operation_id: String, - error: String, - ) { - let _ = - sender.unbounded_send(FollowEvent::::OperationError(OperationError { - operation_id, - error, - })); + let sender = block_guard.response_sender(); + let operation = block_guard.operation(); + + while let Some(query) = self.iter_operations.pop_front() { + if operation.was_stopped() { + return + } + + let result = self.query_storage_iter_pagination(query, hash, child_key.as_ref()); + let (events, maybe_next_query) = match result { + QueryIterResult::Ok(result) => result, + QueryIterResult::Err(error) => { + send_error::(&sender, operation.operation_id(), error.to_string()); + return + }, + }; + + if !events.is_empty() { + // Send back the results of the iteration produced so far. + let _ = sender.unbounded_send(FollowEvent::::OperationStorageItems( + OperationStorageItems { operation_id: operation.operation_id(), items: events }, + )); + } + + if let Some(next_query) = maybe_next_query { + let _ = + sender.unbounded_send(FollowEvent::::OperationWaitingForContinue( + OperationId { operation_id: operation.operation_id() }, + )); + + // The operation might be continued or cancelled only after the + // `OperationWaitingForContinue` is generated above. + operation.wait_for_continue().await; + + // Give a chance for the other items to advance next time. + self.iter_operations.push_back(next_query); + } + } + + if operation.was_stopped() { + return } + let _ = + sender.unbounded_send(FollowEvent::::OperationStorageDone(OperationId { + operation_id: operation.operation_id(), + })); + } + + /// Generate the block events for the `chainHead_storage` method. + pub async fn generate_events( + &mut self, + mut block_guard: BlockGuard, + hash: Block::Hash, + items: Vec>, + child_key: Option, + ) { let sender = block_guard.response_sender(); + let operation = block_guard.operation(); if let Some(child_key) = child_key.as_ref() { if !is_key_queryable(child_key.storage_key()) { let _ = sender.unbounded_send(FollowEvent::::OperationStorageDone( - OperationId { operation_id: block_guard.operation_id() }, + OperationId { operation_id: operation.operation_id() }, )); return } @@ -206,7 +274,7 @@ where Ok(Some(value)) => storage_results.push(value), Ok(None) => continue, Err(error) => { - send_error::(&sender, block_guard.operation_id(), error); + send_error::(&sender, operation.operation_id(), error); return }, } @@ -216,34 +284,16 @@ where Ok(Some(value)) => storage_results.push(value), Ok(None) => continue, Err(error) => { - send_error::(&sender, block_guard.operation_id(), error); + send_error::(&sender, operation.operation_id(), error); return }, }, - StorageQueryType::DescendantsValues => match self.query_storage_iter( - hash, - &item.key, - child_key.as_ref(), - IterQueryType::Value, - ) { - Ok(values) => storage_results.extend(values), - Err(error) => { - send_error::(&sender, block_guard.operation_id(), error); - return - }, - }, - StorageQueryType::DescendantsHashes => match self.query_storage_iter( - hash, - &item.key, - child_key.as_ref(), - IterQueryType::Hash, - ) { - Ok(values) => storage_results.extend(values), - Err(error) => { - send_error::(&sender, block_guard.operation_id(), error); - return - }, - }, + StorageQueryType::DescendantsValues => self + .iter_operations + .push_back(QueryIter { next_key: item.key, ty: IterQueryType::Value }), + StorageQueryType::DescendantsHashes => self + .iter_operations + .push_back(QueryIter { next_key: item.key, ty: IterQueryType::Hash }), _ => continue, }; } @@ -251,15 +301,24 @@ where if !storage_results.is_empty() { let _ = sender.unbounded_send(FollowEvent::::OperationStorageItems( OperationStorageItems { - operation_id: block_guard.operation_id(), + operation_id: operation.operation_id(), items: storage_results, }, )); } - let _ = - sender.unbounded_send(FollowEvent::::OperationStorageDone(OperationId { - operation_id: block_guard.operation_id(), - })); + self.generate_storage_iter_events(block_guard, hash, child_key).await } } + +/// Build and send the opaque error back to the `chainHead_follow` method. +fn send_error( + sender: &TracingUnboundedSender>, + operation_id: String, + error: String, +) { + let _ = sender.unbounded_send(FollowEvent::::OperationError(OperationError { + operation_id, + error, + })); +} diff --git a/client/rpc-spec-v2/src/chain_head/subscription/inner.rs b/client/rpc-spec-v2/src/chain_head/subscription/inner.rs index 9f42be4a2f7f6..d6f64acd63f5f 100644 --- a/client/rpc-spec-v2/src/chain_head/subscription/inner.rs +++ b/client/rpc-spec-v2/src/chain_head/subscription/inner.rs @@ -17,12 +17,13 @@ // along with this program. If not, see . use futures::channel::oneshot; +use parking_lot::Mutex; use sc_client_api::Backend; use sc_utils::mpsc::{tracing_unbounded, TracingUnboundedReceiver, TracingUnboundedSender}; use sp_runtime::traits::Block as BlockT; use std::{ collections::{hash_map::Entry, HashMap}, - sync::Arc, + sync::{atomic::AtomicBool, Arc}, time::{Duration, Instant}, }; @@ -154,12 +155,184 @@ struct PermitOperations { _permit: tokio::sync::OwnedSemaphorePermit, } -impl PermitOperations { +/// The state of one operation. +/// +/// This is directly exposed to users via `chain_head_unstable_continue` and +/// `chain_head_unstable_stop_operation`. +#[derive(Clone)] +pub struct OperationState { + /// The shared operation state that holds information about the + /// `waitingForContinue` event and cancellation. + shared_state: Arc, + /// Send notifications when the user calls `chainHead_continue` method. + send_continue: tokio::sync::mpsc::Sender<()>, +} + +impl OperationState { + /// Returns true if `chainHead_continue` is called after the + /// `waitingForContinue` event was emitted for the associated + /// operation ID. + pub fn submit_continue(&self) -> bool { + // `waitingForContinue` not generated. + if !self.shared_state.requested_continue.load(std::sync::atomic::Ordering::Acquire) { + return false + } + + // Has enough capacity for 1 message. + // Can fail if the `stop_operation` propagated the stop first. + self.send_continue.try_send(()).is_ok() + } + + /// Stops the operation if `waitingForContinue` event was emitted for the associated + /// operation ID. + /// + /// Returns nothing in accordance with `chainHead_unstable_stopOperation`. + pub fn stop_operation(&self) { + // `waitingForContinue` not generated. + if !self.shared_state.requested_continue.load(std::sync::atomic::Ordering::Acquire) { + return + } + + self.shared_state + .operation_stopped + .store(true, std::sync::atomic::Ordering::Release); + + // Send might not have enough capacity if `submit_continue` was sent first. + // However, the `operation_stopped` boolean was set. + let _ = self.send_continue.try_send(()); + } +} + +/// The shared operation state between the backend [`RegisteredOperation`] and frontend +/// [`RegisteredOperation`]. +struct SharedOperationState { + /// True if the `chainHead` generated `waitingForContinue` event. + requested_continue: AtomicBool, + /// True if the operation was cancelled by the user. + operation_stopped: AtomicBool, +} + +impl SharedOperationState { + /// Constructs a new [`SharedOperationState`]. + /// + /// This is efficiently cloned under a single heap allocation. + fn new() -> Arc { + Arc::new(SharedOperationState { + requested_continue: AtomicBool::new(false), + operation_stopped: AtomicBool::new(false), + }) + } +} + +/// The registered operation passed to the `chainHead` methods. +/// +/// This is used internally by the `chainHead` methods. +pub struct RegisteredOperation { + /// The shared operation state that holds information about the + /// `waitingForContinue` event and cancellation. + shared_state: Arc, + /// Receive notifications when the user calls `chainHead_continue` method. + recv_continue: tokio::sync::mpsc::Receiver<()>, + /// The operation ID of the request. + operation_id: String, + /// Track the operations ID of this subscription. + operations: Arc>>, + /// Permit a number of items to be executed by this operation. + permit: PermitOperations, +} + +impl RegisteredOperation { + /// Wait until the user calls `chainHead_continue` or the operation + /// is cancelled via `chainHead_stopOperation`. + pub async fn wait_for_continue(&mut self) { + self.shared_state + .requested_continue + .store(true, std::sync::atomic::Ordering::Release); + + // The sender part of this channel is around for as long as this object exists, + // because it is stored in the `OperationState` of the `operations` field. + // The sender part is removed from tracking when this object is dropped. + let _ = self.recv_continue.recv().await; + + self.shared_state + .requested_continue + .store(false, std::sync::atomic::Ordering::Release); + } + + /// Returns true if the current operation was stopped. + pub fn was_stopped(&self) -> bool { + self.shared_state.operation_stopped.load(std::sync::atomic::Ordering::Acquire) + } + + /// Get the operation ID. + pub fn operation_id(&self) -> String { + self.operation_id.clone() + } + /// Returns the number of reserved elements for this permit. /// /// This can be smaller than the number of items requested via [`LimitOperations::reserve()`]. - fn num_reserved(&self) -> usize { - self.num_ops + pub fn num_reserved(&self) -> usize { + self.permit.num_ops + } +} + +impl Drop for RegisteredOperation { + fn drop(&mut self) { + let mut operations = self.operations.lock(); + operations.remove(&self.operation_id); + } +} + +/// The ongoing operations of a subscription. +struct Operations { + /// The next operation ID to be generated. + next_operation_id: usize, + /// Limit the number of ongoing operations. + limits: LimitOperations, + /// Track the operations ID of this subscription. + operations: Arc>>, +} + +impl Operations { + /// Constructs a new [`Operations`]. + fn new(max_operations: usize) -> Self { + Operations { + next_operation_id: 0, + limits: LimitOperations::new(max_operations), + operations: Default::default(), + } + } + + /// Register a new operation. + pub fn register_operation(&mut self, to_reserve: usize) -> Option { + let permit = self.limits.reserve_at_most(to_reserve)?; + + let operation_id = self.next_operation_id(); + + // At most one message can be sent. + let (send_continue, recv_continue) = tokio::sync::mpsc::channel(1); + let shared_state = SharedOperationState::new(); + + let state = OperationState { send_continue, shared_state: shared_state.clone() }; + + // Cloned operations for removing the current ID on drop. + let operations = self.operations.clone(); + operations.lock().insert(operation_id.clone(), state); + + Some(RegisteredOperation { shared_state, operation_id, recv_continue, operations, permit }) + } + + /// Get the associated operation state with the ID. + pub fn get_operation(&self, id: &str) -> Option { + self.operations.lock().get(id).map(|state| state.clone()) + } + + /// Generate the next operation ID for this subscription. + fn next_operation_id(&mut self) -> String { + let op_id = self.next_operation_id; + self.next_operation_id += 1; + op_id.to_string() } } @@ -180,10 +353,8 @@ struct SubscriptionState { /// /// This object is cloned between methods. response_sender: TracingUnboundedSender>, - /// Limit the number of ongoing operations. - limits: LimitOperations, - /// The next operation ID. - next_operation_id: usize, + /// The ongoing operations of a subscription. + operations: Operations, /// Track the block hashes available for this subscription. /// /// This implementation assumes: @@ -296,18 +467,16 @@ impl SubscriptionState { timestamp } - /// Generate the next operation ID for this subscription. - fn next_operation_id(&mut self) -> usize { - let op_id = self.next_operation_id; - self.next_operation_id = self.next_operation_id.wrapping_add(1); - op_id + /// Register a new operation. + /// + /// The registered operation can execute at least one item and at most the requested items. + fn register_operation(&mut self, to_reserve: usize) -> Option { + self.operations.register_operation(to_reserve) } - /// Reserves capacity to execute at least one operation and at most the requested items. - /// - /// For more details see [`PermitOperations`]. - fn reserve_at_most(&self, to_reserve: usize) -> Option { - self.limits.reserve_at_most(to_reserve) + /// Get the associated operation state with the ID. + pub fn get_operation(&self, id: &str) -> Option { + self.operations.get_operation(id) } } @@ -318,8 +487,7 @@ pub struct BlockGuard> { hash: Block::Hash, with_runtime: bool, response_sender: TracingUnboundedSender>, - operation_id: String, - permit_operations: PermitOperations, + operation: RegisteredOperation, backend: Arc, } @@ -337,22 +505,14 @@ impl> BlockGuard { hash: Block::Hash, with_runtime: bool, response_sender: TracingUnboundedSender>, - operation_id: usize, - permit_operations: PermitOperations, + operation: RegisteredOperation, backend: Arc, ) -> Result { backend .pin_block(hash) .map_err(|err| SubscriptionManagementError::Custom(err.to_string()))?; - Ok(Self { - hash, - with_runtime, - response_sender, - operation_id: operation_id.to_string(), - permit_operations, - backend, - }) + Ok(Self { hash, with_runtime, response_sender, operation, backend }) } /// The `with_runtime` flag of the subscription. @@ -365,16 +525,9 @@ impl> BlockGuard { self.response_sender.clone() } - /// The operation ID of this method. - pub fn operation_id(&self) -> String { - self.operation_id.clone() - } - - /// Returns the number of reserved elements for this permit. - /// - /// This can be smaller than the number of items requested. - pub fn num_reserved(&self) -> usize { - self.permit_operations.num_reserved() + /// Get the details of the registered operation. + pub fn operation(&mut self) -> &mut RegisteredOperation { + &mut self.operation } } @@ -445,9 +598,8 @@ impl> SubscriptionsInner { with_runtime, tx_stop: Some(tx_stop), response_sender, - limits: LimitOperations::new(self.max_ongoing_operations), - next_operation_id: 0, blocks: Default::default(), + operations: Operations::new(self.max_ongoing_operations), }; entry.insert(state); @@ -631,21 +783,24 @@ impl> SubscriptionsInner { return Err(SubscriptionManagementError::BlockHashAbsent) } - let Some(permit_operations) = sub.reserve_at_most(to_reserve) else { + let Some(operation) = sub.register_operation(to_reserve) else { // Error when the server cannot execute at least one operation. return Err(SubscriptionManagementError::ExceededLimits) }; - let operation_id = sub.next_operation_id(); BlockGuard::new( hash, sub.with_runtime, sub.response_sender.clone(), - operation_id, - permit_operations, + operation, self.backend.clone(), ) } + + pub fn get_operation(&mut self, sub_id: &str, id: &str) -> Option { + let state = self.subs.get(sub_id)?; + state.get_operation(id) + } } #[cfg(test)] @@ -758,8 +913,7 @@ mod tests { with_runtime: false, tx_stop: None, response_sender, - next_operation_id: 0, - limits: LimitOperations::new(MAX_OPERATIONS_PER_SUB), + operations: Operations::new(MAX_OPERATIONS_PER_SUB), blocks: Default::default(), }; @@ -788,9 +942,8 @@ mod tests { with_runtime: false, tx_stop: None, response_sender, - next_operation_id: 0, - limits: LimitOperations::new(MAX_OPERATIONS_PER_SUB), blocks: Default::default(), + operations: Operations::new(MAX_OPERATIONS_PER_SUB), }; let hash = H256::random(); @@ -1107,12 +1260,12 @@ mod tests { // One operation is reserved. let permit_one = ops.reserve_at_most(1).unwrap(); - assert_eq!(permit_one.num_reserved(), 1); + assert_eq!(permit_one.num_ops, 1); // Request 2 operations, however there is capacity only for one. let permit_two = ops.reserve_at_most(2).unwrap(); // Number of reserved permits is smaller than provided. - assert_eq!(permit_two.num_reserved(), 1); + assert_eq!(permit_two.num_ops, 1); // Try to reserve operations when there's no space. let permit = ops.reserve_at_most(1); @@ -1123,6 +1276,6 @@ mod tests { // Can reserve again let permit_three = ops.reserve_at_most(1).unwrap(); - assert_eq!(permit_three.num_reserved(), 1); + assert_eq!(permit_three.num_ops, 1); } } diff --git a/client/rpc-spec-v2/src/chain_head/subscription/mod.rs b/client/rpc-spec-v2/src/chain_head/subscription/mod.rs index 39618ecfc1b3e..b25b1a4913b49 100644 --- a/client/rpc-spec-v2/src/chain_head/subscription/mod.rs +++ b/client/rpc-spec-v2/src/chain_head/subscription/mod.rs @@ -25,6 +25,8 @@ mod error; mod inner; use self::inner::SubscriptionsInner; + +pub use self::inner::OperationState; pub use error::SubscriptionManagementError; pub use inner::{BlockGuard, InsertedSubscriptionData}; @@ -126,4 +128,10 @@ impl> SubscriptionManagement { let mut inner = self.inner.write(); inner.lock_block(sub_id, hash, to_reserve) } + + /// Get the operation state. + pub fn get_operation(&self, sub_id: &str, operation_id: &str) -> Option { + let mut inner = self.inner.write(); + inner.get_operation(sub_id, operation_id) + } } diff --git a/client/rpc-spec-v2/src/chain_head/tests.rs b/client/rpc-spec-v2/src/chain_head/tests.rs index 4bda06d3cf01c..00ed9089058ee 100644 --- a/client/rpc-spec-v2/src/chain_head/tests.rs +++ b/client/rpc-spec-v2/src/chain_head/tests.rs @@ -25,7 +25,7 @@ use sp_core::{ Blake2Hasher, Hasher, }; use sp_version::RuntimeVersion; -use std::{collections::HashSet, sync::Arc, time::Duration}; +use std::{collections::HashSet, fmt::Debug, sync::Arc, time::Duration}; use substrate_test_runtime::Transfer; use substrate_test_runtime_client::{ prelude::*, runtime, runtime::RuntimeApi, Backend, BlockBuilderExt, Client, @@ -37,12 +37,14 @@ type Block = substrate_test_runtime_client::runtime::Block; const MAX_PINNED_BLOCKS: usize = 32; const MAX_PINNED_SECS: u64 = 60; const MAX_OPERATIONS: usize = 16; +const MAX_PAGINATION_LIMIT: usize = 5; const CHAIN_GENESIS: [u8; 32] = [0; 32]; const INVALID_HASH: [u8; 32] = [1; 32]; const KEY: &[u8] = b":mock"; const VALUE: &[u8] = b"hello world"; const CHILD_STORAGE_KEY: &[u8] = b"child"; const CHILD_VALUE: &[u8] = b"child value"; +const DOES_NOT_PRODUCE_EVENTS_SECONDS: u64 = 10; async fn get_next_event(sub: &mut RpcSubscription) -> T { let (event, _sub_id) = tokio::time::timeout(std::time::Duration::from_secs(60), sub.next()) @@ -53,6 +55,13 @@ async fn get_next_event(sub: &mut RpcSubscriptio event } +async fn does_not_produce_event( + sub: &mut RpcSubscription, + duration: std::time::Duration, +) { + tokio::time::timeout(duration, sub.next::()).await.unwrap_err(); +} + async fn run_with_timeout(future: F) -> ::Output { tokio::time::timeout(std::time::Duration::from_secs(60 * 10), future) .await @@ -84,6 +93,7 @@ async fn setup_api() -> ( global_max_pinned_blocks: MAX_PINNED_BLOCKS, subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: MAX_PAGINATION_LIMIT, }, ) .into_rpc(); @@ -127,6 +137,7 @@ async fn follow_subscription_produces_blocks() { global_max_pinned_blocks: MAX_PINNED_BLOCKS, subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: MAX_PAGINATION_LIMIT, }, ) .into_rpc(); @@ -188,6 +199,7 @@ async fn follow_with_runtime() { global_max_pinned_blocks: MAX_PINNED_BLOCKS, subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: MAX_PAGINATION_LIMIT, }, ) .into_rpc(); @@ -299,6 +311,7 @@ async fn get_genesis() { global_max_pinned_blocks: MAX_PINNED_BLOCKS, subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: MAX_PAGINATION_LIMIT, }, ) .into_rpc(); @@ -508,6 +521,7 @@ async fn call_runtime_without_flag() { global_max_pinned_blocks: MAX_PINNED_BLOCKS, subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: MAX_PAGINATION_LIMIT, }, ) .into_rpc(); @@ -743,11 +757,16 @@ async fn get_storage_multi_query_iter() { assert_matches!( get_next_event::>(&mut block_sub).await, FollowEvent::OperationStorageItems(res) if res.operation_id == operation_id && - res.items.len() == 2 && + res.items.len() == 1 && + res.items[0].key == key && + res.items[0].result == StorageResultType::Hash(expected_hash) + ); + assert_matches!( + get_next_event::>(&mut block_sub).await, + FollowEvent::OperationStorageItems(res) if res.operation_id == operation_id && + res.items.len() == 1 && res.items[0].key == key && - res.items[1].key == key && - res.items[0].result == StorageResultType::Hash(expected_hash) && - res.items[1].result == StorageResultType::Value(expected_value) + res.items[0].result == StorageResultType::Value(expected_value) ); assert_matches!( get_next_event::>(&mut block_sub).await, @@ -788,11 +807,16 @@ async fn get_storage_multi_query_iter() { assert_matches!( get_next_event::>(&mut block_sub).await, FollowEvent::OperationStorageItems(res) if res.operation_id == operation_id && - res.items.len() == 2 && + res.items.len() == 1 && res.items[0].key == key && - res.items[1].key == key && - res.items[0].result == StorageResultType::Hash(expected_hash) && - res.items[1].result == StorageResultType::Value(expected_value) + res.items[0].result == StorageResultType::Hash(expected_hash) + ); + assert_matches!( + get_next_event::>(&mut block_sub).await, + FollowEvent::OperationStorageItems(res) if res.operation_id == operation_id && + res.items.len() == 1 && + res.items[0].key == key && + res.items[0].result == StorageResultType::Value(expected_value) ); assert_matches!( get_next_event::>(&mut block_sub).await, @@ -1137,6 +1161,7 @@ async fn separate_operation_ids_for_subscriptions() { global_max_pinned_blocks: MAX_PINNED_BLOCKS, subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: MAX_PAGINATION_LIMIT, }, ) .into_rpc(); @@ -1217,6 +1242,7 @@ async fn follow_generates_initial_blocks() { global_max_pinned_blocks: MAX_PINNED_BLOCKS, subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: MAX_PAGINATION_LIMIT, }, ) .into_rpc(); @@ -1348,6 +1374,7 @@ async fn follow_exceeding_pinned_blocks() { global_max_pinned_blocks: 2, subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: MAX_PAGINATION_LIMIT, }, ) .into_rpc(); @@ -1402,6 +1429,7 @@ async fn follow_with_unpin() { global_max_pinned_blocks: 2, subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: MAX_PAGINATION_LIMIT, }, ) .into_rpc(); @@ -1486,6 +1514,7 @@ async fn follow_prune_best_block() { global_max_pinned_blocks: MAX_PINNED_BLOCKS, subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: MAX_PAGINATION_LIMIT, }, ) .into_rpc(); @@ -1646,6 +1675,7 @@ async fn follow_forks_pruned_block() { global_max_pinned_blocks: MAX_PINNED_BLOCKS, subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: MAX_PAGINATION_LIMIT, }, ) .into_rpc(); @@ -1763,6 +1793,7 @@ async fn follow_report_multiple_pruned_block() { global_max_pinned_blocks: MAX_PINNED_BLOCKS, subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: MAX_PAGINATION_LIMIT, }, ) .into_rpc(); @@ -1971,6 +2002,7 @@ async fn pin_block_references() { global_max_pinned_blocks: 3, subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: MAX_PAGINATION_LIMIT, }, ) .into_rpc(); @@ -2084,6 +2116,7 @@ async fn follow_finalized_before_new_block() { global_max_pinned_blocks: MAX_PINNED_BLOCKS, subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: MAX_PAGINATION_LIMIT, }, ) .into_rpc(); @@ -2184,6 +2217,7 @@ async fn ensure_operation_limits_works() { global_max_pinned_blocks: MAX_PINNED_BLOCKS, subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), subscription_max_ongoing_operations: 1, + operation_max_storage_items: MAX_PAGINATION_LIMIT, }, ) .into_rpc(); @@ -2259,3 +2293,275 @@ async fn ensure_operation_limits_works() { FollowEvent::OperationCallDone(done) if done.operation_id == operation_id && done.output == "0x0000000000000000" ); } + +#[tokio::test] +async fn check_continue_operation() { + let child_info = ChildInfo::new_default(CHILD_STORAGE_KEY); + let builder = TestClientBuilder::new().add_extra_child_storage( + &child_info, + KEY.to_vec(), + CHILD_VALUE.to_vec(), + ); + let backend = builder.backend(); + let mut client = Arc::new(builder.build()); + + // Configure the chainHead with maximum 1 item before asking for pagination. + let api = ChainHead::new( + client.clone(), + backend, + Arc::new(TaskExecutor::default()), + CHAIN_GENESIS, + ChainHeadConfig { + global_max_pinned_blocks: MAX_PINNED_BLOCKS, + subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), + subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: 1, + }, + ) + .into_rpc(); + + let mut sub = api.subscribe("chainHead_unstable_follow", [true]).await.unwrap(); + let sub_id = sub.subscription_id(); + let sub_id = serde_json::to_string(&sub_id).unwrap(); + + // Import a new block with storage changes. + let mut builder = client.new_block(Default::default()).unwrap(); + builder.push_storage_change(b":m".to_vec(), Some(b"a".to_vec())).unwrap(); + builder.push_storage_change(b":mo".to_vec(), Some(b"ab".to_vec())).unwrap(); + builder.push_storage_change(b":moc".to_vec(), Some(b"abc".to_vec())).unwrap(); + builder.push_storage_change(b":mock".to_vec(), Some(b"abcd".to_vec())).unwrap(); + let block = builder.build().unwrap().block; + let block_hash = format!("{:?}", block.header.hash()); + client.import(BlockOrigin::Own, block.clone()).await.unwrap(); + + // Ensure the imported block is propagated and pinned for this subscription. + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::Initialized(_) + ); + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::NewBlock(_) + ); + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::BestBlockChanged(_) + ); + + let invalid_hash = hex_string(&INVALID_HASH); + + // Invalid subscription ID must produce no results. + let _res: () = api + .call("chainHead_unstable_continue", ["invalid_sub_id", &invalid_hash]) + .await + .unwrap(); + + // Invalid operation ID must produce no results. + let _res: () = api.call("chainHead_unstable_continue", [&sub_id, &invalid_hash]).await.unwrap(); + + // Valid call with storage at the key. + let response: MethodResponse = api + .call( + "chainHead_unstable_storage", + rpc_params![ + &sub_id, + &block_hash, + vec![StorageQuery { + key: hex_string(b":m"), + query_type: StorageQueryType::DescendantsValues + }] + ], + ) + .await + .unwrap(); + let operation_id = match response { + MethodResponse::Started(started) => started.operation_id, + MethodResponse::LimitReached => panic!("Expected started response"), + }; + + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::OperationStorageItems(res) if res.operation_id == operation_id && + res.items.len() == 1 && + res.items[0].key == hex_string(b":m") && + res.items[0].result == StorageResultType::Value(hex_string(b"a")) + ); + + // Pagination event. + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::OperationWaitingForContinue(res) if res.operation_id == operation_id + ); + + does_not_produce_event::>( + &mut sub, + std::time::Duration::from_secs(DOES_NOT_PRODUCE_EVENTS_SECONDS), + ) + .await; + let _res: () = api.call("chainHead_unstable_continue", [&sub_id, &operation_id]).await.unwrap(); + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::OperationStorageItems(res) if res.operation_id == operation_id && + res.items.len() == 1 && + res.items[0].key == hex_string(b":mo") && + res.items[0].result == StorageResultType::Value(hex_string(b"ab")) + ); + + // Pagination event. + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::OperationWaitingForContinue(res) if res.operation_id == operation_id + ); + + does_not_produce_event::>( + &mut sub, + std::time::Duration::from_secs(DOES_NOT_PRODUCE_EVENTS_SECONDS), + ) + .await; + let _res: () = api.call("chainHead_unstable_continue", [&sub_id, &operation_id]).await.unwrap(); + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::OperationStorageItems(res) if res.operation_id == operation_id && + res.items.len() == 1 && + res.items[0].key == hex_string(b":moc") && + res.items[0].result == StorageResultType::Value(hex_string(b"abc")) + ); + + // Pagination event. + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::OperationWaitingForContinue(res) if res.operation_id == operation_id + ); + does_not_produce_event::>( + &mut sub, + std::time::Duration::from_secs(DOES_NOT_PRODUCE_EVENTS_SECONDS), + ) + .await; + let _res: () = api.call("chainHead_unstable_continue", [&sub_id, &operation_id]).await.unwrap(); + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::OperationStorageItems(res) if res.operation_id == operation_id && + res.items.len() == 1 && + res.items[0].key == hex_string(b":mock") && + res.items[0].result == StorageResultType::Value(hex_string(b"abcd")) + ); + + // Finished. + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::OperationStorageDone(done) if done.operation_id == operation_id + ); +} + +#[tokio::test] +async fn stop_storage_operation() { + let child_info = ChildInfo::new_default(CHILD_STORAGE_KEY); + let builder = TestClientBuilder::new().add_extra_child_storage( + &child_info, + KEY.to_vec(), + CHILD_VALUE.to_vec(), + ); + let backend = builder.backend(); + let mut client = Arc::new(builder.build()); + + // Configure the chainHead with maximum 1 item before asking for pagination. + let api = ChainHead::new( + client.clone(), + backend, + Arc::new(TaskExecutor::default()), + CHAIN_GENESIS, + ChainHeadConfig { + global_max_pinned_blocks: MAX_PINNED_BLOCKS, + subscription_max_pinned_duration: Duration::from_secs(MAX_PINNED_SECS), + subscription_max_ongoing_operations: MAX_OPERATIONS, + operation_max_storage_items: 1, + }, + ) + .into_rpc(); + + let mut sub = api.subscribe("chainHead_unstable_follow", [true]).await.unwrap(); + let sub_id = sub.subscription_id(); + let sub_id = serde_json::to_string(&sub_id).unwrap(); + + // Import a new block with storage changes. + let mut builder = client.new_block(Default::default()).unwrap(); + builder.push_storage_change(b":m".to_vec(), Some(b"a".to_vec())).unwrap(); + builder.push_storage_change(b":mo".to_vec(), Some(b"ab".to_vec())).unwrap(); + let block = builder.build().unwrap().block; + let block_hash = format!("{:?}", block.header.hash()); + client.import(BlockOrigin::Own, block.clone()).await.unwrap(); + + // Ensure the imported block is propagated and pinned for this subscription. + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::Initialized(_) + ); + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::NewBlock(_) + ); + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::BestBlockChanged(_) + ); + + let invalid_hash = hex_string(&INVALID_HASH); + + // Invalid subscription ID must produce no results. + let _res: () = api + .call("chainHead_unstable_stopOperation", ["invalid_sub_id", &invalid_hash]) + .await + .unwrap(); + + // Invalid operation ID must produce no results. + let _res: () = api + .call("chainHead_unstable_stopOperation", [&sub_id, &invalid_hash]) + .await + .unwrap(); + + // Valid call with storage at the key. + let response: MethodResponse = api + .call( + "chainHead_unstable_storage", + rpc_params![ + &sub_id, + &block_hash, + vec![StorageQuery { + key: hex_string(b":m"), + query_type: StorageQueryType::DescendantsValues + }] + ], + ) + .await + .unwrap(); + let operation_id = match response { + MethodResponse::Started(started) => started.operation_id, + MethodResponse::LimitReached => panic!("Expected started response"), + }; + + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::OperationStorageItems(res) if res.operation_id == operation_id && + res.items.len() == 1 && + res.items[0].key == hex_string(b":m") && + res.items[0].result == StorageResultType::Value(hex_string(b"a")) + ); + + // Pagination event. + assert_matches!( + get_next_event::>(&mut sub).await, + FollowEvent::OperationWaitingForContinue(res) if res.operation_id == operation_id + ); + + // Stop the operation. + let _res: () = api + .call("chainHead_unstable_stopOperation", [&sub_id, &operation_id]) + .await + .unwrap(); + + does_not_produce_event::>( + &mut sub, + std::time::Duration::from_secs(DOES_NOT_PRODUCE_EVENTS_SECONDS), + ) + .await; +}