Skip to content

Commit

Permalink
feat(openapi): add discriminator to openapi
Browse files Browse the repository at this point in the history
  • Loading branch information
zcabter committed Nov 18, 2024
1 parent eb1a5a3 commit f750ddf
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 48 deletions.
4 changes: 2 additions & 2 deletions crates/jstz_cli/src/deploy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ pub async fn exec(

debug!("Receipt: {:?}", receipt);

let address = match receipt.inner {
let address = match receipt.result {
ReceiptResult::Success(ReceiptContent::DeployFunction(deploy)) => deploy.address,
ReceiptResult::Success(_) => {
bail!("Expected a `DeployFunction` receipt, but got something else.")
}
ReceiptResult::Failed { source: err } => {
ReceiptResult::Failed(err) => {
bail_user_error!("Failed to deploy smart function with error {err:?}.")
}
};
Expand Down
4 changes: 2 additions & 2 deletions crates/jstz_cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ pub async fn exec(
let receipt = jstz_client.wait_for_operation_receipt(&hash).await?;

debug!("Receipt: {:?}", receipt);
let (status_code, headers, body) = match receipt.inner {
let (status_code, headers, body) = match receipt.result {
ReceiptResult::Success(ReceiptContent::RunFunction(run_function)) => (
run_function.status_code,
run_function.headers,
Expand All @@ -174,7 +174,7 @@ pub async fn exec(
bail!("Expected a `RunFunction` receipt, but got something else.")
}

ReceiptResult::Failed { source: err } => bail_user_error!("{err}"),
ReceiptResult::Failed(err) => bail_user_error!("{err}"),
};

if let Some(spinner) = spinner.as_mut() {
Expand Down
51 changes: 27 additions & 24 deletions crates/jstz_node/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,10 @@
],
"title": "RunFunction"
}
]
],
"discriminator": {
"propertyName": "_type"
}
},
"DeployFunction": {
"type": "object",
Expand Down Expand Up @@ -662,13 +665,13 @@
"type": "object",
"required": [
"hash",
"inner"
"result"
],
"properties": {
"hash": {
"type": "string"
},
"inner": {
"result": {
"$ref": "#/components/schemas/ReceiptResult"
}
}
Expand Down Expand Up @@ -785,37 +788,37 @@
],
"title": "FaWithdraw"
}
]
],
"discriminator": {
"propertyName": "_type"
}
},
"ReceiptResult": {
"oneOf": [
{
"allOf": [
{
"$ref": "#/components/schemas/ReceiptContent"
"type": "object",
"title": "Success",
"required": [
"inner",
"_type"
],
"properties": {
"_type": {
"type": "string",
"enum": [
"Success"
]
},
{
"type": "object",
"required": [
"_type"
],
"properties": {
"_type": {
"type": "string",
"enum": [
"Success"
]
}
}
"inner": {
"$ref": "#/components/schemas/ReceiptContent"
}
],
"title": "Success"
}
},
{
"type": "object",
"title": "Failure",
"required": [
"source",
"inner",
"_type"
],
"properties": {
Expand All @@ -825,7 +828,7 @@
"Failed"
]
},
"source": {
"inner": {
"type": "string"
}
}
Expand Down
231 changes: 230 additions & 1 deletion crates/jstz_node/src/api_doc.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use utoipa::OpenApi;
use utoipa::{
openapi::{schema::ArrayItems, Discriminator, OneOf, RefOr, Schema},
OpenApi,
};

#[derive(OpenApi)]
#[openapi(info(
Expand All @@ -11,3 +14,229 @@ use utoipa::OpenApi;
contact(name = "Trilitech", email = "[email protected]"),
))]
pub struct ApiDoc;

/// Modify OpenAPI doc after its been generated
pub fn modify(openapi: &mut utoipa::openapi::OpenApi) {
if let Some(components) = &mut openapi.components {
let schemas = &mut components.schemas;
for (_, schema) in schemas.iter_mut() {
modify_with_discrminator(schema);
}
}
}

/// Adds discriminator property to `oneOf` Schemas by checking if all of its variants
/// contain the `_type` discriminator property that was generate by serde + utoipa.
/// Recursively applies the modification to all nodes in the Schema tree
fn modify_with_discrminator(schema: &mut RefOr<Schema>) {
match schema {
RefOr::T(Schema::AllOf(all_of)) => {
for all_of_schema in all_of.items.iter_mut() {
modify_with_discrminator(all_of_schema)
}
}
RefOr::T(Schema::AnyOf(any_of)) => {
for item in any_of.items.iter_mut() {
modify_with_discrminator(item)
}
}
RefOr::T(Schema::OneOf(one_of)) => {
if is_sum_type(one_of) {
add_discriminator(one_of)
}
}
RefOr::T(Schema::Array(array)) => match &mut array.items {
ArrayItems::RefOrSchema(ref_or_schema) => {
modify_with_discrminator(ref_or_schema)
}
ArrayItems::False => (),
},
RefOr::T(Schema::Object(obj)) => {
for (_, property_schema) in obj.properties.iter_mut() {
modify_with_discrminator(property_schema)
}
}
RefOr::T(_) => (),
RefOr::Ref(_) => (),
}
}

/// Checks that all items in `one_of` are allOfs where at least
/// one member of the allOf set is an object with a single property
/// named "_type"
fn is_sum_type(one_of: &OneOf) -> bool {
one_of.items.iter().all(|item| match item {
RefOr::T(Schema::AllOf(all_of)) => {
all_of.items.iter().any(|all_of_item| match all_of_item {
RefOr::T(Schema::Object(obj)) => {
if obj.properties.len() != 1 {
return false;
}
obj.properties.contains_key("_type")
}
_ => false,
})
}
_ => false,
})
}

fn add_discriminator(one_of: &mut OneOf) {
if one_of.discriminator.is_none() {
let discrimator = Discriminator::new("_type");
one_of.discriminator = Some(discrimator)
}
}

#[cfg(test)]
mod test {

use jstz_crypto::public_key_hash::PublicKeyHash;
use jstz_proto::{operation::Content, receipt::ReceiptContent};
use utoipa::{
openapi::{
schema::{ArrayItems, SchemaType},
AllOfBuilder, ArrayBuilder, ComponentsBuilder, ObjectBuilder, OpenApi,
OpenApiBuilder, RefOr, Schema, Type,
},
schema, PartialSchema,
};

use super::modify;

fn unsafe_get_schema(
open_api: &OpenApi,
schema_name: impl Into<String>,
) -> RefOr<Schema> {
open_api
.components
.clone()
.unwrap()
.schemas
.get(schema_name.into().as_str())
.unwrap()
.clone()
}

fn check_contains_discriminator(schema: RefOr<Schema>) {
assert!(matches!(schema, RefOr::T(Schema::OneOf(one_of))
if one_of.discriminator.clone().unwrap().property_name == "_type"))
}

fn check_discriminator(
open_api: &OpenApi,
schema_name: impl Into<String>,
discriminator_should_exist: bool,
) {
let schema = unsafe_get_schema(open_api, schema_name);
if discriminator_should_exist {
check_contains_discriminator(schema)
} else if let RefOr::T(Schema::OneOf(one_of)) = schema {
assert!(one_of.discriminator.is_none())
}
}

#[test]
fn modify_discriminator_one_of() {
let mut open_api = OpenApiBuilder::new()
.components(Some(
ComponentsBuilder::new()
.schema_from::<ReceiptContent>()
.schema_from::<Content>()
.build(),
))
.build();

modify(&mut open_api);
check_discriminator(&open_api, "ReceiptContent", true);
check_discriminator(&open_api, "Content", true);
}

#[test]
fn modify_discriminator_one_of_non_discriminant_type() {
let mut open_api = OpenApiBuilder::new()
.components(Some(
ComponentsBuilder::new()
.schema_from::<PublicKeyHash>()
.build(),
))
.build();
modify(&mut open_api);
check_discriminator(&open_api, "PublicKeyHash", false);
}

#[test]
fn modify_discriminator_on_all_of() {
let mut open_api = OpenApiBuilder::new()
.components(Some(
ComponentsBuilder::new()
.schema(
"Test",
AllOfBuilder::new().item(
ObjectBuilder::new()
.schema_type(SchemaType::Type(Type::String)),
),
)
.build(),
))
.build();
modify(&mut open_api);
check_discriminator(&open_api, "Test", false);
}

#[test]
fn modify_discriminator_on_array() {
let mut open_api = OpenApiBuilder::new()
.components(Some(
ComponentsBuilder::new()
.schema("Test", ArrayBuilder::new().items(ReceiptContent::schema()))
.build(),
))
.build();
modify(&mut open_api);

check_discriminator(&open_api, "Test", false);

let array = unsafe_get_schema(&open_api, "Test");
match array {
RefOr::T(Schema::Array(arr)) => {
if let ArrayItems::RefOrSchema(schema) = arr.items {
// Checks that inner inline schema definitions
// are modified
check_contains_discriminator(*schema)
} else {
panic!("Expected schema")
}
}
_ => panic!("Expected array"),
}
}

#[test]
fn modify_discrminator_on_object() {
let mut open_api = OpenApiBuilder::new()
.components(Some(
ComponentsBuilder::new()
.schema(
"Test",
ObjectBuilder::new()
.property("test", schema!(Vec<String>))
.property("content", Content::schema())
.build(),
)
.build(),
))
.build();
modify(&mut open_api);
check_discriminator(&open_api, "Test", false);

let object = unsafe_get_schema(&open_api, "Test");
match object {
RefOr::T(Schema::Object(obj)) => {
let content = obj.properties.get("content").unwrap();
check_contains_discriminator(content.clone())
}
_ => panic!("Expected object"),
}
}
}
11 changes: 6 additions & 5 deletions crates/jstz_node/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::Result;
use api_doc::ApiDoc;
use api_doc::{modify, ApiDoc};
use axum::{http, routing::get};
use config::JstzNodeConfig;
use octez::OctezRollupClient;
Expand Down Expand Up @@ -65,8 +65,8 @@ pub async fn run(
.allow_origin(Any)
.allow_headers(Any);

let (router, openapi) = router().with_state(state).layer(cors).split_for_parts();

let (router, mut openapi) = router().with_state(state).layer(cors).split_for_parts();
modify(&mut openapi);
let router = router.merge(Scalar::with_url("/scalar", openapi));

let listener = TcpListener::bind(format!("{}:{}", addr, port)).await?;
Expand All @@ -86,6 +86,7 @@ fn router() -> OpenApiRouter<AppState> {
}

pub fn openapi_json_raw() -> anyhow::Result<String> {
let doc = router().split_for_parts().1.to_pretty_json()?;
Ok(doc)
let mut doc = router().split_for_parts().1;
modify(&mut doc);
Ok(doc.to_pretty_json()?)
}
Loading

0 comments on commit f750ddf

Please sign in to comment.