diff --git a/Cargo.lock b/Cargo.lock index 5a05f2f41..eee3f6a4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3954,6 +3954,30 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "mock-service-endpoint" +version = "0.9.1" +dependencies = [ + "assert2", + "async-stream", + "bytes", + "futures", + "http 0.2.12", + "http-body-util", + "hyper 1.2.0", + "hyper-util", + "prost", + "restate-service-protocol", + "restate-types", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tracing", + "tracing-subscriber", +] + [[package]] name = "multimap" version = "0.8.3" diff --git a/Cargo.toml b/Cargo.toml index 580703a18..471ef8e89 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "crates/*", "crates/codederror/derive", "server", + "tools/mock-service-endpoint", "tools/service-protocol-wireshark-dissector", "tools/xtask", "tools/bifrost-benchpress", diff --git a/crates/bifrost/src/loglets/local_loglet/keys.rs b/crates/bifrost/src/loglets/local_loglet/keys.rs index c903cdb90..d758b41ba 100644 --- a/crates/bifrost/src/loglets/local_loglet/keys.rs +++ b/crates/bifrost/src/loglets/local_loglet/keys.rs @@ -8,7 +8,6 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use std::fmt::Write; use std::mem::size_of; use bytes::{Buf, BufMut, Bytes, BytesMut}; @@ -17,6 +16,8 @@ use restate_types::logs::SequenceNumber; use crate::loglet::LogletOffset; +pub(crate) const DATA_KEY_PREFIX_LENGTH: usize = size_of::() + size_of::(); + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct RecordKey { pub log_id: u64, @@ -37,7 +38,7 @@ impl RecordKey { pub fn to_bytes(self) -> Bytes { let mut buf = BytesMut::with_capacity(size_of::() + 1); - buf.write_char('d').expect("enough key buffer"); + buf.put_u8(b'd'); buf.put_u64(self.log_id); buf.put_u64(self.offset.into()); buf.freeze() @@ -75,7 +76,7 @@ impl MetadataKey { pub fn to_bytes(self) -> Bytes { let mut buf = BytesMut::with_capacity(size_of::() + 1); // m for metadata - buf.write_char('m').expect("enough key buffer"); + buf.put_u8(b'm'); buf.put_u64(self.log_id); buf.put_u8(self.kind as u8); buf.freeze() diff --git a/crates/bifrost/src/loglets/local_loglet/log_store.rs b/crates/bifrost/src/loglets/local_loglet/log_store.rs index 75819e425..c67ff6b6c 100644 --- a/crates/bifrost/src/loglets/local_loglet/log_store.rs +++ b/crates/bifrost/src/loglets/local_loglet/log_store.rs @@ -16,9 +16,9 @@ use restate_rocksdb::{ use restate_types::arc_util::Updateable; use restate_types::config::{LocalLogletOptions, RocksDbOptions}; use restate_types::storage::{StorageDecodeError, StorageEncodeError}; -use rocksdb::{BoundColumnFamily, DBCompressionType, DB}; +use rocksdb::{BoundColumnFamily, DBCompressionType, SliceTransform, DB}; -use super::keys::{MetadataKey, MetadataKind}; +use super::keys::{MetadataKey, MetadataKind, DATA_KEY_PREFIX_LENGTH}; use super::log_state::{log_state_full_merge, log_state_partial_merge, LogState}; use super::log_store_writer::LogStoreWriter; @@ -138,12 +138,15 @@ fn cf_data_options(mut opts: rocksdb::Options) -> rocksdb::Options { opts.set_compression_per_level(&[ DBCompressionType::None, DBCompressionType::Snappy, - DBCompressionType::Snappy, - DBCompressionType::Snappy, - DBCompressionType::Snappy, - DBCompressionType::Snappy, + DBCompressionType::Zstd, + DBCompressionType::Zstd, + DBCompressionType::Zstd, + DBCompressionType::Zstd, DBCompressionType::Zstd, ]); + + opts.set_prefix_extractor(SliceTransform::create_fixed_prefix(DATA_KEY_PREFIX_LENGTH)); + opts.set_memtable_prefix_bloom_ratio(0.2); // most reads are sequential opts.set_advise_random_on_open(false); // @@ -158,7 +161,7 @@ fn cf_metadata_options(mut opts: rocksdb::Options) -> rocksdb::Options { opts.set_num_levels(3); opts.set_compression_per_level(&[ DBCompressionType::None, - DBCompressionType::None, + DBCompressionType::Snappy, DBCompressionType::Zstd, ]); opts.set_max_write_buffer_number(2); diff --git a/crates/core/src/task_center.rs b/crates/core/src/task_center.rs index fe08aaac1..af127cd99 100644 --- a/crates/core/src/task_center.rs +++ b/crates/core/src/task_center.rs @@ -127,9 +127,7 @@ fn tokio_builder(common_opts: &CommonOptions) -> tokio::runtime::Builder { format!("rs:worker-{}", id) }); - if let Some(worker_threads) = common_opts.default_thread_pool_size { - builder.worker_threads(worker_threads); - } + builder.worker_threads(common_opts.default_thread_pool_size()); builder } diff --git a/crates/rocksdb/src/db_manager.rs b/crates/rocksdb/src/db_manager.rs index a0664c8e4..e34380395 100644 --- a/crates/rocksdb/src/db_manager.rs +++ b/crates/rocksdb/src/db_manager.rs @@ -328,6 +328,8 @@ impl RocksDbManager { // https://github.com/facebook/rocksdb/blob/f059c7d9b96300091e07429a60f4ad55dac84859/include/rocksdb/table.h#L275 block_opts.set_format_version(5); block_opts.set_cache_index_and_filter_blocks(true); + block_opts.set_pin_l0_filter_and_index_blocks_in_cache(true); + block_opts.set_block_cache(&self.cache); cf_options.set_block_based_table_factory(&block_opts); diff --git a/crates/types/src/config/common.rs b/crates/types/src/config/common.rs index 3942c4acb..baa13508c 100644 --- a/crates/types/src/config/common.rs +++ b/crates/types/src/config/common.rs @@ -87,9 +87,9 @@ pub struct CommonOptions { /// # Default async runtime thread pool /// /// Size of the default thread pool used to perform internal tasks. - /// If not set, it defaults to the number of CPU cores. + /// If not set, it defaults to twice the number of CPU cores. #[builder(setter(strip_option))] - pub default_thread_pool_size: Option, + default_thread_pool_size: Option, /// # Tracing Endpoint /// @@ -246,19 +246,20 @@ impl CommonOptions { } pub fn storage_high_priority_bg_threads(&self) -> NonZeroUsize { - self.storage_high_priority_bg_threads.unwrap_or( + NonZeroUsize::new(4).unwrap() + } + + pub fn default_thread_pool_size(&self) -> usize { + 2 * self.default_thread_pool_size.unwrap_or( std::thread::available_parallelism() // Shouldn't really fail, but just in case. - .unwrap_or(NonZeroUsize::new(4).unwrap()), + .unwrap_or(NonZeroUsize::new(4).unwrap()) + .get(), ) } pub fn storage_low_priority_bg_threads(&self) -> NonZeroUsize { - self.storage_low_priority_bg_threads.unwrap_or( - std::thread::available_parallelism() - // Shouldn't really fail, but just in case. - .unwrap_or(NonZeroUsize::new(4).unwrap()), - ) + NonZeroUsize::new(4).unwrap() } pub fn rocksdb_bg_threads(&self) -> NonZeroU32 { @@ -301,8 +302,8 @@ impl Default for CommonOptions { default_thread_pool_size: None, storage_high_priority_bg_threads: None, storage_low_priority_bg_threads: None, - rocksdb_total_memtables_ratio: 0.5, // (50% of rocksdb-total-memory-size) - rocksdb_total_memory_size: NonZeroUsize::new(4_000_000_000).unwrap(), // 4GB + rocksdb_total_memtables_ratio: 0.6, // (60% of rocksdb-total-memory-size) + rocksdb_total_memory_size: NonZeroUsize::new(6_000_000_000).unwrap(), // 4GB rocksdb_bg_threads: None, rocksdb_high_priority_bg_threads: NonZeroU32::new(2).unwrap(), rocksdb_write_stall_threshold: std::time::Duration::from_secs(3).into(), diff --git a/crates/types/src/config/rocksdb.rs b/crates/types/src/config/rocksdb.rs index ba56b355d..d06df8ee9 100644 --- a/crates/types/src/config/rocksdb.rs +++ b/crates/types/src/config/rocksdb.rs @@ -123,6 +123,8 @@ impl RocksDbOptions { // Assuming 256MB for bifrost's data cf (2 memtables * 128MB default write buffer size) // Assuming 256MB for bifrost's metadata cf (2 memtables * 128MB default write buffer size) let buffer_size = (all_memtables - 512_000_000) / (num_partitions * 3) as usize; + // reduce the buffer_size by 10% for safety + let buffer_size = (buffer_size as f64 * 0.9) as usize; NonZeroUsize::new(buffer_size).unwrap() }) } diff --git a/crates/types/src/config/worker.rs b/crates/types/src/config/worker.rs index d6658d503..08239eda5 100644 --- a/crates/types/src/config/worker.rs +++ b/crates/types/src/config/worker.rs @@ -68,7 +68,7 @@ impl WorkerOptions { impl Default for WorkerOptions { fn default() -> Self { Self { - internal_queue_length: NonZeroUsize::new(64).unwrap(), + internal_queue_length: NonZeroUsize::new(6400).unwrap(), num_timers_in_memory_limit: None, storage: StorageOptions::default(), invoker: Default::default(), @@ -181,7 +181,7 @@ impl Default for InvokerOptions { message_size_warning: NonZeroUsize::new(10_000_000).unwrap(), // 10MB message_size_limit: None, tmp_dir: None, - concurrent_invocations_limit: None, + concurrent_invocations_limit: Some(NonZeroUsize::new(10_000).unwrap()), disable_eager_state: false, } } diff --git a/tools/mock-service-endpoint/Cargo.toml b/tools/mock-service-endpoint/Cargo.toml new file mode 100644 index 000000000..1731a99fb --- /dev/null +++ b/tools/mock-service-endpoint/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "mock-service-endpoint" +version.workspace = true +authors.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +publish = false + +[dependencies] +assert2 = { workspace = true } +async-stream = "0.3.5" +bytes = { workspace = true } +futures = { workspace = true } +http = {workspace = true} +http-body-util = "0.1" +hyper = { version = "1", features = ["server"] } +hyper-util = { version = "0.1", features = ["full"] } +restate-service-protocol = { workspace = true, features = ["message", "codec"] } +restate-types = { workspace = true } +prost = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +thiserror = { workspace = true } diff --git a/tools/mock-service-endpoint/src/main.rs b/tools/mock-service-endpoint/src/main.rs new file mode 100644 index 000000000..eacf74257 --- /dev/null +++ b/tools/mock-service-endpoint/src/main.rs @@ -0,0 +1,432 @@ +use std::convert::Infallible; +use std::fmt::{Display, Formatter}; +use std::net::SocketAddr; +use std::str::FromStr; + +use assert2::let_assert; +use async_stream::{stream, try_stream}; +use bytes::Bytes; +use futures::{pin_mut, Stream, StreamExt}; +use http_body_util::{BodyStream, Either, Empty, Full, StreamBody}; +use hyper::body::{Frame, Incoming}; +use hyper::server::conn::http2; +use hyper::service::service_fn; +use hyper::{Request, Response}; +use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; +use prost::Message; +use tokio::net::TcpListener; +use tracing::{debug, error, info}; +use tracing_subscriber::filter::LevelFilter; + +use restate_service_protocol::codec::ProtobufRawEntryCodec; +use restate_service_protocol::message::{Decoder, Encoder, EncodingError, ProtocolMessage}; +use restate_types::errors::codes; +use restate_types::journal::raw::{EntryHeader, PlainRawEntry, RawEntryCodecError}; +use restate_types::journal::{Entry, EntryType, InputEntry}; +use restate_types::service_protocol::start_message::StateEntry; +use restate_types::service_protocol::{ + self, get_state_entry_message, output_entry_message, ServiceProtocolVersion, StartMessage, +}; + +#[derive(Debug, thiserror::Error)] +enum FrameError { + #[error(transparent)] + EncodingError(EncodingError), + #[error(transparent)] + Hyper(hyper::Error), + #[error("Stream ended before finished replay")] + UnexpectedEOF, + #[error("Journal does not contain expected messages")] + InvalidJournal, + #[error(transparent)] + RawEntryCodecError(#[from] RawEntryCodecError), + #[error(transparent)] + Serde(#[from] serde_json::Error), +} + +async fn serve( + req: Request, +) -> Result< + Response< + Either, StreamBody, Infallible>>>>, + >, + Infallible, +> { + let (req_head, req_body) = req.into_parts(); + let mut split = req_head.uri.path().rsplit('/'); + let handler_name = if let Some(handler_name) = split.next() { + handler_name + } else { + return Ok(Response::builder() + .status(404) + .body(Either::Left(Empty::new())) + .unwrap()); + }; + if let Some("Counter") = split.next() { + } else { + return Ok(Response::builder() + .status(404) + .body(Either::Left(Empty::new())) + .unwrap()); + }; + if let Some("invoke") = split.next() { + } else { + return Ok(Response::builder() + .status(404) + .body(Either::Left(Empty::new())) + .unwrap()); + }; + + let req_body = BodyStream::new(req_body); + let mut decoder = Decoder::new(ServiceProtocolVersion::V1, usize::MAX, None); + let encoder = Encoder::new(ServiceProtocolVersion::V1); + + let incoming = stream! { + for await frame in req_body { + match frame { + Ok(frame) => { + if let Ok(data) = frame.into_data() { + decoder.push(data); + loop { + match decoder.consume_next() { + Ok(Some((_header, message))) => yield Ok(message), + Ok(None) => { + break + }, + Err(err) => yield Err(FrameError::EncodingError(err)), + } + } + } + }, + Err(err) => yield Err(FrameError::Hyper(err)), + }; + } + }; + + let handler: Handler = match handler_name.parse() { + Ok(handler) => handler, + Err(_err) => { + return Ok(Response::builder() + .status(404) + .body(Either::Left(Empty::new())) + .unwrap()); + } + }; + + let outgoing = handler.handle(incoming).map(move |message| match message { + Ok(message) => Ok(Frame::data(encoder.encode(message))), + Err(err) => { + error!("Error handling stream: {err:?}"); + Ok(Frame::data(encoder.encode(error(err)))) + } + }); + + Ok(Response::builder() + .status(200) + .header("content-type", "application/vnd.restate.invocation.v1") + .body(Either::Right(StreamBody::new(outgoing))) + .unwrap()) +} + +enum Handler { + Get, + Add, +} + +#[derive(Debug, thiserror::Error)] +#[error("Invalid handler")] +struct InvalidHandler; + +impl FromStr for Handler { + type Err = InvalidHandler; + + fn from_str(s: &str) -> Result { + match s { + "get" => Ok(Self::Get), + "add" => Ok(Self::Add), + _ => Err(InvalidHandler), + } + } +} + +impl Display for Handler { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Get => write!(f, "get"), + Self::Add => write!(f, "add"), + } + } +} + +impl Handler { + fn handle( + self, + incoming: impl Stream>, + ) -> impl Stream> { + try_stream! { + pin_mut!(incoming); + match (incoming.next().await, incoming.next().await) { + (Some(Ok(ProtocolMessage::Start(start_message))), Some(Ok(ProtocolMessage::UnparsedEntry(input)))) if input.ty() == EntryType::Input => { + let input = input.deserialize_entry_ref::()?; + let_assert!( + Entry::Input(input) = input + ); + + let replay_count = start_message.known_entries as usize - 1; + let mut replayed = Vec::with_capacity(replay_count); + for _ in 0..replay_count { + let message = incoming.next().await.ok_or(FrameError::UnexpectedEOF)??; + replayed.push(message); + } + + debug!("Handling request to {self} with {} known entries", start_message.known_entries); + + match self { + Handler::Get => { + for await message in Self::handle_get(start_message, input, replayed, incoming) { + yield message? + } + }, + Handler::Add => { + for await message in Self::handle_add(start_message, input, replayed, incoming) { + yield message? + } + }, + }; + }, + _ => {Err(FrameError::InvalidJournal)?; return}, + }; + } + } + + fn handle_get( + start_message: StartMessage, + _input: InputEntry, + replayed: Vec, + _incoming: impl Stream>, + ) -> impl Stream> { + try_stream! { + let counter = read_counter(&start_message.state_map); + match replayed.len() { + 0 => { + yield get_state(counter.clone()); + yield output(counter.unwrap_or("0".into())); + yield end(); + }, + 1 => { + yield output(counter.unwrap_or("0".into())); + yield end(); + } + 2=> { + yield end(); + } + _ => {Err(FrameError::InvalidJournal)?; return}, + } + } + } + + fn handle_add( + start_message: StartMessage, + input: InputEntry, + replayed: Vec, + _incoming: impl Stream>, + ) -> impl Stream> { + try_stream! { + let counter = read_counter(&start_message.state_map); + match replayed.len() { + 0 => { + yield get_state(counter.clone()); + + let next_value = match counter { + Some(ref counter) => { + let to_add: i32 = serde_json::from_slice(input.value.as_ref())?; + let current: i32 = serde_json::from_slice(counter.as_ref())?; + + serde_json::to_vec(&(to_add + current))?.into() + } + None => input.value, + }; + + yield set_state(next_value.clone()); + yield output(next_value); + yield end(); + }, + 1 => { + let next_value = match counter { + Some(ref counter) => { + let to_add: i32 = serde_json::from_slice(input.value.as_ref())?; + let current: i32 = serde_json::from_slice(counter.as_ref())?; + + serde_json::to_vec(&(to_add + current))?.into() + } + None => input.value, + }; + + yield set_state(next_value.clone()); + yield output(next_value); + yield end(); + } + 2 => { + let set_value = match &replayed[1] { + ProtocolMessage::UnparsedEntry(set) if set.ty() == EntryType::SetState => { + let set = set.deserialize_entry_ref::()?; + let_assert!( + Entry::SetState(set) = set + ); + set.value.clone() + }, + _ => {Err(FrameError::InvalidJournal)?; return}, + }; + yield output(set_value); + yield end(); + } + 3 => { + yield end(); + } + _ => {Err(FrameError::InvalidJournal)?; return}, + } + } + } +} + +fn read_counter(state_map: &[StateEntry]) -> Option { + let entry = state_map + .iter() + .find(|entry| entry.key.as_ref() == b"counter")?; + Some(entry.value.clone()) +} + +fn get_state(counter: Option) -> ProtocolMessage { + debug!( + "Yielding GetStateEntryMessage with value {}", + LossyDisplay(counter.as_deref()) + ); + + ProtocolMessage::UnparsedEntry(PlainRawEntry::new( + EntryHeader::GetState { is_completed: true }, + service_protocol::GetStateEntryMessage { + name: String::new(), + key: "counter".into(), + result: Some(match counter { + Some(ref counter) => get_state_entry_message::Result::Value(counter.clone()), + None => get_state_entry_message::Result::Empty(service_protocol::Empty {}), + }), + } + .encode_to_vec() + .into(), + )) +} + +fn set_state(value: Bytes) -> ProtocolMessage { + debug!( + "Yielding SetStateEntryMessage with value {}", + LossyDisplay(Some(&value)) + ); + + ProtocolMessage::UnparsedEntry(PlainRawEntry::new( + EntryHeader::SetState, + service_protocol::SetStateEntryMessage { + name: String::new(), + key: "counter".into(), + value: value.clone(), + } + .encode_to_vec() + .into(), + )) +} + +fn output(value: Bytes) -> ProtocolMessage { + debug!( + "Yielding OutputEntryMessage with result {}", + LossyDisplay(Some(&value)) + ); + + ProtocolMessage::UnparsedEntry(PlainRawEntry::new( + EntryHeader::Output, + service_protocol::OutputEntryMessage { + name: String::new(), + result: Some(output_entry_message::Result::Value(value)), + } + .encode_to_vec() + .into(), + )) +} + +fn end() -> ProtocolMessage { + debug!("Yielding EndMessage"); + + ProtocolMessage::End(service_protocol::EndMessage {}) +} + +fn error(err: FrameError) -> ProtocolMessage { + let code = match err { + FrameError::EncodingError(_) => codes::PROTOCOL_VIOLATION, + FrameError::Hyper(_) => codes::INTERNAL, + FrameError::UnexpectedEOF => codes::PROTOCOL_VIOLATION, + FrameError::InvalidJournal => codes::JOURNAL_MISMATCH, + FrameError::RawEntryCodecError(_) => codes::PROTOCOL_VIOLATION, + FrameError::Serde(_) => codes::INTERNAL, + }; + ProtocolMessage::Error(service_protocol::ErrorMessage { + code: code.into(), + description: err.to_string(), + message: String::new(), + related_entry_index: None, + related_entry_name: None, + related_entry_type: None, + }) +} + +struct LossyDisplay<'a>(Option<&'a [u8]>); +impl<'a> Display for LossyDisplay<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.0 { + Some(bytes) => write!(f, "{}", String::from_utf8_lossy(bytes)), + None => write!(f, ""), + } + } +} + +#[tokio::main] +pub async fn main() -> Result<(), Box> { + let format = tracing_subscriber::fmt::format().compact(); + + tracing_subscriber::fmt() + .event_format(format) + .with_env_filter( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .init(); + + let addr: SocketAddr = ([127, 0, 0, 1], 9080).into(); + + let listener = TcpListener::bind(addr).await?; + info!("Listening on http://{}", addr); + loop { + let (tcp, _) = listener.accept().await?; + let io = TokioIo::new(tcp); + + tokio::task::spawn(async move { + if let Err(err) = http2::Builder::new(TokioExecutor::new()) + .timer(TokioTimer::new()) + .serve_connection(io, service_fn(|req| async { + if req.uri().path() == "/discover" { + return Ok(Response::builder() + .header("content-type", "application/vnd.restate.endpointmanifest.v1+json") + .body(Either::Left(Full::new(Bytes::from( + r#"{"protocolMode":"BIDI_STREAM","minProtocolVersion":1,"maxProtocolVersion":2,"services":[{"name":"Counter","ty":"VIRTUAL_OBJECT","handlers":[{"name":"add","input":{"required":false,"contentType":"application/json"},"output":{"setContentTypeIfEmpty":false,"contentType":"application/json"},"ty":"EXCLUSIVE"},{"name":"get","input":{"required":false,"contentType":"application/json"},"output":{"setContentTypeIfEmpty":false,"contentType":"application/json"},"ty":"EXCLUSIVE"}]}]}"# + )))).unwrap()); + } + + let (head, body) = serve(req).await?.into_parts(); + Result::<_, Infallible>::Ok(Response::from_parts(head, Either::Right(body))) + })) + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } +}