Skip to content

Commit

Permalink
Merge pull request #207 from EspressoSystems/feat/app-dynamic-ver
Browse files Browse the repository at this point in the history
Runtime binding for app-level format version
  • Loading branch information
jbearer authored Apr 2, 2024
2 parents abd79fc + 8170c7d commit a18b5a6
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 80 deletions.
7 changes: 4 additions & 3 deletions examples/hello-world/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use snafu::Snafu;
use std::io;
use tide_disco::{Api, App, Error, RequestError, StatusCode};
use tracing::info;
use vbs::version::StaticVersion;
use vbs::version::{StaticVersion, StaticVersionType};

type StaticVer01 = StaticVersion<0, 1>;

Expand Down Expand Up @@ -42,7 +42,7 @@ impl From<RequestError> for HelloError {
}

async fn serve(port: u16) -> io::Result<()> {
let mut app = App::<_, HelloError, StaticVer01>::with_state(RwLock::new("Hello".to_string()));
let mut app = App::<_, HelloError>::with_state(RwLock::new("Hello".to_string()));
app.with_version(env!("CARGO_PKG_VERSION").parse().unwrap());

let mut api =
Expand Down Expand Up @@ -78,7 +78,8 @@ async fn serve(port: u16) -> io::Result<()> {
.unwrap();

app.register_module("hello", api).unwrap();
app.serve(format!("0.0.0.0:{}", port)).await
app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance())
.await
}

#[async_std::main]
Expand Down
7 changes: 4 additions & 3 deletions examples/versions/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
use futures::FutureExt;
use std::io;
use tide_disco::{error::ServerError, Api, App};
use vbs::version::StaticVersion;
use vbs::version::{StaticVersion, StaticVersionType};

type StaticVer01 = StaticVersion<0, 1>;

async fn serve(port: u16) -> io::Result<()> {
let mut app = App::<_, ServerError, StaticVer01>::with_state(());
let mut app = App::<_, ServerError>::with_state(());
app.with_version(env!("CARGO_PKG_VERSION").parse().unwrap());

let mut v1 =
Expand All @@ -31,7 +31,8 @@ async fn serve(port: u16) -> io::Result<()> {
.unwrap()
.register_module("api", v2)
.unwrap();
app.serve(format!("0.0.0.0:{}", port)).await
app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance())
.await
}

