Skip to content

Commit

Permalink
refactor: Create data layer client wrapper to enable mocking
Browse files Browse the repository at this point in the history
  • Loading branch information
morgsmccauley committed Aug 5, 2024
1 parent c2cccb0 commit 53a26a0
Showing 1 changed file with 66 additions and 19 deletions.
85 changes: 66 additions & 19 deletions coordinator/src/handlers/data_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,76 @@ pub use runner::data_layer::TaskStatus;

use anyhow::Context;
use runner::data_layer::data_layer_client::DataLayerClient;
use runner::data_layer::{DeprovisionRequest, GetTaskStatusRequest, ProvisionRequest};
use runner::data_layer::{
DeprovisionRequest, GetTaskStatusRequest, GetTaskStatusResponse, ProvisionRequest,
StartTaskResponse,
};
use tonic::transport::channel::Channel;
use tonic::{Request, Status};
use tonic::Status;

use crate::indexer_config::IndexerConfig;

type TaskId = String;

const TASK_TIMEOUT_SECONDS: u64 = 600; // 10 minutes

#[cfg(not(test))]
use DataLayerClientWrapperImpl as DataLayerClientWrapper;
#[cfg(test)]
use MockDataLayerClientWrapperImpl as DataLayerClientWrapper;

#[derive(Clone)]
struct DataLayerClientWrapperImpl {
inner: DataLayerClient<Channel>,
}

#[cfg(test)]
impl Clone for MockDataLayerClientWrapperImpl {
fn clone(&self) -> Self {
Self::default()
}
}

#[cfg_attr(test, mockall::automock)]
impl DataLayerClientWrapperImpl {
pub fn new(inner: DataLayerClient<Channel>) -> Self {
Self { inner }
}

pub async fn start_provisioning_task<R>(
&self,
request: R,
) -> std::result::Result<tonic::Response<StartTaskResponse>, tonic::Status>
where
R: tonic::IntoRequest<ProvisionRequest> + 'static,
{
self.inner.clone().start_provisioning_task(request).await
}

pub async fn start_deprovisioning_task<R>(
&self,
request: R,
) -> std::result::Result<tonic::Response<StartTaskResponse>, tonic::Status>
where
R: tonic::IntoRequest<DeprovisionRequest> + 'static,
{
self.inner.clone().start_deprovisioning_task(request).await
}

pub async fn get_task_status<R>(
&self,
request: R,
) -> std::result::Result<tonic::Response<GetTaskStatusResponse>, tonic::Status>
where
R: tonic::IntoRequest<GetTaskStatusRequest> + 'static,
{
self.inner.clone().get_task_status(request).await
}
}

#[derive(Clone)]
pub struct DataLayerHandler {
client: DataLayerClient<Channel>,
client: DataLayerClientWrapper,
}

impl DataLayerHandler {
Expand All @@ -28,7 +85,9 @@ impl DataLayerHandler {
.connect_lazy();
let client = DataLayerClient::new(channel);

Ok(Self { client })
Ok(Self {
client: DataLayerClientWrapper::new(client),
})
}

pub async fn start_provisioning_task(
Expand All @@ -41,11 +100,7 @@ impl DataLayerHandler {
schema: indexer_config.schema.clone(),
};

let response = self
.client
.clone()
.start_provisioning_task(Request::new(request))
.await?;
let response = self.client.start_provisioning_task(request).await?;

Ok(response.into_inner().task_id)
}
Expand All @@ -60,23 +115,15 @@ impl DataLayerHandler {
function_name,
};

let response = self
.client
.clone()
.start_deprovisioning_task(Request::new(request))
.await?;
let response = self.client.start_deprovisioning_task(request).await?;

Ok(response.into_inner().task_id)
}

pub async fn get_task_status(&self, task_id: TaskId) -> anyhow::Result<TaskStatus> {
let request = GetTaskStatusRequest { task_id };

let response = self
.client
.clone()
.get_task_status(Request::new(request))
.await;
let response = self.client.get_task_status(request).await;

if let Err(error) = response {
if error.code() == tonic::Code::NotFound {
Expand Down

0 comments on commit 53a26a0

Please sign in to comment.