diff --git a/coordinator/src/handlers/data_layer.rs b/coordinator/src/handlers/data_layer.rs index ff2e2657..f88d0090 100644 --- a/coordinator/src/handlers/data_layer.rs +++ b/coordinator/src/handlers/data_layer.rs @@ -6,9 +6,12 @@ pub use runner::data_layer::TaskStatus; use anyhow::Context; use runner::data_layer::data_layer_client::DataLayerClient; -use runner::data_layer::{DeprovisionRequest, GetTaskStatusRequest, ProvisionRequest}; +use runner::data_layer::{ + DeprovisionRequest, GetTaskStatusRequest, GetTaskStatusResponse, ProvisionRequest, + StartTaskResponse, +}; use tonic::transport::channel::Channel; -use tonic::{Request, Status}; +use tonic::Status; use crate::indexer_config::IndexerConfig; @@ -16,9 +19,63 @@ type TaskId = String; const TASK_TIMEOUT_SECONDS: u64 = 600; // 10 minutes +#[cfg(not(test))] +use DataLayerClientWrapperImpl as DataLayerClientWrapper; +#[cfg(test)] +use MockDataLayerClientWrapperImpl as DataLayerClientWrapper; + +#[derive(Clone)] +struct DataLayerClientWrapperImpl { + inner: DataLayerClient, +} + +#[cfg(test)] +impl Clone for MockDataLayerClientWrapperImpl { + fn clone(&self) -> Self { + Self::default() + } +} + +#[cfg_attr(test, mockall::automock)] +impl DataLayerClientWrapperImpl { + pub fn new(inner: DataLayerClient) -> Self { + Self { inner } + } + + pub async fn start_provisioning_task( + &self, + request: R, + ) -> std::result::Result, tonic::Status> + where + R: tonic::IntoRequest + 'static, + { + self.inner.clone().start_provisioning_task(request).await + } + + pub async fn start_deprovisioning_task( + &self, + request: R, + ) -> std::result::Result, tonic::Status> + where + R: tonic::IntoRequest + 'static, + { + self.inner.clone().start_deprovisioning_task(request).await + } + + pub async fn get_task_status( + &self, + request: R, + ) -> std::result::Result, tonic::Status> + where + R: tonic::IntoRequest + 'static, + { + self.inner.clone().get_task_status(request).await + } +} + #[derive(Clone)] pub struct DataLayerHandler { - client: DataLayerClient, + client: DataLayerClientWrapper, } impl DataLayerHandler { @@ -28,7 +85,9 @@ impl DataLayerHandler { .connect_lazy(); let client = DataLayerClient::new(channel); - Ok(Self { client }) + Ok(Self { + client: DataLayerClientWrapper::new(client), + }) } pub async fn start_provisioning_task( @@ -41,11 +100,7 @@ impl DataLayerHandler { schema: indexer_config.schema.clone(), }; - let response = self - .client - .clone() - .start_provisioning_task(Request::new(request)) - .await?; + let response = self.client.start_provisioning_task(request).await?; Ok(response.into_inner().task_id) } @@ -60,11 +115,7 @@ impl DataLayerHandler { function_name, }; - let response = self - .client - .clone() - .start_deprovisioning_task(Request::new(request)) - .await?; + let response = self.client.start_deprovisioning_task(request).await?; Ok(response.into_inner().task_id) } @@ -72,11 +123,7 @@ impl DataLayerHandler { pub async fn get_task_status(&self, task_id: TaskId) -> anyhow::Result { let request = GetTaskStatusRequest { task_id }; - let response = self - .client - .clone() - .get_task_status(Request::new(request)) - .await; + let response = self.client.get_task_status(request).await; if let Err(error) = response { if error.code() == tonic::Code::NotFound {