#[async_std::main]
Expand Down
42 changes: 33 additions & 9 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use snafu::{OptionExt, ResultExt, Snafu};
use std::{
borrow::Cow,
collections::hash_map::{Entry, HashMap, IntoValues, Values},
convert::Infallible,
fmt::Display,
fs,
marker::PhantomData,
Expand Down Expand Up @@ -291,11 +292,23 @@ pub(crate) struct ApiInner<State, Error> {
/// error type. However, it will always be set after `Api::into_inner` is called.
#[derivative(Debug = "ignore")]
error_handler: Option<Arc<dyn ErrorHandler<Error>>>,
/// Response handler encapsulating the serialization format version for version requests
#[derivative(Debug = "ignore")]
version_handler: Arc<dyn VersionHandler>,
public: Option<PathBuf>,
short_description: String,
long_description: String,
}

pub(crate) trait VersionHandler:
Send + Sync + Fn(&Accept, ApiVersion) -> Result<tide::Response, RouteError<Infallible>>
{
}
impl<F> VersionHandler for F where
F: Send + Sync + Fn(&Accept, ApiVersion) -> Result<tide::Response, RouteError<Infallible>>
{
}

impl<'a, State, Error> IntoIterator for &'a ApiInner<State, Error> {
type Item = &'a Route<State, Error>;
type IntoIter = Values<'a, String, Route<State, Error>>;
Expand Down Expand Up @@ -405,6 +418,10 @@ impl<State, Error> ApiInner<State, Error> {
pub(crate) fn error_handler(&self) -> Arc<dyn ErrorHandler<Error>> {
self.error_handler.clone().unwrap()
}

pub(crate) fn version_handler(&self) -> Arc<dyn VersionHandler> {
self.version_handler.clone()
}
}

impl<State, Error, VER> Api<State, Error, VER>
Expand Down Expand Up @@ -496,6 +513,9 @@ where
health_check: Box::new(Self::default_health_check),
api_version: None,
error_handler: None,
version_handler: Arc::new(|accept, version| {
respond_with(accept, version, VER::instance())
}),
public: None,
short_description,
long_description,
Expand Down Expand Up @@ -1290,6 +1310,7 @@ where
health_check: self.inner.health_check,
api_version: self.inner.api_version,
error_handler: None,
version_handler: self.inner.version_handler,
public: self.inner.public,
short_description: self.inner.short_description,
long_description: self.inner.long_description,
Expand Down Expand Up @@ -1436,7 +1457,10 @@ mod test {
use prometheus::{Counter, Registry};
use std::borrow::Cow;
use toml::toml;
use vbs::{version::StaticVersion, BinarySerializer, Serializer};
use vbs::{
version::{StaticVersion, StaticVersionType},
BinarySerializer, Serializer,
};

#[cfg(windows)]
use async_tungstenite::tungstenite::Error as WsError;
Expand Down Expand Up @@ -1471,7 +1495,7 @@ mod test {
async fn test_socket_endpoint() {
setup_test();

let mut app = App::<_, ServerError, StaticVer01>::with_state(RwLock::new(()));
let mut app = App::<_, ServerError>::with_state(RwLock::new(()));
let api_toml = toml! {
[meta]
FORMAT_VERSION = "0.1.0"
Expand Down Expand Up @@ -1532,7 +1556,7 @@ mod test {
}
let port = pick_unused_port().unwrap();
let url: Url = format!("http://localhost:{}", port).parse().unwrap();
spawn(app.serve(format!("0.0.0.0:{}", port)));
spawn(app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance()));

// Create a client that accepts JSON messages.
let mut conn = test_ws_client_with_headers(
Expand Down Expand Up @@ -1616,7 +1640,7 @@ mod test {
async fn test_stream_endpoint() {
setup_test();

let mut app = App::<_, ServerError, StaticVer01>::with_state(RwLock::new(()));
let mut app = App::<_, ServerError>::with_state(RwLock::new(()));
let api_toml = toml! {
[meta]
FORMAT_VERSION = "0.1.0"
Expand Down Expand Up @@ -1654,7 +1678,7 @@ mod test {
}
let port = pick_unused_port().unwrap();
let url: Url = format!("http://localhost:{}", port).parse().unwrap();
spawn(app.serve(format!("0.0.0.0:{}", port)));
spawn(app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance()));

// Consume the `nat` stream.
let mut conn = test_ws_client(url.join("mod/nat").unwrap()).await;
Expand Down Expand Up @@ -1693,7 +1717,7 @@ mod test {
async fn test_custom_healthcheck() {
setup_test();

let mut app = App::<_, ServerError, StaticVer01>::with_state(HealthStatus::Available);
let mut app = App::<_, ServerError>::with_state(HealthStatus::Available);
let api_toml = toml! {
[meta]
FORMAT_VERSION = "0.1.0"
Expand All @@ -1709,7 +1733,7 @@ mod test {
}
let port = pick_unused_port().unwrap();
let url: Url = format!("http://localhost:{}", port).parse().unwrap();
spawn(app.serve(format!("0.0.0.0:{}", port)));
spawn(app.serve(format!("0.0.0.0:{}", port), StaticVer01::instance()));
let client = Client::new(url).await;

let res = client.get("/mod/healthcheck").send().await.unwrap();
Expand Down Expand Up @@ -1738,7 +1762,7 @@ mod test {
metrics.register(Box::new(counter.clone())).unwrap();
let state = State { metrics, counter };

let mut app = App::<_, ServerError, StaticVer01>::with_state(RwLock::new(state));
let mut app = App::<_, ServerError>::with_state(RwLock::new(state));
let api_toml = toml! {
[meta]
FORMAT_VERSION = "0.1.0"
Expand All @@ -1762,7 +1786,7 @@ mod test {
}
let port = pick_unused_port().unwrap();
let url: Url = format!("http://localhost:{port}").parse().unwrap();
spawn(app.serve(format!("0.0.0.0:{port}")));
spawn(app.serve(format!("0.0.0.0:{port}"), StaticVer01::instance()));
let client = Client::new(url).await;

for i in 1..5 {
Expand Down
Loading

0 comments on commit a18b5a6

Please sign in to comment.