diff --git a/examples/simple-source/src/main.rs b/examples/simple-source/src/main.rs index e1328bd..9127211 100644 --- a/examples/simple-source/src/main.rs +++ b/examples/simple-source/src/main.rs @@ -60,11 +60,9 @@ pub(crate) mod simple_source { self.yet_to_ack.write().unwrap().extend(message_offsets) } - async fn ack(&self, offsets: Vec) { - for offset in offsets { - let x = &String::from_utf8(offset.offset).unwrap(); - self.yet_to_ack.write().unwrap().remove(x); - } + async fn ack(&self, offset: Offset) { + let x = &String::from_utf8(offset.offset).unwrap(); + self.yet_to_ack.write().unwrap().remove(x); } async fn pending(&self) -> usize { diff --git a/proto/source.proto b/proto/source.proto index 4352028..dcaf253 100644 --- a/proto/source.proto +++ b/proto/source.proto @@ -7,16 +7,17 @@ package source.v1; service Source { // Read returns a stream of datum responses. - // The size of the returned ReadResponse is less than or equal to the num_records specified in ReadRequest. - // If the request timeout is reached on server side, the returned ReadResponse will contain all the datum that have been read (which could be an empty list). - rpc ReadFn(ReadRequest) returns (stream ReadResponse); + // The size of the returned ReadResponse is less than or equal to the num_records specified in each ReadRequest. + // If the request timeout is reached on the server side, the returned ReadResponse will contain all the datum that have been read (which could be an empty list). + // The server will continue to read and respond to subsequent ReadRequests until the client closes the stream. + rpc ReadFn(stream ReadRequest) returns (stream ReadResponse); - // AckFn acknowledges a list of datum offsets. + // AckFn acknowledges a stream of datum offsets. // When AckFn is called, it implicitly indicates that the datum stream has been processed by the source vertex. // The caller (numa) expects the AckFn to be successful, and it does not expect any errors. // If there are some irrecoverable errors when the callee (UDSource) is processing the AckFn request, - // then it is best to crash because there are no other retry mechanisms possible. - rpc AckFn(AckRequest) returns (AckResponse); + // then it is best to crash because there are no other retry mechanisms possible. + rpc AckFn(stream AckRequest) returns (AckResponse); // PendingFn returns the number of pending records at the user defined source. rpc PendingFn(google.protobuf.Empty) returns (PendingResponse); @@ -60,9 +61,35 @@ message ReadResponse { // We add this optional field to support the use case where the user defined source can provide keys for the datum. // e.g. Kafka and Redis Stream message usually include information about the keys. repeated string keys = 4; + // Optional list of headers associated with the datum. + // Headers are the metadata associated with the datum. + // e.g. Kafka and Redis Stream message usually include information about the headers. + map headers = 5; + } + message Status { + // Code to indicate the status of the response. + enum Code { + SUCCESS = 0; + FAILURE = 1; + } + + // Error to indicate the error type. If the code is FAILURE, then the error field will be populated. + enum Error { + UNACKED = 0; + OTHER = 1; + } + + // End of transmission flag. + bool eot = 1; + Code code = 2; + Error error = 3; + optional string msg = 4; } // Required field holding the result. Result result = 1; + // Status of the response. Holds the end of transmission flag and the status code. + // + Status status = 2; } /* @@ -71,11 +98,8 @@ message ReadResponse { */ message AckRequest { message Request { - // Required field holding a list of offsets to be acknowledged. - // The offsets must be strictly corresponding to the previously read batch, - // meaning the offsets must be in the same order as the datum responses in the ReadResponse. - // By enforcing ordering, we can save deserialization effort on the server side, assuming the server keeps a local copy of the raw/un-serialized offsets. - repeated Offset offsets = 1; + // Required field holding the offset to be acked + Offset offset = 1; } // Required field holding the request. The list will be ordered and will have the same order as the original Read response. Request request = 1; @@ -146,4 +170,4 @@ message Offset { // It is useful for sources that have multiple partitions. e.g. Kafka. // If the partition_id is not specified, it is assumed that the source has a single partition. int32 partition_id = 2; -} \ No newline at end of file +} diff --git a/src/error.rs b/src/error.rs index e33102f..5a3818a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,5 +1,7 @@ use thiserror::Error; +pub type Result = std::result::Result; + #[derive(Error, Debug, Clone)] pub enum ErrorKind { #[error("User Defined Error: {0}")] diff --git a/src/source.rs b/src/source.rs index af3cf3e..94cc841 100644 --- a/src/source.rs +++ b/src/source.rs @@ -4,17 +4,24 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; +use crate::error::Error::SourceError; +use crate::error::{Error, ErrorKind}; use crate::shared::{self, prost_timestamp_from_utc}; +use crate::source::proto::{AckRequest, AckResponse, ReadRequest, ReadResponse}; use chrono::{DateTime, Utc}; -use tokio::sync::mpsc::{self, Sender}; +use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio::sync::oneshot; +use tokio::task::JoinHandle; use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; -use tonic::{async_trait, Request, Response, Status}; +use tonic::{async_trait, Request, Response, Status, Streaming}; +use tracing::{error, info}; const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/source.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sourcer-server-info"; +// TODO: use batch-size, blocked by https://github.com/numaproj/numaflow/issues/2026 +const DEFAULT_CHANNEL_SIZE: usize = 1000; /// Source Proto definitions. pub mod proto { @@ -23,10 +30,11 @@ pub mod proto { struct SourceService { handler: Arc, - _shutdown_tx: Sender<()>, - _cancellation_token: CancellationToken, + shutdown_tx: Sender<()>, + cancellation_token: CancellationToken, } +// FIXME: remove async_trait #[async_trait] /// Trait representing a [user defined source](https://numaflow.numaproj.io/user-guide/sources/overview/). /// @@ -46,8 +54,8 @@ struct SourceService { pub trait Sourcer { /// Reads the messages from the source and sends them to the transmitter. async fn read(&self, request: SourceReadRequest, transmitter: Sender); - /// Acknowledges the messages that have been processed by the user-defined source. - async fn ack(&self, offsets: Vec); + /// Acknowledges the message that has been processed by the user-defined source. + async fn ack(&self, offset: Offset); /// Returns the number of messages that are yet to be processed by the user-defined source. async fn pending(&self) -> usize; /// Returns the partitions associated with the source. This will be used by the platform to determine @@ -73,29 +81,21 @@ pub struct Offset { pub partition_id: i32, } -#[async_trait] -impl proto::source_server::Source for SourceService +impl SourceService where T: Sourcer + Send + Sync + 'static, { - type ReadFnStream = ReceiverStream>; - - async fn read_fn( - &self, - request: Request, - ) -> Result, Status> { - let sr = request.into_inner().request.unwrap(); - - // tx,rx pair for sending data over to user-defined source - let (stx, mut srx) = mpsc::channel::(sr.num_records as usize); - // tx,rx pair for gRPC response - let (tx, rx) = - mpsc::channel::>(sr.num_records as usize); - - // start the ud-source rx asynchronously and start populating the gRPC response, so it can be streamed to the gRPC client (numaflow). - tokio::spawn(async move { - while let Some(resp) = srx.recv().await { - tx.send(Ok(proto::ReadResponse { + /// writes a read batch returned by the user-defined handler to the client (numaflow). + async fn write_a_batch( + grpc_resp_tx: Sender>, + mut udsource_rx: Receiver, + ) -> crate::error::Result<()> { + // even though we use bi-di; the user-defined source sees this as a 1/2 duplex + // server side streaming. this means that the below while loop will terminate + // after every batch of read has been returned. + while let Some(resp) = udsource_rx.recv().await { + grpc_resp_tx + .send(Ok(ReadResponse { result: Some(proto::read_response::Result { payload: resp.value, offset: Some(proto::Offset { @@ -104,26 +104,139 @@ where }), event_time: prost_timestamp_from_utc(resp.event_time), keys: resp.keys, + headers: Default::default(), }), + status: None, })) .await - .expect("receiver dropped"); + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; + } + + // send end of transmission on success + grpc_resp_tx + .send(Ok(ReadResponse { + result: None, + status: Some(proto::read_response::Status { + eot: true, + code: 0, + error: 0, + msg: None, + }), + })) + .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; + + Ok(()) + } + + /// Invokes the user-defined source handler to get a read batch and streams it to the numaflow + /// (client). + async fn forward_a_batch( + handler_fn: Arc, + grpc_resp_tx: Sender>, + request: proto::read_request::Request, + ) -> crate::error::Result<()> { + // tx,rx pair for sending data over to user-defined source + let (stx, srx) = mpsc::channel::(DEFAULT_CHANNEL_SIZE); + + // spawn the rx side so that when the handler is invoked, we can stream the handler's read data + // to the gprc response stream. + let grpc_writer_handle: JoinHandle> = + tokio::spawn(async move { Self::write_a_batch(grpc_resp_tx, srx).await }); + + // spawn the handler, it will stream the data to tx passed which will be streamed to the client + // by the above task. + handler_fn + .read( + SourceReadRequest { + count: request.num_records as usize, + timeout: Duration::from_millis(request.timeout_in_ms as u64), + }, + stx, + ) + .await; + + // wait for the spawned grpc writer to end + grpc_writer_handle + .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; + + Ok(()) + } +} + +#[async_trait] +impl proto::source_server::Source for SourceService +where + T: Sourcer + Send + Sync + 'static, +{ + type ReadFnStream = ReceiverStream>; + + async fn read_fn( + &self, + request: Request>, + ) -> Result, Status> { + let mut sr = request.into_inner(); + // we have to call the handler over and over for each ReadRequest + let handler_fn = Arc::clone(&self.handler); + + // tx (read from client), rx (write to client) pair for gRPC response + let (tx, rx) = mpsc::channel::>(DEFAULT_CHANNEL_SIZE); + + // this _tx ends up writing to the client side + let grpc_tx = tx.clone(); + + let cln_token = self.cancellation_token.clone(); + + // this is the top-level stream consumer and this task will only exit when stream is closed (which + // will happen when server and client are shutting down). + let grpc_read_handle: JoinHandle> = tokio::spawn(async move { + loop { + tokio::select! { + // for each ReadRequest message, the handler will be called and a batch of messages + // will be sent over to the client. + read_request = sr.message() => { + let read_request = read_request + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? + .ok_or_else(|| SourceError(ErrorKind::InternalError("Stream closed".to_string())))?; + + let request = read_request.request.ok_or_else(|| SourceError(ErrorKind::InternalError("Stream closed".to_string())))?; + + // start the ud-source rx asynchronously and start populating the gRPC + // response, so it can be streamed to the gRPC client (numaflow). + let grpc_resp_tx = grpc_tx.clone(); + + // let's forward a batch for this request + Self::forward_a_batch(handler_fn.clone(), grpc_resp_tx, request).await? + } + _ = cln_token.cancelled() => { + info!("Cancellation token triggered, shutting down"); + break; + } + } } + Ok(()) }); - let handler_fn = Arc::clone(&self.handler); - // we want to start streaming to the server as soon as possible + let shutdown_tx = self.shutdown_tx.clone(); + // spawn so we can return the recv stream to client. tokio::spawn(async move { - // user-defined source read handler - handler_fn - .read( - SourceReadRequest { - count: sr.num_records as usize, - timeout: Duration::from_millis(sr.timeout_in_ms as u64), - }, - stx, - ) - .await + // wait for the grpc read handle; if there are any errors, we set the gRPC Status to failure + // which will close the stream with failure. + if let Err(e) = grpc_read_handle.await { + error!("shutting down the gRPC channel, {}", e); + tx.send(Err(Status::internal(e.to_string()))) + .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string()))) + .expect("writing error to grpc response channel should never fail"); + + // if there are any failures, we propagate those failures so that the server can shutdown. + shutdown_tx + .send(()) + .await + .expect("write to shutdown channel should never fail"); + } }); Ok(Response::new(ReceiverStream::new(rx))) @@ -131,31 +244,29 @@ where async fn ack_fn( &self, - request: Request, - ) -> Result, Status> { - let ar: proto::AckRequest = request.into_inner(); - - let success_response = Response::new(proto::AckResponse { + request: Request>, + ) -> Result, Status> { + let mut ack_stream = request.into_inner(); + while let Some(ack_request) = ack_stream.message().await? { + // the request is not there send back status as invalid argument + let Some(request) = ack_request.request else { + return Err(Status::invalid_argument("request is empty")); + }; + + let Some(offset) = request.offset else { + return Err(Status::invalid_argument("offset is not present")); + }; + + self.handler + .ack(Offset { + offset: offset.clone().offset, + partition_id: offset.partition_id, + }) + .await; + } + Ok(Response::new(AckResponse { result: Some(proto::ack_response::Result { success: Some(()) }), - }); - - let Some(request) = ar.request else { - return Ok(success_response); - }; - - // invoke the user-defined source's ack handler - let offsets = request - .offsets - .into_iter() - .map(|so| Offset { - offset: so.offset, - partition_id: so.partition_id, - }) - .collect(); - - self.handler.ack(offsets).await; - - Ok(success_response) + })) } async fn pending_fn(&self, _: Request<()>) -> Result, Status> { @@ -276,8 +387,8 @@ impl Server { let source_service = SourceService { handler: Arc::new(handler), - _shutdown_tx: internal_shutdown_tx, - _cancellation_token: cln_token.clone(), + shutdown_tx: internal_shutdown_tx, + cancellation_token: cln_token.clone(), }; let source_svc = proto::source_server::SourceServer::new(source_service) @@ -318,18 +429,19 @@ impl Drop for Server { #[cfg(test)] mod tests { use super::{proto, Message, Offset, SourceReadRequest}; + use crate::source; use chrono::Utc; use std::collections::{HashMap, HashSet}; + use std::error::Error; + use std::time::Duration; use std::vec; - use std::{error::Error, time::Duration}; - - use crate::source; use tempfile::TempDir; use tokio::net::UnixStream; use tokio::sync::mpsc::Sender; - use tokio::sync::oneshot; - use tokio_stream::StreamExt; + use tokio::sync::{mpsc, oneshot}; + use tokio_stream::wrappers::ReceiverStream; use tonic::transport::Uri; + use tonic::Request; use tower::service_fn; use uuid::Uuid; @@ -377,13 +489,11 @@ mod tests { self.yet_to_ack.write().unwrap().extend(message_offsets) } - async fn ack(&self, offsets: Vec) { - for offset in offsets { - self.yet_to_ack - .write() - .unwrap() - .remove(&String::from_utf8(offset.offset).unwrap()); - } + async fn ack(&self, offset: Offset) { + self.yet_to_ack + .write() + .unwrap() + .remove(&String::from_utf8(offset.offset).unwrap()); } async fn pending(&self) -> usize { @@ -433,83 +543,69 @@ mod tests { .await?; let mut client = proto::source_client::SourceClient::new(channel); - let request = tonic::Request::new(proto::ReadRequest { + + // Test read_fn with bidirectional streaming + let (read_tx, read_rx) = mpsc::channel(4); + let read_request = proto::ReadRequest { request: Some(proto::read_request::Request { num_records: 5, - timeout_in_ms: 500, + timeout_in_ms: 1000, }), - }); + }; + read_tx.send(read_request).await.unwrap(); + drop(read_tx); // Close the sender to indicate no more requests - let resp = client.read_fn(request).await?; - let resp = resp.into_inner(); - let result: Vec = resp - .map(|item| item.unwrap().result.unwrap()) - .collect() - .await; - let response_values: Vec = result - .iter() - .map(|item| { - usize::from_le_bytes( - item.payload - .clone() - .try_into() - .expect("expected Vec length to be 8"), - ) - }) - .collect(); - assert_eq!(response_values, vec![8, 8, 8, 8, 8]); - - let pending_before_ack = client - .pending_fn(tonic::Request::new(())) - .await - .unwrap() - .into_inner(); - assert_eq!( - pending_before_ack.result.unwrap().count, - 5, - "Expected pending messages to be 5 before ACK" - ); - - let offsets_to_ack: Vec = result - .iter() - .map(|item| item.clone().offset.unwrap()) - .collect(); - let ack_request = tonic::Request::new(proto::AckRequest { - request: Some(proto::ack_request::Request { - offsets: offsets_to_ack, - }), - }); - let resp = client.ack_fn(ack_request).await.unwrap().into_inner(); - assert!( - resp.result.unwrap().success.is_some(), - "Expected acknowledgement request to be successful" - ); - - let pending_before_ack = client - .pending_fn(tonic::Request::new(())) - .await - .unwrap() + let mut response_stream = client + .read_fn(Request::new(ReceiverStream::new(read_rx))) + .await? .into_inner(); - assert_eq!( - pending_before_ack.result.unwrap().count, - 0, - "Expected pending messages to be 0 after ACK" - ); - - let partitions = client - .partitions_fn(tonic::Request::new(())) - .await - .unwrap() + let mut response_values = Vec::new(); + + while let Some(response) = response_stream.message().await? { + if let Some(status) = response.status { + if status.eot { + break; + } + } + + if let Some(result) = response.result { + response_values.push(result); + } + } + assert_eq!(response_values.len(), 5); + + // Test pending_fn + let pending_before_ack = client.pending_fn(Request::new(())).await?.into_inner(); + assert_eq!(pending_before_ack.result.unwrap().count, 5); + + // Test ack_fn with client-side streaming + let (ack_tx, ack_rx) = mpsc::channel(10); + for resp in response_values.iter() { + let ack_request = proto::AckRequest { + request: Some(proto::ack_request::Request { + offset: Some(proto::Offset { + offset: resp.offset.clone().unwrap().offset, + partition_id: resp.offset.clone().unwrap().partition_id, + }), + }), + }; + ack_tx.send(ack_request).await.unwrap(); + } + drop(ack_tx); // Close the sender to indicate no more requests + + let ack_response = client + .ack_fn(Request::new(ReceiverStream::new(ack_rx))) + .await? .into_inner(); - assert_eq!( - partitions.result.unwrap().partitions, - vec![2], - "Expected number of partitions to be 2" - ); - - shutdown_tx - .send(()) - .expect("Sending shutdown signal to gRPC server"); + assert!(ack_response.result.unwrap().success.is_some()); + + let pending_after_ack = client.pending_fn(Request::new(())).await?.into_inner(); + assert_eq!(pending_after_ack.result.unwrap().count, 0); + + let partitions = client.partitions_fn(Request::new(())).await?.into_inner(); + assert_eq!(partitions.result.unwrap().partitions, vec![2]); + + shutdown_tx.send(()).unwrap(); tokio::time::sleep(Duration::from_millis(50)).await; assert!(task.is_finished(), "gRPC server is still running"); Ok(())