diff --git a/crates/jstzd/src/task/jstzd.rs b/crates/jstzd/src/task/jstzd.rs index 73554eab4..2d4e3be8c 100644 --- a/crates/jstzd/src/task/jstzd.rs +++ b/crates/jstzd/src/task/jstzd.rs @@ -1,6 +1,6 @@ use super::{octez_baker::OctezBaker, octez_node::OctezNode, utils::retry, Task}; use anyhow::Result; -use async_dropper_simple::AsyncDrop; +use async_dropper_simple::{AsyncDrop, AsyncDropper}; use async_trait::async_trait; use axum::{ extract::{Path, State}, @@ -167,75 +167,96 @@ impl Jstzd { } } -#[derive(Clone)] -pub struct JstzdServer { - jstzd_server_port: u16, +#[derive(Clone, Default)] +pub struct JstzdServerInner { state: Arc>, } +#[derive(Default)] struct ServerState { - jstzd_config: JstzdConfig, + jstzd_config: Option, jstzd_config_json: serde_json::Map, jstzd: Option, server_handle: Option>, } #[async_trait] -impl AsyncDrop for JstzdServer { +impl AsyncDrop for JstzdServerInner { async fn async_drop(&mut self) { let mut lock = self.state.write().await; let _ = shutdown(&mut lock).await; } } +pub struct JstzdServer { + port: u16, + inner: Arc>, +} + impl JstzdServer { pub fn new(config: JstzdConfig, port: u16) -> Self { Self { - jstzd_server_port: port, - state: Arc::new(RwLock::new(ServerState { - jstzd_config_json: serde_json::to_value(&config) - .unwrap() - .as_object() - .unwrap() - .to_owned(), - jstzd_config: config, - jstzd: None, - server_handle: None, + port, + inner: Arc::new(AsyncDropper::new(JstzdServerInner { + state: Arc::new(RwLock::new(ServerState { + jstzd_config_json: serde_json::to_value(&config) + .unwrap() + .as_object() + .unwrap() + .to_owned(), + jstzd_config: Some(config), + jstzd: None, + server_handle: None, + })), })), } } pub async fn health_check(&self) -> bool { - let lock = self.state.read().await; + let lock = self.inner.state.read().await; health_check(&lock).await } pub async fn stop(&mut self) -> Result<()> { - let mut lock = self.state.write().await; + let mut lock = self.inner.state.write().await; shutdown(&mut lock).await } pub async fn run(&mut self) -> Result<()> { - let jstzd = Jstzd::spawn(self.state.read().await.jstzd_config.clone()).await?; - self.state.write().await.jstzd.replace(jstzd); + let jstzd = Jstzd::spawn( + self.inner + .state + .read() + .await + .jstzd_config + .as_ref() + .ok_or(anyhow::anyhow!( + // shouldn't really reach this branch since jstzd config is required at instantiation + // unless someone calls `run` after calling `stop` + "cannot run jstzd server without jstzd config" + ))? + .clone(), + ) + .await?; + self.inner.state.write().await.jstzd.replace(jstzd); let router = Router::new() .route("/health", get(health_check_handler)) .route("/shutdown", put(shutdown_handler)) .route("/config/:config_type", get(config_handler)) .route("/config/", get(all_config_handler)) - .with_state(self.state.clone()); - let listener = TcpListener::bind(("0.0.0.0", self.jstzd_server_port)).await?; + .with_state(self.inner.state.clone()); + let listener = TcpListener::bind(("0.0.0.0", self.port)).await?; let handle = tokio::spawn(async { axum::serve(listener, router).await.unwrap(); }); - self.state.write().await.server_handle.replace(handle); + self.inner.state.write().await.server_handle.replace(handle); Ok(()) } pub async fn baker_healthy(&self) -> bool { - if let Some(v) = &self.state.read().await.jstzd { + if let Some(v) = &self.inner.state.read().await.jstzd { v.baker.read().await.health_check().await.unwrap_or(false) } else { false @@ -267,6 +288,8 @@ async fn shutdown(state: &mut ServerState) -> Result<()> { if let Some(server) = state.server_handle.take() { server.abort(); } + state.jstzd_config.take(); + state.jstzd_config_json.clear(); Ok(()) } diff --git a/crates/jstzd/tests/jstzd_test.rs b/crates/jstzd/tests/jstzd_test.rs index b3ffe380e..0bb56b1b9 100644 --- a/crates/jstzd/tests/jstzd_test.rs +++ b/crates/jstzd/tests/jstzd_test.rs @@ -39,6 +39,12 @@ async fn jstzd_test() { .unwrap(); ensure_jstzd_components_are_down(&jstzd, &rpc_endpoint, jstzd_port).await; + + // calling `run` after calling `stop` should fail because all states should have been cleared + assert_eq!( + jstzd.run().await.unwrap_err().to_string(), + "cannot run jstzd server without jstzd config" + ); } async fn create_jstzd_server(