From d0d2bfb8de1fcee6c28362c768efd6a2928d7b56 Mon Sep 17 00:00:00 2001 From: Nathan F Yospe Date: Thu, 14 Mar 2024 18:31:38 -0400 Subject: [PATCH] Updating with StaticVersionType (#188) * Updating with StaticVersionType * harmonizing test temp names, and an important dependency update * Cargo.lock update * dependency update --- Cargo.lock | 32 ++++--- Cargo.toml | 4 +- examples/hello-world/main.rs | 10 +- src/api.rs | 174 ++++++++++++++++++----------------- src/app.rs | 153 ++++++++++++++++++------------ src/lib.rs | 43 ++++----- src/metrics.rs | 6 +- src/request.rs | 7 +- src/route.rs | 110 +++++++++++----------- src/socket.rs | 56 +++++------ src/status.rs | 7 +- 11 files changed, 334 insertions(+), 268 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 848b6962..ac46ea81 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -285,7 +285,7 @@ dependencies = [ "polling", "rustix", "slab", - "socket2", + "socket2 0.4.9", "waker-fn", ] @@ -726,24 +726,24 @@ dependencies = [ [[package]] name = "curl" -version = "0.4.44" +version = "0.4.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "509bd11746c7ac09ebd19f0b17782eae80aadee26237658a6b4808afb5c11a22" +checksum = "1e2161dd6eba090ff1594084e95fd67aeccf04382ffea77999ea94ed42ec67b6" dependencies = [ "curl-sys", "libc", "openssl-probe", "openssl-sys", "schannel", - "socket2", - "winapi", + "socket2 0.5.6", + "windows-sys 0.52.0", ] [[package]] name = "curl-sys" -version = "0.4.65+curl-8.2.1" +version = "0.4.72+curl-8.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "961ba061c9ef2fe34bbd12b807152d96f0badd2bebe7b90ce6c8c8b7572a0986" +checksum = "29cbdc8314c447d11e8fd156dcdd031d9e02a7a976163e396b548c03153bc9ea" dependencies = [ "cc", "libc", @@ -752,7 +752,7 @@ dependencies = [ "openssl-sys", "pkg-config", "vcpkg", - "winapi", + "windows-sys 0.52.0", ] [[package]] @@ -2540,6 +2540,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "socket2" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "spinning_top" version = "0.2.5" @@ -2840,7 +2850,7 @@ dependencies = [ [[package]] name = "tide-disco" -version = "0.4.7" +version = "0.4.8" dependencies = [ "anyhow", "ark-serialize", @@ -3326,8 +3336,8 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "versioned-binary-serialization" -version = "0.1.0" -source = "git+https://github.com/EspressoSystems/versioned-binary-serialization.git?tag=0.1.0#6b5bf0c0b74f8384c940880d5a988d5ec757de3e" +version = "0.1.2" +source = "git+https://github.com/EspressoSystems/versioned-binary-serialization.git?tag=0.1.2#6874f91a3c8d64acc24fe0abe4ad93c35b75eb9d" dependencies = [ "anyhow", "bincode", diff --git a/Cargo.toml b/Cargo.toml index ac6d9853..7baf58c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tide-disco" -version = "0.4.7" +version = "0.4.8" edition = "2021" authors = ["Espresso Systems "] description = "Discoverability for Tide" @@ -54,7 +54,7 @@ tracing-futures = "0.2" tracing-log = "0.2.0" tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json"] } url = "2.5.0" -versioned-binary-serialization = { git = "https://github.com/EspressoSystems/versioned-binary-serialization.git", tag = "0.1.0" } +versioned-binary-serialization = { git = "https://github.com/EspressoSystems/versioned-binary-serialization.git", tag = "0.1.2" } [target.'cfg(not(windows))'.dependencies] signal-hook-async-std = "0.2.2" diff --git a/examples/hello-world/main.rs b/examples/hello-world/main.rs index 097fd1b1..85657861 100644 --- a/examples/hello-world/main.rs +++ b/examples/hello-world/main.rs @@ -11,6 +11,10 @@ use snafu::Snafu; use std::io; use tide_disco::{Api, App, Error, RequestError, StatusCode}; use tracing::info; +use versioned_binary_serialization::version::StaticVersion; + +type StaticVer01 = StaticVersion<0, 1>; +const STATIC_VER: StaticVer01 = StaticVersion {}; #[derive(Clone, Debug, Deserialize, Serialize, Snafu)] enum HelloError { @@ -39,11 +43,11 @@ impl From for HelloError { } async fn serve(port: u16) -> io::Result<()> { - let mut app = App::<_, HelloError, 0, 1>::with_state(RwLock::new("Hello".to_string())); + let mut app = App::<_, HelloError, StaticVer01>::with_state(RwLock::new("Hello".to_string())); app.with_version(env!("CARGO_PKG_VERSION").parse().unwrap()); let mut api = - Api::, HelloError, 0, 1>::from_file("examples/hello-world/api.toml") + Api::, HelloError, StaticVer01>::from_file("examples/hello-world/api.toml") .unwrap(); api.with_version(env!("CARGO_PKG_VERSION").parse().unwrap()); @@ -75,7 +79,7 @@ async fn serve(port: u16) -> io::Result<()> { .unwrap(); app.register_module("hello", api).unwrap(); - app.serve(format!("0.0.0.0:{}", port)).await + app.serve(format!("0.0.0.0:{}", port), STATIC_VER).await } #[async_std::main] diff --git a/src/api.rs b/src/api.rs index ead4587a..5a317888 100644 --- a/src/api.rs +++ b/src/api.rs @@ -29,6 +29,7 @@ use std::fs; use std::ops::Index; use std::path::{Path, PathBuf}; use tide::http::content::Accept; +use versioned_binary_serialization::version::StaticVersionType; /// An error encountered when parsing or constructing an [Api]. #[derive(Clone, Debug, Snafu)] @@ -252,10 +253,10 @@ mod meta_defaults { /// TOML file and registered as a module of an [App](crate::App). #[derive(Derivative)] #[derivative(Debug(bound = ""))] -pub struct Api { +pub struct Api { meta: Arc, name: String, - routes: HashMap>, + routes: HashMap>, routes_by_path: HashMap>, #[derivative(Debug = "ignore")] health_check: Option>, @@ -265,34 +266,28 @@ pub struct Api { long_description: String, } -impl<'a, State, Error, const MAJOR: u16, const MINOR: u16> IntoIterator - for &'a Api -{ - type Item = &'a Route; - type IntoIter = Values<'a, String, Route>; +impl<'a, State, Error, VER: StaticVersionType> IntoIterator for &'a Api { + type Item = &'a Route; + type IntoIter = Values<'a, String, Route>; fn into_iter(self) -> Self::IntoIter { self.routes.values() } } -impl IntoIterator - for Api -{ - type Item = Route; - type IntoIter = IntoValues>; +impl IntoIterator for Api { + type Item = Route; + type IntoIter = IntoValues>; fn into_iter(self) -> Self::IntoIter { self.routes.into_values() } } -impl Index<&str> - for Api -{ - type Output = Route; +impl Index<&str> for Api { + type Output = Route; - fn index(&self, index: &str) -> &Route { + fn index(&self, index: &str) -> &Route { &self.routes[index] } } @@ -302,22 +297,20 @@ impl Index<&str> /// This type iterates over all of the routes that have a given path. /// [routes_by_path](Api::routes_by_path), in turn, returns an iterator over paths whose items /// contain a [RoutesWithPath] iterator. -pub struct RoutesWithPath<'a, State, Error, const MAJOR: u16, const MINOR: u16> { +pub struct RoutesWithPath<'a, State, Error, VER: StaticVersionType> { routes: std::slice::Iter<'a, String>, - api: &'a Api, + api: &'a Api, } -impl<'a, State, Error, const MAJOR: u16, const MINOR: u16> Iterator - for RoutesWithPath<'a, State, Error, MAJOR, MINOR> -{ - type Item = &'a Route; +impl<'a, State, Error, VER: StaticVersionType> Iterator for RoutesWithPath<'a, State, Error, VER> { + type Item = &'a Route; fn next(&mut self) -> Option { Some(&self.api.routes[self.routes.next()?]) } } -impl Api { +impl Api { /// Parse an API from a TOML specification. pub fn new(api: impl Into) -> Result { let mut api = api.into(); @@ -423,7 +416,7 @@ impl Api impl Iterator)> { + ) -> impl Iterator)> { self.routes_by_path.iter().map(|(path, routes)| { ( path.as_str(), @@ -451,9 +444,9 @@ impl Api) { + /// # use versioned_binary_serialization::version::StaticVersion; + /// # type StaticVer01 = StaticVersion<0, 1>; + /// # fn ex(api: &mut tide_disco::Api<(), (), StaticVer01>) { /// api.with_version(env!("CARGO_PKG_VERSION").parse().unwrap()); /// # } /// ``` @@ -489,12 +482,12 @@ impl Api; /// - /// # fn ex(api: &mut Api) { + /// # fn ex(api: &mut Api) { /// api.at("getstate", |req, state| async { Ok(*state) }.boxed()); /// # } /// ``` @@ -515,12 +508,12 @@ impl Api; - /// const MAJOR: u16 = 0; - /// const MINOR: u16 = 1; + /// type StaticVer01 = StaticVersion<0, 1>; /// - /// # fn ex(api: &mut Api) { + /// # fn ex(api: &mut Api) { /// api.at("increment", |req, state| async { /// let mut guard = state.lock().await; /// *guard += 1; @@ -565,6 +558,7 @@ impl Api BoxFuture<'_, Result>, T: Serialize, State: 'static + Send + Sync, + VER: 'static + Send + Sync, { let route = self.routes.get_mut(name).ok_or(ApiError::UndefinedRoute)?; if route.has_handler() { @@ -601,6 +595,7 @@ impl Api::State) -> BoxFuture<'_, Result>, T: Serialize, State: 'static + Send + Sync + ReadState, + VER: 'static + Send + Sync + StaticVersionType, { assert!(method.is_http() && !method.is_mutable()); let route = self.routes.get_mut(name).ok_or(ApiError::UndefinedRoute)?; @@ -648,12 +643,12 @@ impl Api; - /// const MAJOR: u16 = 0; - /// const MINOR: u16 = 1; + /// type StaticVer01 = StaticVersion<0, 1>; /// - /// # fn ex(api: &mut Api) { + /// # fn ex(api: &mut Api) { /// api.get("getstate", |req, state| async { Ok(*state) }.boxed()); /// # } /// ``` @@ -681,6 +676,7 @@ impl Api::State) -> BoxFuture<'_, Result>, T: Serialize, State: 'static + Send + Sync + ReadState, + VER: 'static + Send + Sync, { self.method_immutable(Method::get(), name, handler) } @@ -698,6 +694,7 @@ impl Api::State) -> BoxFuture<'_, Result>, T: Serialize, State: 'static + Send + Sync + WriteState, + VER: 'static + Send + Sync, { assert!(method.is_http() && method.is_mutable()); let route = self.routes.get_mut(name).ok_or(ApiError::UndefinedRoute)?; @@ -747,12 +744,12 @@ impl Api; - /// const MAJOR: u16 = 0; - /// const MINOR: u16 = 1; + /// type StaticVer01 = StaticVersion<0, 1>; /// - /// # fn ex(api: &mut Api) { + /// # fn ex(api: &mut Api) { /// api.post("increment", |req, state| async { /// *state += 1; /// Ok(*state) @@ -783,6 +780,7 @@ impl Api::State) -> BoxFuture<'_, Result>, T: Serialize, State: 'static + Send + Sync + WriteState, + VER: 'static + Send + Sync, { self.method_mutable(Method::post(), name, handler) } @@ -816,12 +814,12 @@ impl Api; - /// const MAJOR: u16 = 0; - /// const MINOR: u16 = 1; + /// type StaticVer01 = StaticVersion<0, 1>; /// - /// # fn ex(api: &mut Api) { + /// # fn ex(api: &mut Api) { /// api.post("replace", |req, state| async move { /// *state = req.integer_param("new_state")?; /// Ok(()) @@ -852,6 +850,7 @@ impl Api::State) -> BoxFuture<'_, Result>, T: Serialize, State: 'static + Send + Sync + WriteState, + VER: 'static + Send + Sync, { self.method_mutable(Method::put(), name, handler) } @@ -884,12 +883,12 @@ impl Api>; - /// const MAJOR: u16 = 0; - /// const MINOR: u16 = 1; + /// type StaticVer01 = StaticVersion<0, 1>; /// - /// # fn ex(api: &mut Api) { + /// # fn ex(api: &mut Api) { /// api.delete("state", |req, state| async { /// *state = None; /// Ok(()) @@ -920,6 +919,7 @@ impl Api::State) -> BoxFuture<'_, Result>, T: Serialize, State: 'static + Send + Sync + WriteState, + VER: 'static + Send + Sync, { self.method_mutable(Method::delete(), name, handler) } @@ -955,9 +955,10 @@ impl Api) { - /// api.socket("sum", |_req, mut conn: Connection, _state| async move { + /// # fn ex(api: &mut Api<(), ServerError, StaticVersion<0, 1>>) { + /// api.socket("sum", |_req, mut conn: Connection>, _state| async move { /// let mut sum = 0; /// while let Some(amount) = conn.next().await { /// sum += amount?; @@ -994,7 +995,7 @@ impl Api, + socket::Connection, &State, ) -> BoxFuture<'_, Result<(), Error>>, ToClient: 'static + Serialize + ?Sized, @@ -1021,11 +1022,9 @@ impl Api(handler), - ) + self.register_socket_handler(name, socket::stream_handler::<_, _, _, _, VER>(handler)) } fn register_socket_handler( @@ -1086,17 +1085,17 @@ impl Api; /// - /// # fn ex(_api: Api, ServerError, MAJOR, MINOR>) -> Result<(), ApiError> { - /// let mut api: Api, ServerError, MAJOR, MINOR>; + /// # fn ex(_api: Api, ServerError, StaticVer01>) -> Result<(), ApiError> { + /// let mut api: Api, ServerError, StaticVer01>; /// # api = _api; /// api.metrics("metrics", |_req, state| async move { /// state.counter.inc(); @@ -1130,6 +1129,7 @@ impl Api Api(handler)); + self.health_check = Some(route::health_check_handler::<_, _, VER>(handler)); self } @@ -1172,7 +1173,7 @@ impl Api( + route::health_check_response::<_, VER>( &req.accept().unwrap_or_else(|_| { // The healthcheck endpoint is not allowed to fail, so just use the default content // type if we can't parse the Accept header. @@ -1201,11 +1202,12 @@ impl Api( self, f: impl 'static + Clone + Send + Sync + Fn(Error) -> Error2, - ) -> Api + ) -> Api where Error: 'static + Send + Sync, Error2: 'static, State: 'static + Send + Sync, + VER: 'static + Send + Sync, { Api { meta: self.meta, @@ -1268,8 +1270,7 @@ struct ReadHandler { } #[async_trait] -impl Handler - for ReadHandler +impl Handler for ReadHandler where F: 'static + Send @@ -1277,16 +1278,19 @@ where + Fn(RequestParams, &::State) -> BoxFuture<'_, Result>, R: Serialize, State: 'static + Send + Sync + ReadState, + VER: 'static + Send + Sync + StaticVersionType, { async fn handle( &self, req: RequestParams, state: &State, + bind_version: VER, ) -> Result> { let accept = req.accept()?; - response_from_result::<_, _, MAJOR, MINOR>( + response_from_result( &accept, state.read(|state| (self.handler)(req, state)).await, + bind_version, ) } } @@ -1298,8 +1302,7 @@ struct WriteHandler { } #[async_trait] -impl Handler - for WriteHandler +impl Handler for WriteHandler where F: 'static + Send @@ -1307,16 +1310,19 @@ where + Fn(RequestParams, &mut ::State) -> BoxFuture<'_, Result>, R: Serialize, State: 'static + Send + Sync + WriteState, + VER: 'static + Send + Sync + StaticVersionType, { async fn handle( &self, req: RequestParams, state: &State, + bind_version: VER, ) -> Result> { let accept = req.accept()?; - response_from_result::<_, _, MAJOR, MINOR>( + response_from_result( &accept, state.write(|state| (self.handler)(req, state)).await, + bind_version, ) } } @@ -1346,13 +1352,17 @@ mod test { use prometheus::{Counter, Registry}; use std::borrow::Cow; use toml::toml; - use versioned_binary_serialization::{BinarySerializer, Serializer}; + use versioned_binary_serialization::{version::StaticVersion, BinarySerializer, Serializer}; #[cfg(windows)] use async_tungstenite::tungstenite::Error as WsError; #[cfg(windows)] use std::io::ErrorKind; + type StaticVer01 = StaticVersion<0, 1>; + type SerializerV01 = Serializer>; + const VER_0_1: StaticVer01 = StaticVersion {}; + async fn check_stream_closed(mut conn: WebSocketStream) where S: AsyncRead + AsyncWrite + Unpin, @@ -1376,7 +1386,7 @@ mod test { #[async_std::test] async fn test_socket_endpoint() { - let mut app = App::<_, ServerError, 0, 1>::with_state(RwLock::new(())); + let mut app = App::<_, ServerError, StaticVer01>::with_state(RwLock::new(())); let api_toml = toml! { [meta] FORMAT_VERSION = "0.1.0" @@ -1397,7 +1407,7 @@ mod test { let mut api = app.module::("mod", api_toml).unwrap(); api.socket( "echo", - |_req, mut conn: Connection, _state| { + |_req, mut conn: Connection, _state| { async move { while let Some(msg) = conn.next().await { conn.send(&msg?).await?; @@ -1410,7 +1420,7 @@ mod test { .unwrap() .socket( "once", - |_req, mut conn: Connection<_, (), _, 0, 1>, _state| { + |_req, mut conn: Connection<_, (), _, StaticVer01>, _state| { async move { conn.send("msg").boxed().await?; Ok(()) @@ -1421,7 +1431,7 @@ mod test { .unwrap() .socket( "error", - |_req, _conn: Connection<(), (), _, 0, 1>, _state| { + |_req, _conn: Connection<(), (), _, StaticVer01>, _state| { async move { Err(ServerError::catch_all( StatusCode::InternalServerError, @@ -1435,7 +1445,7 @@ mod test { } let port = pick_unused_port().unwrap(); let url: Url = format!("http://localhost:{}", port).parse().unwrap(); - spawn(app.serve(format!("0.0.0.0:{}", port))); + spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1)); wait_for_server(&url, SERVER_STARTUP_RETRIES, SERVER_STARTUP_SLEEP_MS).await; let mut socket_url = url.join("mod/echo").unwrap(); @@ -1459,7 +1469,7 @@ mod test { // Send a binary message. conn.send(Message::Binary( - Serializer::<0, 1>::serialize("goodbye").unwrap(), + SerializerV01::serialize("goodbye").unwrap(), )) .await .unwrap(); @@ -1481,18 +1491,18 @@ mod test { .unwrap(); assert_eq!( conn.next().await.unwrap().unwrap(), - Message::Binary(Serializer::<0, 1>::serialize("hello").unwrap()) + Message::Binary(SerializerV01::serialize("hello").unwrap()) ); // Send a binary message. conn.send(Message::Binary( - Serializer::<0, 1>::serialize("goodbye").unwrap(), + SerializerV01::serialize("goodbye").unwrap(), )) .await .unwrap(); assert_eq!( conn.next().await.unwrap().unwrap(), - Message::Binary(Serializer::<0, 1>::serialize("goodbye").unwrap()) + Message::Binary(SerializerV01::serialize("goodbye").unwrap()) ); // Test a stream that exits normally. @@ -1525,7 +1535,7 @@ mod test { #[async_std::test] async fn test_stream_endpoint() { - let mut app = App::<_, ServerError, 0, 1>::with_state(RwLock::new(())); + let mut app = App::<_, ServerError, StaticVer01>::with_state(RwLock::new(())); let api_toml = toml! { [meta] FORMAT_VERSION = "0.1.0" @@ -1561,7 +1571,7 @@ mod test { } let port = pick_unused_port().unwrap(); let url: Url = format!("http://localhost:{}", port).parse().unwrap(); - spawn(app.serve(format!("0.0.0.0:{}", port))); + spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1)); wait_for_server(&url, SERVER_STARTUP_RETRIES, SERVER_STARTUP_SLEEP_MS).await; // Consume the `nat` stream. @@ -1608,7 +1618,7 @@ mod test { #[async_std::test] async fn test_custom_healthcheck() { - let mut app = App::<_, ServerError, 0, 1>::with_state(HealthStatus::Available); + let mut app = App::<_, ServerError, StaticVer01>::with_state(HealthStatus::Available); let api_toml = toml! { [meta] FORMAT_VERSION = "0.1.0" @@ -1622,7 +1632,7 @@ mod test { } let port = pick_unused_port().unwrap(); let url: Url = format!("http://localhost:{}", port).parse().unwrap(); - spawn(app.serve(format!("0.0.0.0:{}", port))); + spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1)); wait_for_server(&url, SERVER_STARTUP_RETRIES, SERVER_STARTUP_SLEEP_MS).await; let mut res = surf::get(format!("http://localhost:{}/mod/healthcheck", port)) @@ -1652,7 +1662,7 @@ mod test { metrics.register(Box::new(counter.clone())).unwrap(); let state = State { metrics, counter }; - let mut app = App::<_, ServerError, 0, 1>::with_state(RwLock::new(state)); + let mut app = App::<_, ServerError, StaticVer01>::with_state(RwLock::new(state)); let api_toml = toml! { [meta] FORMAT_VERSION = "0.1.0" @@ -1674,7 +1684,7 @@ mod test { } let port = pick_unused_port().unwrap(); let url: Url = format!("http://localhost:{port}").parse().unwrap(); - spawn(app.serve(format!("0.0.0.0:{port}"))); + spawn(app.serve(format!("0.0.0.0:{port}"), VER_0_1)); wait_for_server(&url, SERVER_STARTUP_RETRIES, SERVER_STARTUP_SLEEP_MS).await; for i in 1..5 { diff --git a/src/app.rs b/src/app.rs index 1b97bbc4..3ac33a83 100644 --- a/src/app.rs +++ b/src/app.rs @@ -35,6 +35,7 @@ use tide::{ security::{CorsMiddleware, Origin}, }; use tide_websockets::WebSocket; +use versioned_binary_serialization::version::StaticVersionType; pub use tide::listener::{Listener, ToListener}; @@ -44,9 +45,9 @@ pub use tide::listener::{Listener, ToListener}; /// constructing an [Api] for each module and calling [App::register_module]. Once all of the /// desired modules are registered, the app can be converted into an asynchronous server task using /// [App::serve]. -pub struct App { +pub struct App { // Map from base URL to module API. - apis: HashMap>, + apis: HashMap>, state: Arc, app_version: Option, } @@ -58,8 +59,11 @@ pub enum AppError { ModuleAlreadyExists, } -impl - App +impl< + State: Send + Sync + 'static, + Error: 'static, + VER: Send + Sync + 'static + StaticVersionType, + > App { /// Create a new [App] with a given state. pub fn with_state(state: State) -> Self { @@ -75,7 +79,7 @@ impl, - ) -> Result, AppError> + ) -> Result, AppError> where Error: From, ModuleError: 'static + Send + Sync, @@ -95,7 +99,7 @@ impl( &mut self, base_url: &str, - api: Api, + api: Api, ) -> Result<&mut Self, AppError> where Error: From, @@ -135,9 +139,9 @@ impl) { + /// # use versioned_binary_serialization::version::StaticVersion; + /// # type StaticVer01 = StaticVersion<0, 1>; + /// # fn ex(app: &mut tide_disco::App<(), (), StaticVer01>) { /// app.with_version(env!("CARGO_PKG_VERSION").parse().unwrap()); /// # } /// ``` @@ -216,12 +220,15 @@ lazy_static! { impl< State: Send + Sync + 'static, Error: 'static + crate::Error, - const MAJOR: u16, - const MINOR: u16, - > App + VER: 'static + Send + Sync + StaticVersionType, + > App { /// Serve the [App] asynchronously. - pub async fn serve>>(self, listener: L) -> io::Result<()> { + pub async fn serve>>( + self, + listener: L, + bind_version: VER, + ) -> io::Result<()> { let state = Arc::new(self); let mut server = tide::Server::with_state(state.clone()); for (name, api) in &state.apis { @@ -235,7 +242,7 @@ impl< .at(name) .serve_dir(api.public().unwrap_or_else(|| &DEFAULT_PUBLIC_PATH))?; } - server.with(add_error_body::<_, Error, MAJOR, MINOR>); + server.with(add_error_body::<_, Error, VER>); server.with( CorsMiddleware::new() .allow_methods("GET, POST".parse::().unwrap()) @@ -270,13 +277,24 @@ impl< // all endpoints registered under this pattern, so that a request to this path // with the right headers will return metrics instead of going through the // normal method-based dispatching. - Self::register_metrics(prefix.to_owned(), &mut endpoint, metrics_route); + Self::register_metrics( + prefix.to_owned(), + &mut endpoint, + metrics_route, + bind_version, + ); } // Register the HTTP routes. for route in routes { if let Method::Http(method) = route.method() { - Self::register_route(prefix.to_owned(), &mut endpoint, route, method); + Self::register_route( + prefix.to_owned(), + &mut endpoint, + route, + method, + bind_version, + ); } } } @@ -308,9 +326,9 @@ impl< async move { let api = &req.state().apis[&prefix]; let accept = RequestParams::accept_from_headers(&req)?; - respond_with::<_, _, MAJOR, MINOR>(&accept, api.version()).map_err( - |err| Error::from_route_error::(err).into_tide_error(), - ) + respond_with(&accept, api.version(), bind_version).map_err(|err| { + Error::from_route_error::(err).into_tide_error() + }) } }); } @@ -319,19 +337,19 @@ impl< // Register app-level automatic routes: `healthcheck` and `version`. server .at("healthcheck") - .get(|req: tide::Request>| async move { + .get(move |req: tide::Request>| async move { let state = req.state().clone(); let app_state = &*state.state; let req = request_params(req, &[]).await?; let accept = req.accept()?; let res = state.health(req, app_state).await; - Ok(health_check_response::<_, MAJOR, MINOR>(&accept, res)) + Ok(health_check_response::<_, VER>(&accept, res)) }); server .at("version") - .get(|req: tide::Request>| async move { + .get(move |req: tide::Request>| async move { let accept = RequestParams::accept_from_headers(&req)?; - respond_with::<_, _, MAJOR, MINOR>(&accept, req.state().version()) + respond_with(&accept, req.state().version(), bind_version) .map_err(|err| Error::from_route_error::(err).into_tide_error()) }); @@ -384,8 +402,9 @@ impl< fn register_route( api: String, endpoint: &mut tide::Route>, - route: &Route, + route: &Route, method: http::Method, + bind_version: VER, ) { let name = route.name(); endpoint.method(method, move |req: tide::Request>| { @@ -396,7 +415,7 @@ impl< let state = &*req.state().clone().state; let req = request_params(req, route.params()).await?; route - .handle(req, state) + .handle(req, state, bind_version) .await .map_err(|err| match err { RouteError::AppSpecific(err) => err, @@ -410,14 +429,19 @@ impl< fn register_metrics( api: String, endpoint: &mut tide::Route>, - route: &Route, + route: &Route, + bind_version: VER, ) { let name = route.name(); if route.has_handler() { // If there is a metrics handler, add middleware to the endpoint to intercept the // request and respond with metrics, rather than the usual HTTP dispatching, if the // appropriate headers are set. - endpoint.with(MetricsMiddleware::new(name.clone(), api.clone())); + endpoint.with(MetricsMiddleware::new( + name.clone(), + api.clone(), + bind_version, + )); } // Register a catch-all HTTP handler for the route, which serves the route documentation as @@ -435,7 +459,7 @@ impl< fn register_socket( api: String, endpoint: &mut tide::Route>, - route: &Route, + route: &Route, ) { let name = route.name(); if route.has_handler() { @@ -482,7 +506,7 @@ impl< fn register_fallback( api: String, endpoint: &mut tide::Route>, - route: &Route, + route: &Route, ) { let name = route.name(); endpoint.all(move |req: tide::Request>| { @@ -502,27 +526,28 @@ impl< } } -struct MetricsMiddleware { +struct MetricsMiddleware { route: String, api: String, + ver: VER, } -impl MetricsMiddleware { - fn new(route: String, api: String) -> Self { - Self { route, api } +impl MetricsMiddleware { + fn new(route: String, api: String, ver: VER) -> Self { + Self { route, api, ver } } } -impl - tide::Middleware>> for MetricsMiddleware +impl tide::Middleware>> for MetricsMiddleware where State: Send + Sync + 'static, Error: 'static + crate::Error, + VER: Send + Sync + 'static + StaticVersionType, { fn handle<'a, 'b, 't>( &'a self, - req: tide::Request>>, - next: tide::Next<'b, Arc>>, + req: tide::Request>>, + next: tide::Next<'b, Arc>>, ) -> BoxFuture<'t, tide::Result> where 'a: 't, @@ -531,6 +556,7 @@ where { let route = self.route.clone(); let api = self.api.clone(); + let bind_version = self.ver; async move { if req.method() != http::Method::Get { // Metrics only apply to GET requests. For other requests, proceed with normal @@ -552,7 +578,7 @@ where let state = &*req.state().clone().state; let req = request_params(req, route.params()).await?; route - .handle(req, state) + .handle(req, state, bind_version) .await .map_err(|err| match err { RouteError::AppSpecific(err) => err, @@ -564,8 +590,8 @@ where } } -async fn request_params( - req: tide::Request>>, +async fn request_params( + req: tide::Request>>, params: &[RequestParam], ) -> Result { RequestParams::new(req, params) @@ -619,8 +645,7 @@ pub struct AppVersion { fn add_error_body< T: Clone + Send + Sync + 'static, E: crate::Error, - const MAJOR: u16, - const MINOR: u16, + VER: Send + Sync + 'static + StaticVersionType, >( req: tide::Request, next: tide::Next, @@ -634,7 +659,7 @@ fn add_error_body< // Try to add the error to the response body using a format accepted by the client. If // we cannot do that (for example, if the client requested a format that is incompatible // with a serialized error) just add the error as a string using plaintext. - let (body, content_type) = route::response_body::<_, E, MAJOR, MINOR>(&accept, &error) + let (body, content_type) = route::response_body::<_, E, VER>(&accept, &error) .unwrap_or_else(|_| (error.to_string().into(), mime::PLAIN)); res.set_body(body); res.set_content_type(content_type); @@ -645,50 +670,54 @@ fn add_error_body< }) } -pub struct Module<'a, State, Error, ModuleError, const MAJOR: u16, const MINOR: u16> +pub struct Module<'a, State, Error, ModuleError, VER: StaticVersionType> where State: 'static + Send + Sync, Error: 'static + From, ModuleError: 'static + Send + Sync, + VER: 'static + Send + Sync, { - app: &'a mut App, + app: &'a mut App, base_url: &'a str, // This is only an [Option] so we can [take] out of it during [drop]. - api: Option>, + api: Option>, } -impl<'a, State, Error, ModuleError, const MAJOR: u16, const MINOR: u16> Deref - for Module<'a, State, Error, ModuleError, MAJOR, MINOR> +impl<'a, State, Error, ModuleError, VER: StaticVersionType> Deref + for Module<'a, State, Error, ModuleError, VER> where State: 'static + Send + Sync, Error: 'static + From, ModuleError: 'static + Send + Sync, + VER: 'static + Send + Sync, { - type Target = Api; + type Target = Api; fn deref(&self) -> &Self::Target { self.api.as_ref().unwrap() } } -impl<'a, State, Error, ModuleError, const MAJOR: u16, const MINOR: u16> DerefMut - for Module<'a, State, Error, ModuleError, MAJOR, MINOR> +impl<'a, State, Error, ModuleError, VER: StaticVersionType> DerefMut + for Module<'a, State, Error, ModuleError, VER> where State: 'static + Send + Sync, Error: 'static + From, ModuleError: 'static + Send + Sync, + VER: 'static + Send + Sync, { fn deref_mut(&mut self) -> &mut Self::Target { self.api.as_mut().unwrap() } } -impl<'a, State, Error, ModuleError, const MAJOR: u16, const MINOR: u16> Drop - for Module<'a, State, Error, ModuleError, MAJOR, MINOR> +impl<'a, State, Error, ModuleError, VER: StaticVersionType> Drop + for Module<'a, State, Error, ModuleError, VER> where State: 'static + Send + Sync, Error: 'static + From, ModuleError: 'static + Send + Sync, + VER: 'static + Send + Sync, { fn drop(&mut self) { self.app @@ -710,7 +739,11 @@ mod test { use portpicker::pick_unused_port; use std::borrow::Cow; use toml::toml; - use versioned_binary_serialization::{BinarySerializer, Serializer}; + use versioned_binary_serialization::{version::StaticVersion, BinarySerializer, Serializer}; + + type StaticVer01 = StaticVersion<0, 1>; + type SerializerV01 = Serializer>; + const VER_0_1: StaticVer01 = StaticVersion {}; #[derive(Clone, Copy, Debug)] struct FakeMetrics; @@ -728,7 +761,7 @@ mod test { async fn test_method_dispatch() { use crate::http::Method::*; - let mut app = App::<_, ServerError, 0, 1>::with_state(RwLock::new(FakeMetrics)); + let mut app = App::<_, ServerError, StaticVer01>::with_state(RwLock::new(FakeMetrics)); let api_toml = toml! { [meta] FORMAT_VERSION = "0.1.0" @@ -777,7 +810,7 @@ mod test { .unwrap() .socket( "socket_test", - |_req, mut conn: Connection<_, (), _, 0, 1>, _state| { + |_req, mut conn: Connection<_, (), _, StaticVer01>, _state| { async move { conn.send("SOCKET").await.unwrap(); Ok(()) @@ -793,7 +826,7 @@ mod test { } let port = pick_unused_port().unwrap(); let url: Url = format!("http://localhost:{}", port).parse().unwrap(); - spawn(app.serve(format!("0.0.0.0:{}", port))); + spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1)); wait_for_server(&url, SERVER_STARTUP_RETRIES, SERVER_STARTUP_SLEEP_MS).await; let client: surf::Client = surf::Config::new() @@ -839,7 +872,7 @@ mod test { let msg = conn.next().await.unwrap().unwrap(); let body: String = match msg { Message::Text(m) => serde_json::from_str(&m).unwrap(), - Message::Binary(m) => Serializer::<0, 1>::deserialize(&m).unwrap(), + Message::Binary(m) => SerializerV01::deserialize(&m).unwrap(), m => panic!("expected Text or Binary message, but got {}", m), }; assert_eq!(body, "SOCKET"); @@ -848,7 +881,7 @@ mod test { /// Test route dispatching for routes with patterns containing different parmaeters #[async_std::test] async fn test_param_dispatch() { - let mut app = App::<_, ServerError, 0, 1>::with_state(RwLock::new(())); + let mut app = App::<_, ServerError, StaticVer01>::with_state(RwLock::new(())); let api_toml = toml! { [meta] FORMAT_VERSION = "0.1.0" @@ -874,7 +907,7 @@ mod test { } let port = pick_unused_port().unwrap(); let url: Url = format!("http://localhost:{}", port).parse().unwrap(); - spawn(app.serve(format!("0.0.0.0:{}", port))); + spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1)); wait_for_server(&url, SERVER_STARTUP_RETRIES, SERVER_STARTUP_SLEEP_MS).await; let client: surf::Client = surf::Config::new() diff --git a/src/lib.rs b/src/lib.rs index 01fb5ed6..0709f188 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,16 +47,16 @@ //! # fn main() -> Result<(), tide_disco::api::ApiError> { //! use tide_disco::Api; //! use tide_disco::error::ServerError; +//! use versioned_binary_serialization::version::StaticVersion; //! //! type State = (); //! type Error = ServerError; -//! const MAJOR: u16 = 0; -//! const MINOR: u16 = 1; +//! type StaticVer01 = StaticVersion<0, 1>; //! //! let spec: toml::Value = toml::from_str( //! std::str::from_utf8(&std::fs::read("/path/to/api.toml").unwrap()).unwrap(), //! ).unwrap(); -//! let mut api = Api::::new(spec)?; +//! let mut api = Api::::new(spec)?; //! # Ok(()) //! # } //! ``` @@ -74,11 +74,11 @@ //! //! ```no_run //! # use tide_disco::Api; -//! # const MAJOR: u16 = 0; -//! # const MINOR: u16 = 1; +//! # use versioned_binary_serialization::version::StaticVersion; +//! # type StaticVer01 = StaticVersion<0, 1>; //! # fn main() -> Result<(), tide_disco::api::ApiError> { //! # let spec: toml::Value = toml::from_str(std::str::from_utf8(&std::fs::read("/path/to/api.toml").unwrap()).unwrap()).unwrap(); -//! # let mut api = Api::<(), tide_disco::error::ServerError, MAJOR, MINOR>::new(spec)?; +//! # let mut api = Api::<(), tide_disco::error::ServerError, StaticVer01>::new(spec)?; //! use futures::FutureExt; //! //! api.get("hello", |req, state| async move { Ok("Hello, world!") }.boxed())?; @@ -92,21 +92,22 @@ //! an [App]: //! //! ```no_run +//! # use versioned_binary_serialization::version::StaticVersion; //! # type State = (); //! # type Error = tide_disco::error::ServerError; -//! # const MAJOR: u16 = 0; -//! # const MINOR: u16 = 1; +//! # type StaticVer01 = StaticVersion<0, 1>; //! # #[async_std::main] async fn main() { //! # let spec: toml::Value = toml::from_str(std::str::from_utf8(&std::fs::read("/path/to/api.toml").unwrap()).unwrap()).unwrap(); -//! # let api = tide_disco::Api::::new(spec).unwrap(); +//! # let api = tide_disco::Api::::new(spec).unwrap(); //! use tide_disco::App; +//! use versioned_binary_serialization::version::StaticVersion; //! -//! const MAJOR: u16 = 0; -//! const MINOR: u16 = 1; +//! type StaticVer01 = StaticVersion<0, 1>; +//! const VER_0_1: StaticVer01 = StaticVersion {}; //! -//! let mut app = App::::with_state(()); +//! let mut app = App::::with_state(()); //! app.register_module("api", api); -//! app.serve("http://localhost:8080").await; +//! app.serve("http://localhost:8080", VER_0_1).await; //! # } //! ``` //! @@ -167,7 +168,7 @@ //! implements [Fn], not just static function pointers. Here is what we would _like_ to write: //! //! ```ignore -//! impl Api { +//! impl Api { //! pub fn at(&mut self, route: &str, handler: F) //! where //! F: for<'a> Fn<(RequestParams, &'a State)>, @@ -192,7 +193,7 @@ //! `F`. Here is the actual (partial) signature of [at](Api::at): //! //! ```ignore -//! impl Api { +//! impl Api { //! pub fn at(&mut self, route: &str, handler: F) //! where //! F: for<'a> Fn(RequestParams, &'a State) -> BoxFuture<'a, Result>, @@ -208,14 +209,14 @@ //! use async_std::sync::RwLock; //! use futures::FutureExt; //! use tide_disco::Api; +//! use versioned_binary_serialization::version::StaticVersion; //! //! type State = RwLock; //! type Error = (); //! -//! const MAJOR: u16 = 0; -//! const MINOR: u16 = 1; +//! type StaticVer01 = StaticVersion<0, 1>; //! -//! fn define_routes(api: &mut Api) { +//! fn define_routes(api: &mut Api) { //! api.at("someroute", |_req, state: &State| async { //! Ok(*state.read().await) //! }.boxed()); @@ -230,17 +231,17 @@ //! use async_std::sync::RwLock; //! use futures::FutureExt; //! use tide_disco::{Api, RequestParams}; +//! use versioned_binary_serialization::version::StaticVersion; //! //! type State = RwLock; //! type Error = (); -//! const MAJOR: u16 = 0; -//! const MINOR: u16 = 1; +//! type StaticVer01 = StaticVersion<0, 1>; //! //! async fn handler(_req: RequestParams, state: &State) -> Result { //! Ok(*state.read().await) //! } //! -//! fn register(api: &mut Api) { +//! fn register(api: &mut Api) { //! api.at("someroute", |req, state: &State| handler(req, state).boxed()); //! } //! ``` diff --git a/src/metrics.rs b/src/metrics.rs index 5adcd8ca..a130aed4 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -16,6 +16,7 @@ use derive_more::From; use futures::future::{BoxFuture, FutureExt}; use prometheus::{Encoder, TextEncoder}; use std::{borrow::Cow, error::Error, fmt::Debug}; +use versioned_binary_serialization::version::StaticVersionType; pub trait Metrics { type Error: Debug + Error; @@ -54,18 +55,19 @@ impl Metrics for prometheus::Registry { pub(crate) struct Handler(F); #[async_trait] -impl - route::Handler for Handler +impl route::Handler for Handler where F: 'static + Send + Sync + Fn(RequestParams, &State::State) -> BoxFuture, Error>>, T: 'static + Clone + Metrics, State: 'static + Send + Sync + ReadState, Error: 'static, + VER: 'static + Send + Sync, { async fn handle( &self, req: RequestParams, state: &State, + _: VER, ) -> Result> { let exported = state .read(|state| { diff --git a/src/request.rs b/src/request.rs index 150f26e4..65c2a9b6 100644 --- a/src/request.rs +++ b/src/request.rs @@ -13,7 +13,7 @@ use std::fmt::Display; use strum_macros::EnumString; use tagged_base64::TaggedBase64; use tide::http::{self, content::Accept, mime::Mime, Headers}; -use versioned_binary_serialization::{BinarySerializer, Serializer}; +use versioned_binary_serialization::{version::StaticVersionType, BinarySerializer, Serializer}; #[derive(Clone, Debug, Snafu, Deserialize, Serialize)] pub enum RequestError { @@ -404,7 +404,7 @@ impl RequestParams { /// Deserialize the body of a request. /// /// The Content-Type header is used to determine the serialization format. - pub fn body_auto(&self) -> Result + pub fn body_auto(&self, _: VER) -> Result where T: serde::de::DeserializeOwned, { @@ -413,8 +413,7 @@ impl RequestParams { "application/json" => self.body_json(), "application/octet-stream" => { let bytes = self.body_bytes(); - Serializer::::deserialize(&bytes) - .map_err(|_err| RequestError::Binary {}) + Serializer::::deserialize(&bytes).map_err(|_err| RequestError::Binary {}) } _content_type => Err(RequestError::UnsupportedContentType {}), } diff --git a/src/route.rs b/src/route.rs index 19140c17..40b769b2 100644 --- a/src/route.rs +++ b/src/route.rs @@ -36,7 +36,7 @@ use tide::{ Body, }; use tide_websockets::WebSocketConnection; -use versioned_binary_serialization::{BinarySerializer, Serializer}; +use versioned_binary_serialization::{version::StaticVersionType, BinarySerializer, Serializer}; /// An error returned by a route handler. /// @@ -113,13 +113,14 @@ impl From for RouteError { /// return type of a handler function. The types which are preserved, `State` and `Error`, should be /// the same for all handlers in an API module. #[async_trait] -pub(crate) trait Handler: +pub(crate) trait Handler: 'static + Send + Sync { async fn handle( &self, req: RequestParams, state: &State, + bind_version: VER, ) -> Result>; } @@ -141,69 +142,71 @@ pub(crate) trait Handler: pub(crate) struct FnHandler(F); #[async_trait] -impl Handler - for FnHandler +impl Handler for FnHandler where F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxFuture<'_, Result>, T: Serialize, State: 'static + Send + Sync, + VER: 'static + Send + Sync + StaticVersionType, { async fn handle( &self, req: RequestParams, state: &State, + bind_version: VER, ) -> Result> { let accept = req.accept()?; - response_from_result::<_, _, MAJOR, MINOR>(&accept, (self.0)(req, state).await) + response_from_result(&accept, (self.0)(req, state).await, bind_version) } } -pub(crate) fn response_from_result( +pub(crate) fn response_from_result( accept: &Accept, res: Result, + bind_version: VER, ) -> Result> { res.map_err(RouteError::AppSpecific) - .and_then(|res| respond_with::<_, _, MAJOR, MINOR>(accept, &res)) + .and_then(|res| respond_with(accept, &res, bind_version)) } #[async_trait] impl< - H: ?Sized + Handler, + H: ?Sized + Handler, State: 'static + Send + Sync, Error, - const MAJOR: u16, - const MINOR: u16, - > Handler for Box + VER: 'static + Send + Sync + StaticVersionType, + > Handler for Box { async fn handle( &self, req: RequestParams, state: &State, + bind_version: VER, ) -> Result> { - (**self).handle(req, state).await + (**self).handle(req, state, bind_version).await } } -enum RouteImplementation { +enum RouteImplementation { Http { method: http::Method, - handler: Option>>, + handler: Option>>, }, Socket { handler: Option>, }, Metrics { - handler: Option>>, + handler: Option>>, }, } -impl - RouteImplementation +impl + RouteImplementation { fn map_err( self, f: impl 'static + Send + Sync + Fn(Error) -> Error2, - ) -> RouteImplementation + ) -> RouteImplementation where State: 'static + Send + Sync, Error: 'static + Send + Sync, @@ -213,12 +216,10 @@ impl Self::Http { method, handler } => RouteImplementation::Http { method, handler: handler.map(|h| { - let h: Box> = - Box::new(MapErr::< - Box>, - _, - Error, - >::new(h, f)); + let h: Box> = + Box::new( + MapErr::>, _, Error>::new(h, f), + ); h }), }, @@ -227,12 +228,10 @@ impl }, Self::Metrics { handler } => RouteImplementation::Metrics { handler: handler.map(|h| { - let h: Box> = - Box::new(MapErr::< - Box>, - _, - Error, - >::new(h, f)); + let h: Box> = + Box::new( + MapErr::>, _, Error>::new(h, f), + ); h }), }, @@ -248,14 +247,14 @@ impl /// simply returns information about the route. #[derive(Derivative)] #[derivative(Debug(bound = ""))] -pub struct Route { +pub struct Route { name: String, patterns: Vec, params: Vec, doc: String, meta: Arc, #[derivative(Debug = "ignore")] - handler: RouteImplementation, + handler: RouteImplementation, } #[derive(Clone, Debug, Snafu)] @@ -273,7 +272,7 @@ pub enum RouteParseError { RouteMustBeTable, } -impl Route { +impl Route { /// Parse a [Route] from a TOML specification. /// /// The specification must be a table containing at least the following keys: @@ -396,11 +395,12 @@ impl Route( self, f: impl 'static + Send + Sync + Fn(Error) -> Error2, - ) -> Route + ) -> Route where State: 'static + Send + Sync, Error: 'static + Send + Sync, Error2: 'static, + VER: 'static + Send + Sync, { Route { handler: self.handler.map_err(f), @@ -439,10 +439,10 @@ impl Route Route { +impl Route { pub(crate) fn set_handler( &mut self, - h: impl Handler, + h: impl Handler, ) -> Result<(), RouteError> { match &mut self.handler { RouteImplementation::Http { handler, .. } => { @@ -460,6 +460,7 @@ impl Route BoxFuture<'_, Result>, T: Serialize, State: 'static + Send + Sync, + VER: 'static + Send + Sync, { self.set_handler(FnHandler::from(handler)) } @@ -488,6 +489,7 @@ impl Route { @@ -530,21 +532,22 @@ impl Route Handler - for Route +impl Handler for Route where Error: 'static, State: 'static + Send + Sync, + VER: 'static + Send + Sync, { async fn handle( &self, req: RequestParams, state: &State, + bind_version: VER, ) -> Result> { match &self.handler { RouteImplementation::Http { handler, .. } | RouteImplementation::Metrics { handler, .. } => match handler { - Some(handler) => handler.handle(req, state).await, + Some(handler) => handler.handle(req, state, bind_version).await, None => self.default_handler(), }, RouteImplementation::Socket { .. } => Err(RouteError::IncorrectMethod { @@ -571,22 +574,23 @@ impl MapErr { } #[async_trait] -impl - Handler for MapErr +impl Handler for MapErr where - H: Handler, + H: Handler, F: 'static + Send + Sync + Fn(Error1) -> Error2, State: 'static + Send + Sync, Error1: 'static + Send + Sync, Error2: 'static, + VER: 'static + Send + Sync + StaticVersionType, { async fn handle( &self, req: RequestParams, state: &State, + bind_version: VER, ) -> Result> { self.handler - .handle(req, state) + .handle(req, state, bind_version) .await .map_err(|err| err.map_app_specific(&self.map)) } @@ -601,13 +605,13 @@ where pub(crate) type HealthCheckHandler = Box BoxFuture<'_, tide::Response>>; -pub(crate) fn health_check_response( +pub(crate) fn health_check_response( accept: &Accept, health: H, ) -> tide::Response { let status = health.status(); - let (body, content_type) = response_body::(accept, health) - .unwrap_or_else(|err| { + let (body, content_type) = + response_body::(accept, health).unwrap_or_else(|err| { let msg = format!( "health status was {}, but there was an error generating the response: {}", status, err @@ -624,12 +628,13 @@ pub(crate) fn health_check_response( +pub(crate) fn health_check_handler( handler: impl 'static + Send + Sync + Fn(&State) -> BoxFuture, ) -> HealthCheckHandler where State: 'static + Send + Sync, H: 'static + HealthCheck, + VER: 'static + Send + Sync, { Box::new(move |req, state| { let accept = req.accept().unwrap_or_else(|_| { @@ -642,19 +647,19 @@ where let future = handler(state); async move { let health = future.await; - health_check_response::<_, MAJOR, MINOR>(&accept, health) + health_check_response::<_, VER>(&accept, health) } .boxed() }) } -pub(crate) fn response_body( +pub(crate) fn response_body( accept: &Accept, body: T, ) -> Result<(Body, Mime), RouteError> { let ty = best_response_type(accept, &[mime::JSON, mime::BYTE_STREAM])?; if ty == mime::BYTE_STREAM { - let bytes = Serializer::::serialize(&body).map_err(RouteError::Binary)?; + let bytes = Serializer::::serialize(&body).map_err(RouteError::Binary)?; Ok((bytes.into(), mime::BYTE_STREAM)) } else if ty == mime::JSON { let json = serde_json::to_string(&body).map_err(RouteError::Json)?; @@ -664,11 +669,12 @@ pub(crate) fn response_body } } -pub(crate) fn respond_with( +pub(crate) fn respond_with( accept: &Accept, body: T, + _: VER, ) -> Result> { - let (body, content_type) = response_body::<_, _, MAJOR, MINOR>(accept, body)?; + let (body, content_type) = response_body::<_, _, VER>(accept, body)?; Ok(tide::Response::builder(StatusCode::Ok) .body(body) .content_type(content_type) diff --git a/src/socket.rs b/src/socket.rs index 9c6e1e02..e36540a4 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -29,7 +29,7 @@ use tide_websockets::{ tungstenite::protocol::frame::{coding::CloseCode, CloseFrame}, Message, WebSocketConnection, }; -use versioned_binary_serialization::{BinarySerializer, Serializer}; +use versioned_binary_serialization::{version::StaticVersionType, BinarySerializer, Serializer}; /// An error returned by a socket handler. /// @@ -133,16 +133,17 @@ enum MessageType { /// /// [Connection] implements [Stream], which can be used to receive `FromClient` messages from the /// client, and [Sink] which can be used to send `ToClient` messages to the client. -pub struct Connection { +pub struct Connection { conn: WebSocketConnection, // [Sink] wrapper around `conn` sink: Pin>>>, accept: MessageType, - _phantom: PhantomData ()>, + #[allow(clippy::type_complexity)] + _phantom: PhantomData ()>, } -impl Stream - for Connection +impl Stream + for Connection { type Item = Result>; @@ -154,7 +155,7 @@ impl Poll::Ready(Some(Err(err.into()))), Poll::Ready(Some(Ok(msg))) => Poll::Ready(Some(match msg { Message::Binary(bytes) => { - Serializer::::deserialize(&bytes).map_err(SocketError::from) + Serializer::::deserialize(&bytes).map_err(SocketError::from) } Message::Text(s) => serde_json::from_str(&s).map_err(SocketError::from), _ => Err(SocketError::UnsupportedMessageType), @@ -164,8 +165,8 @@ impl - Sink<&ToClient> for Connection +impl Sink<&ToClient> + for Connection { type Error = SocketError; @@ -175,7 +176,7 @@ impl, item: &ToClient) -> Result<(), Self::Error> { let msg = match self.accept { - MessageType::Binary => Message::Binary(Serializer::::serialize(item)?), + MessageType::Binary => Message::Binary(Serializer::::serialize(item)?), MessageType::Json => Message::Text(serde_json::to_string(item)?), }; self.sink @@ -193,8 +194,8 @@ impl Drop - for Connection +impl Drop + for Connection { fn drop(&mut self) { // This is the idiomatic way to implement [drop] for a type that uses pinning. Since [drop] @@ -211,16 +212,16 @@ impl Dr // `new_unchecked` is okay because we know this value is never used again after being // dropped. inner_drop(unsafe { Pin::new_unchecked(self) }); - fn inner_drop( - _this: Pin<&mut Connection>, + fn inner_drop( + _this: Pin<&mut Connection>, ) { // Any logic goes here. } } } -impl - Connection +impl + Connection { fn new(accept: &Accept, conn: WebSocketConnection) -> Result> { let ty = best_response_type(accept, &[mime::JSON, mime::BYTE_STREAM])?; @@ -275,7 +276,7 @@ pub(crate) type Handler = Box< + Fn(RequestParams, WebSocketConnection, &State) -> BoxFuture>>, >; -pub(crate) fn handler( +pub(crate) fn handler( f: F, ) -> Handler where @@ -284,7 +285,7 @@ where + Sync + Fn( RequestParams, - Connection, + Connection, &State, ) -> BoxFuture>, State: 'static + Send + Sync, @@ -299,13 +300,13 @@ where }) } -struct StreamHandler(F); +struct StreamHandler(F, PhantomData); -impl StreamHandler { +impl StreamHandler { fn handle<'a, State, Error, Msg>( &self, req: RequestParams, - mut conn: Connection, + mut conn: Connection, state: &'a State, ) -> BoxFuture<'a, Result<(), SocketError>> where @@ -313,6 +314,7 @@ impl StreamHandler { State: 'static + Send + Sync, Msg: 'static + Serialize + Send + Sync, Error: 'static + Send, + VER: 'static + Send + Sync, { let mut stream = (self.0)(req, state); async move { @@ -325,35 +327,33 @@ impl StreamHandler { } } -pub(crate) fn stream_handler( - f: F, -) -> Handler +pub(crate) fn stream_handler(f: F) -> Handler where F: 'static + Send + Sync + Fn(RequestParams, &State) -> BoxStream>, State: 'static + Send + Sync, Msg: 'static + Serialize + Send + Sync, Error: 'static + Send + Display, + VER: 'static + Send + Sync + StaticVersionType, { - let handler: StreamHandler = StreamHandler(f); + let handler: StreamHandler = StreamHandler(f, Default::default()); raw_handler(move |req, conn, state| handler.handle(req, conn, state)) } -fn raw_handler( - f: F, -) -> Handler +fn raw_handler(f: F) -> Handler where F: 'static + Send + Sync + Fn( RequestParams, - Connection, + Connection, &State, ) -> BoxFuture>>, State: 'static + Send + Sync, ToClient: 'static + Serialize + ?Sized, FromClient: 'static + DeserializeOwned, Error: 'static + Send + Display, + VER: StaticVersionType, { let close = |conn: WebSocketConnection, res: Result<(), SocketError>| async move { // When the handler finishes, send a close message. If there was an error, include the error diff --git a/src/status.rs b/src/status.rs index 27aeb52a..649758e0 100644 --- a/src/status.rs +++ b/src/status.rs @@ -532,8 +532,9 @@ impl StatusCode { #[cfg(test)] mod test { use super::*; - use versioned_binary_serialization::{BinarySerializer, Serializer}; + use versioned_binary_serialization::{version::StaticVersion, BinarySerializer, Serializer}; + type SerializerV01 = Serializer>; #[test] fn test_status_code() { for code in 0u16.. { @@ -551,8 +552,8 @@ mod test { // Test binary round trip. assert_eq!( status, - Serializer::<0, 1>::deserialize::( - &Serializer::<0, 1>::serialize(&status).unwrap() + SerializerV01::deserialize::( + &SerializerV01::serialize(&status).unwrap() ) .unwrap() );