From f66fc89d4ecda88a5cb257285109d1fdbaf54e6e Mon Sep 17 00:00:00 2001 From: Sahil Yeole <73148455+beelchester@users.noreply.github.com> Date: Mon, 26 Aug 2024 22:13:34 +0530 Subject: [PATCH 1/2] feat(2690): make llm models configurable (#2716) Co-authored-by: Tushar Mathur Co-authored-by: Sandipsinh Dilipsinh Rathod <62684960+ssddOnTop@users.noreply.github.com> Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: Mehul Mathur --- src/cli/generator/config.rs | 37 +- src/cli/generator/generator.rs | 26 +- src/cli/llm/infer_type_name.rs | 15 +- src/cli/llm/mod.rs | 1 - src/cli/llm/model.rs | 73 --- src/cli/llm/wizard.rs | 5 +- ..._fixtures__generator__gen_deezer.json.snap | 419 ++++++++++++++++++ ...__generator__gen_jsonplaceholder.json.snap | 81 ++++ 8 files changed, 546 insertions(+), 111 deletions(-) delete mode 100644 src/cli/llm/model.rs create mode 100644 tests/cli/snapshots/cli_spec__test__generator_spec__tests__cli__fixtures__generator__gen_deezer.json.snap create mode 100644 tests/cli/snapshots/cli_spec__test__generator_spec__tests__cli__fixtures__generator__gen_jsonplaceholder.json.snap diff --git a/src/cli/generator/config.rs b/src/cli/generator/config.rs index 08d9fd40da..dab568cb42 100644 --- a/src/cli/generator/config.rs +++ b/src/cli/generator/config.rs @@ -25,8 +25,18 @@ pub struct Config { #[serde(skip_serializing_if = "Option::is_none")] pub preset: Option, pub schema: Schema, - #[serde(skip_serializing_if = "TemplateString::is_empty")] - pub secret: TemplateString, + #[serde(skip_serializing_if = "Option::is_none")] + pub llm: Option, +} + +#[derive(Deserialize, Serialize, Debug, Default, PartialEq, Clone)] +#[serde(rename_all = "camelCase")] +#[serde(deny_unknown_fields)] +pub struct LLMConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub secret: Option, } #[derive(Clone, Deserialize, Serialize, Debug, Default)] @@ -273,13 +283,17 @@ impl Config { .collect::>>>()?; let output = self.output.resolve(parent_dir)?; + let llm = self.llm.map(|llm| { + let secret = llm.secret.map(|s| s.resolve(&reader_context)); + LLMConfig { model: llm.model, secret } + }); Ok(Config { inputs, output, schema: self.schema, preset: self.preset, - secret: self.secret.resolve(&reader_context), + llm, }) } } @@ -419,7 +433,7 @@ mod tests { fn test_raise_error_unknown_field_at_root_level() { let json = r#"{"input": "value"}"#; let expected_error = - "unknown field `input`, expected one of `inputs`, `output`, `preset`, `schema`, `secret` at line 1 column 8"; + "unknown field `input`, expected one of `inputs`, `output`, `preset`, `schema`, `llm` at line 1 column 8"; assert_deserialization_error(json, expected_error); } @@ -492,7 +506,7 @@ mod tests { } #[test] - fn test_secret() { + fn test_llm_config() { let mut env_vars = HashMap::new(); let token = "eyJhbGciOiJIUzI1NiIsInR5"; env_vars.insert("TAILCALL_SECRET".to_owned(), token.to_owned()); @@ -506,12 +520,17 @@ mod tests { headers: Default::default(), }; - let config = - Config::default().secret(TemplateString::parse("{{.env.TAILCALL_SECRET}}").unwrap()); + let config = Config::default().llm(Some(LLMConfig { + model: Some("gpt-3.5-turbo".to_string()), + secret: Some(TemplateString::parse("{{.env.TAILCALL_SECRET}}").unwrap()), + })); let resolved_config = config.into_resolved("", reader_ctx).unwrap(); - let actual = resolved_config.secret; - let expected = TemplateString::from("eyJhbGciOiJIUzI1NiIsInR5"); + let actual = resolved_config.llm; + let expected = Some(LLMConfig { + model: Some("gpt-3.5-turbo".to_string()), + secret: Some(TemplateString::from(token)), + }); assert_eq!(actual, expected); } diff --git a/src/cli/generator/generator.rs b/src/cli/generator/generator.rs index b54ce98224..f34ce02616 100644 --- a/src/cli/generator/generator.rs +++ b/src/cli/generator/generator.rs @@ -6,7 +6,7 @@ use hyper::HeaderMap; use inquire::Confirm; use pathdiff::diff_paths; -use super::config::{Config, Resolved, Source}; +use super::config::{Config, LLMConfig, Resolved, Source}; use super::source::ConfigSource; use crate::cli::llm::InferTypeName; use crate::core::config::transformer::{Preset, RenameTypes}; @@ -164,7 +164,7 @@ impl Generator { let query_type = config.schema.query.clone(); let mutation_type_name = config.schema.mutation.clone(); - let secret = config.secret.clone(); + let llm = config.llm.clone(); let preset = config.preset.clone().unwrap_or_default(); let preset: Preset = preset.validate_into().to_result()?; let input_samples = self.resolve_io(config).await?; @@ -180,19 +180,15 @@ impl Generator { let mut config = config_gen.mutation(mutation_type_name).generate(true)?; if infer_type_names { - let key = if !secret.is_empty() { - Some(secret.to_string()) - } else { - None - }; - - let mut llm_gen = InferTypeName::new(key); - let suggested_names = llm_gen.generate(config.config()).await?; - let cfg = RenameTypes::new(suggested_names.iter()) - .transform(config.config().to_owned()) - .to_result()?; - - config = ConfigModule::from(cfg); + if let Some(LLMConfig { model: Some(model), secret }) = llm { + let mut llm_gen = InferTypeName::new(model, secret.map(|s| s.to_string())); + let suggested_names = llm_gen.generate(config.config()).await?; + let cfg = RenameTypes::new(suggested_names.iter()) + .transform(config.config().to_owned()) + .to_result()?; + + config = ConfigModule::from(cfg); + } } self.write(&config, &path).await?; diff --git a/src/cli/llm/infer_type_name.rs b/src/cli/llm/infer_type_name.rs index 61ae8bb360..cfba78ff77 100644 --- a/src/cli/llm/infer_type_name.rs +++ b/src/cli/llm/infer_type_name.rs @@ -4,14 +4,12 @@ use genai::chat::{ChatMessage, ChatRequest, ChatResponse}; use serde::{Deserialize, Serialize}; use serde_json::json; -use super::model::groq; use super::{Error, Result, Wizard}; use crate::core::config::Config; use crate::core::Mustache; -#[derive(Default)] pub struct InferTypeName { - secret: Option, + wizard: Wizard, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -74,14 +72,11 @@ impl TryInto for Question { } impl InferTypeName { - pub fn new(secret: Option) -> InferTypeName { - Self { secret } + pub fn new(model: String, secret: Option) -> InferTypeName { + Self { wizard: Wizard::new(model, secret) } } - pub async fn generate(&mut self, config: &Config) -> Result> { - let secret = self.secret.as_ref().map(|s| s.to_owned()); - - let wizard: Wizard = Wizard::new(groq::LLAMA38192, secret); + pub async fn generate(&mut self, config: &Config) -> Result> { let mut new_name_mappings: HashMap = HashMap::new(); // removed root type from types. @@ -104,7 +99,7 @@ impl InferTypeName { let mut delay = 3; loop { - let answer = wizard.ask(question.clone()).await; + let answer = self.wizard.ask(question.clone()).await; match answer { Ok(answer) => { let name = &answer.suggestions.join(", "); diff --git a/src/cli/llm/mod.rs b/src/cli/llm/mod.rs index ef63fb9d4a..40c0dce610 100644 --- a/src/cli/llm/mod.rs +++ b/src/cli/llm/mod.rs @@ -3,7 +3,6 @@ pub mod infer_type_name; pub use error::Error; use error::Result; pub use infer_type_name::InferTypeName; -mod model; mod wizard; pub use wizard::Wizard; diff --git a/src/cli/llm/model.rs b/src/cli/llm/model.rs deleted file mode 100644 index a3da95d8eb..0000000000 --- a/src/cli/llm/model.rs +++ /dev/null @@ -1,73 +0,0 @@ -#![allow(unused)] - -use std::borrow::Cow; -use std::fmt::{Display, Formatter}; -use std::marker::PhantomData; - -use derive_setters::Setters; -use genai::adapter::AdapterKind; - -#[derive(Clone)] -pub struct Model(&'static str); - -pub mod open_ai { - use super::*; - pub const GPT3_5_TURBO: Model = Model("gp-3.5-turbo"); - pub const GPT4: Model = Model("gpt-4"); - pub const GPT4_TURBO: Model = Model("gpt-4-turbo"); - pub const GPT4O_MINI: Model = Model("gpt-4o-mini"); - pub const GPT4O: Model = Model("gpt-4o"); -} - -pub mod ollama { - use super::*; - pub const GEMMA2B: Model = Model("gemma:2b"); -} - -pub mod anthropic { - use super::*; - pub const CLAUDE3_HAIKU_20240307: Model = Model("claude-3-haiku-20240307"); - pub const CLAUDE3_SONNET_20240229: Model = Model("claude-3-sonnet-20240229"); - pub const CLAUDE3_OPUS_20240229: Model = Model("claude-3-opus-20240229"); - pub const CLAUDE35_SONNET_20240620: Model = Model("claude-3-5-sonnet-20240620"); -} - -pub mod cohere { - use super::*; - pub const COMMAND_LIGHT_NIGHTLY: Model = Model("command-light-nightly"); - pub const COMMAND_LIGHT: Model = Model("command-light"); - pub const COMMAND_NIGHTLY: Model = Model("command-nightly"); - pub const COMMAND: Model = Model("command"); - pub const COMMAND_R: Model = Model("command-r"); - pub const COMMAND_R_PLUS: Model = Model("command-r-plus"); -} - -pub mod gemini { - use super::*; - pub const GEMINI15_FLASH_LATEST: Model = Model("gemini-1.5-flash-latest"); - pub const GEMINI10_PRO: Model = Model("gemini-1.0-pro"); - pub const GEMINI15_FLASH: Model = Model("gemini-1.5-flash"); - pub const GEMINI15_PRO: Model = Model("gemini-1.5-pro"); -} - -pub mod groq { - use super::*; - pub const LLAMA708192: Model = Model("llama3-70b-8192"); - pub const LLAMA38192: Model = Model("llama3-8b-8192"); - pub const LLAMA_GROQ8B8192_TOOL_USE_PREVIEW: Model = - Model("llama3-groq-8b-8192-tool-use-preview"); - pub const LLAMA_GROQ70B8192_TOOL_USE_PREVIEW: Model = - Model("llama3-groq-70b-8192-tool-use-preview"); - pub const GEMMA29B_IT: Model = Model("gemma2-9b-it"); - pub const GEMMA7B_IT: Model = Model("gemma-7b-it"); - pub const MIXTRAL_8X7B32768: Model = Model("mixtral-8x7b-32768"); - pub const LLAMA8B_INSTANT: Model = Model("llama-3.1-8b-instant"); - pub const LLAMA70B_VERSATILE: Model = Model("llama-3.1-70b-versatile"); - pub const LLAMA405B_REASONING: Model = Model("llama-3.1-405b-reasoning"); -} - -impl Model { - pub fn as_str(&self) -> &'static str { - self.0 - } -} diff --git a/src/cli/llm/wizard.rs b/src/cli/llm/wizard.rs index 1604d7f15f..46d7a18624 100644 --- a/src/cli/llm/wizard.rs +++ b/src/cli/llm/wizard.rs @@ -5,18 +5,17 @@ use genai::resolver::AuthResolver; use genai::Client; use super::Result; -use crate::cli::llm::model::Model; #[derive(Setters, Clone)] pub struct Wizard { client: Client, - model: Model, + model: String, _q: std::marker::PhantomData, _a: std::marker::PhantomData, } impl Wizard { - pub fn new(model: Model, secret: Option) -> Self { + pub fn new(model: String, secret: Option) -> Self { let mut config = genai::adapter::AdapterConfig::default(); if let Some(key) = secret { config = config.with_auth_resolver(AuthResolver::from_key_value(key)); diff --git a/tests/cli/snapshots/cli_spec__test__generator_spec__tests__cli__fixtures__generator__gen_deezer.json.snap b/tests/cli/snapshots/cli_spec__test__generator_spec__tests__cli__fixtures__generator__gen_deezer.json.snap new file mode 100644 index 0000000000..3f1b48ca31 --- /dev/null +++ b/tests/cli/snapshots/cli_spec__test__generator_spec__tests__cli__fixtures__generator__gen_deezer.json.snap @@ -0,0 +1,419 @@ +--- +source: tests/cli/gen.rs +expression: config.to_sdl() +--- +schema @server @upstream(baseURL: "https://api.deezer.com") { + query: Query +} + +type Album { + cover: String + cover_big: String + cover_medium: String + cover_small: String + cover_xl: String + id: Int + md5_image: String + title: String + tracklist: String + type: String +} + +type Artist { + id: Int + link: String + name: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + radio: Boolean + tracklist: String + type: String +} + +type Chart { + albums: T167 + artists: T169 + playlists: T181 + podcasts: Podcast + tracks: T166 +} + +type Contributor { + id: Int + link: String + name: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + radio: Boolean + role: String + share: String + tracklist: String + type: String +} + +type Datum { + album: Album + artist: T42 + duration: Int + explicit_content_cover: Int + explicit_content_lyrics: Int + explicit_lyrics: Boolean + id: Int + link: String + md5_image: String + preview: String + rank: Int + readable: Boolean + time_add: Int + title: String + title_short: String + title_version: String + type: String +} + +type Editorial { + data: [T185] + total: Int +} + +type Genre { + data: [T5] +} + +type Playlist { + checksum: String + collaborative: Boolean + creation_date: String + creator: User + description: String + duration: Int + fans: Int + id: Int + is_loved_track: Boolean + link: String + md5_image: String + nb_tracks: Int + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_type: String + picture_xl: String + public: Boolean + share: String + title: String + tracklist: String + tracks: Track + type: String +} + +type Podcast { + data: [T182] + total: Int +} + +type Query { + album(p1: Int!): T39 @http(path: "/album/{{.args.p1}}") + artist(p1: Int!): T40 @http(path: "/artist/{{.args.p1}}") + chart: Chart @http(path: "/chart") + editorial: Editorial @http(path: "/editorial") + playlist(p1: Int!): Playlist @http(path: "/playlist/{{.args.p1}}") + search(q: String): Search @http(path: "/search", query: [{key: "q", value: "{{.args.q}}"}]) + track(p1: Int!): T4 @http(path: "/track/{{.args.p1}}") + user(p1: Int!): T187 @http(path: "/user/{{.args.p1}}") +} + +type Search { + data: [JSON] + next: String + total: Int +} + +type T165 { + album: Album + artist: Artist + duration: Int + explicit_content_cover: Int + explicit_content_lyrics: Int + explicit_lyrics: Boolean + id: Int + link: String + md5_image: String + position: Int + preview: String + rank: Int + title: String + title_short: String + title_version: String + type: String +} + +type T166 { + data: [T165] + total: Int +} + +type T167 { + data: [JSON] + total: Int +} + +type T168 { + id: Int + link: String + name: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + position: Int + radio: Boolean + tracklist: String + type: String +} + +type T169 { + data: [T168] + total: Int +} + +type T180 { + checksum: String + creation_date: String + id: Int + link: String + md5_image: String + nb_tracks: Int + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_type: String + picture_xl: String + public: Boolean + title: String + tracklist: String + type: String + user: User +} + +type T181 { + data: [T180] + total: Int +} + +type T182 { + available: Boolean + description: String + fans: Int + id: Int + link: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + share: String + title: String + type: String +} + +type T185 { + id: Int + name: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + type: String +} + +type T187 { + country: String + id: Int + link: String + name: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + tracklist: String + type: String +} + +type T2 { + id: Int + link: String + name: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + radio: Boolean + share: String + tracklist: String + type: String +} + +type T3 { + cover: String + cover_big: String + cover_medium: String + cover_small: String + cover_xl: String + id: Int + link: String + md5_image: String + release_date: String + title: String + tracklist: String + type: String +} + +type T37 { + album: Album + artist: User + duration: Int + explicit_content_cover: Int + explicit_content_lyrics: Int + explicit_lyrics: Boolean + id: Int + link: String + md5_image: String + preview: String + rank: Int + readable: Boolean + title: String + title_short: String + title_version: String + type: String +} + +type T38 { + data: [T37] +} + +type T39 { + artist: T8 + available: Boolean + contributors: [Contributor] + cover: String + cover_big: String + cover_medium: String + cover_small: String + cover_xl: String + duration: Int + explicit_content_cover: Int + explicit_content_lyrics: Int + explicit_lyrics: Boolean + fans: Int + genre_id: Int + genres: Genre + id: Int + label: String + link: String + md5_image: String + nb_tracks: Int + record_type: String + release_date: String + share: String + title: String + tracklist: String + tracks: T38 + type: String + upc: String +} + +type T4 { + album: T3 + artist: T2 + available_countries: [String] + bpm: Int + contributors: [Contributor] + disk_number: Int + duration: Int + explicit_content_cover: Int + explicit_content_lyrics: Int + explicit_lyrics: Boolean + gain: Int + id: Int + isrc: String + link: String + md5_image: String + preview: String + rank: Int + readable: Boolean + release_date: String + share: String + title: String + title_short: String + title_version: String + track_position: Int + type: String +} + +type T40 { + id: Int + link: String + name: String + nb_album: Int + nb_fan: Int + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + radio: Boolean + share: String + tracklist: String + type: String +} + +type T42 { + id: Int + link: String + name: String + tracklist: String + type: String +} + +type T5 { + id: Int + name: String + picture: String + type: String +} + +type T8 { + id: Int + name: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + tracklist: String + type: String +} + +type Track { + checksum: String + data: [Datum] +} + +type User { + id: Int + name: String + tracklist: String + type: String +} diff --git a/tests/cli/snapshots/cli_spec__test__generator_spec__tests__cli__fixtures__generator__gen_jsonplaceholder.json.snap b/tests/cli/snapshots/cli_spec__test__generator_spec__tests__cli__fixtures__generator__gen_jsonplaceholder.json.snap new file mode 100644 index 0000000000..47bdfb6296 --- /dev/null +++ b/tests/cli/snapshots/cli_spec__test__generator_spec__tests__cli__fixtures__generator__gen_jsonplaceholder.json.snap @@ -0,0 +1,81 @@ +--- +source: tests/cli/gen.rs +expression: config.to_sdl() +--- +schema @server @upstream(baseURL: "https://jsonplaceholder.typicode.com") { + query: Query +} + +type Address { + city: String + geo: Geo + street: String + suite: String + zipcode: String +} + +type Comment { + body: String + email: String + id: Int + name: String + postId: Int +} + +type Company { + bs: String + catchPhrase: String + name: String +} + +type Geo { + lat: String + lng: String +} + +type Photo { + albumId: Int + id: Int + thumbnailUrl: String + title: String + url: String +} + +type Post { + body: String + id: Int + title: String + userId: Int +} + +type Query { + comment(p1: Int!): Comment @http(path: "/comments/{{.args.p1}}") + comments: [Comment] @http(path: "/comments") + photo(p1: Int!): Photo @http(path: "/photos/{{.args.p1}}") + photos: [Photo] @http(path: "/photos") + post(p1: Int!): Post @http(path: "/posts/{{.args.p1}}") + postComments(postId: Int): [Comment] @http(path: "/comments", query: [{key: "postId", value: "{{.args.postId}}"}]) + posts: [Post] @http(path: "/posts") + todo(p1: Int!): Todo @http(path: "/todos/{{.args.p1}}") + todos: [Todo] @http(path: "/todos") + user(p1: Int!): User @http(path: "/users/{{.args.p1}}") + users: [User] @http(path: "/users") +} + +type Todo { + completed: Boolean + id: Int + title: String + userId: Int +} + +type User { + address: Address + company: Company + email: String + id: Int + name: String + phone: String + username: String + website: String +} From 452648e36cc9931ca6cb25c056547b6a8e6e44da Mon Sep 17 00:00:00 2001 From: Mehul Mathur Date: Mon, 26 Aug 2024 23:52:08 +0530 Subject: [PATCH 2/2] refactor: move cli error to core (#2708) Co-authored-by: Tushar Mathur --- src/cli/mod.rs | 2 - src/cli/runtime/file.rs | 7 +- src/cli/server/http_1.rs | 6 +- src/cli/server/http_2.rs | 4 +- src/cli/server/http_server.rs | 4 +- src/cli/tc/check.rs | 4 +- src/cli/telemetry.rs | 6 +- src/{cli/error.rs => core/errata.rs} | 113 ++++++++++++++------------- src/core/grpc/request.rs | 2 +- src/core/http/response.rs | 6 +- src/core/ir/error.rs | 69 +++++++++++----- src/core/ir/eval.rs | 7 +- src/core/ir/eval_http.rs | 2 +- src/core/mod.rs | 2 + src/main.rs | 6 +- 15 files changed, 131 insertions(+), 109 deletions(-) rename src/{cli/error.rs => core/errata.rs} (78%) diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 59798bc6b5..eebb64373a 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -1,5 +1,4 @@ pub mod command; -mod error; mod fmt; pub mod generator; #[cfg(feature = "js")] @@ -11,5 +10,4 @@ pub mod server; mod tc; pub mod telemetry; pub(crate) mod update_checker; -pub use error::CLIError; pub use tc::run::run; diff --git a/src/cli/runtime/file.rs b/src/cli/runtime/file.rs index 3d9be7ed77..72a55160c6 100644 --- a/src/cli/runtime/file.rs +++ b/src/cli/runtime/file.rs @@ -1,7 +1,6 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use crate::cli::CLIError; -use crate::core::FileIO; +use crate::core::{Errata, FileIO}; #[derive(Clone)] pub struct NativeFileIO {} @@ -29,7 +28,7 @@ async fn write<'a>(path: &'a str, content: &'a [u8]) -> anyhow::Result<()> { impl FileIO for NativeFileIO { async fn write<'a>(&'a self, path: &'a str, content: &'a [u8]) -> anyhow::Result<()> { write(path, content).await.map_err(|err| { - CLIError::new(format!("Failed to write file: {}", path).as_str()) + Errata::new(format!("Failed to write file: {}", path).as_str()) .description(err.to_string()) })?; tracing::info!("File write: {} ... ok", path); @@ -38,7 +37,7 @@ impl FileIO for NativeFileIO { async fn read<'a>(&'a self, path: &'a str) -> anyhow::Result { let content = read(path).await.map_err(|err| { - CLIError::new(format!("Failed to read file: {}", path).as_str()) + Errata::new(format!("Failed to read file: {}", path).as_str()) .description(err.to_string()) })?; tracing::info!("File read: {} ... ok", path); diff --git a/src/cli/server/http_1.rs b/src/cli/server/http_1.rs index 22d7d97c96..76360e860a 100644 --- a/src/cli/server/http_1.rs +++ b/src/cli/server/http_1.rs @@ -4,9 +4,9 @@ use hyper::service::{make_service_fn, service_fn}; use tokio::sync::oneshot; use super::server_config::ServerConfig; -use crate::cli::CLIError; use crate::core::async_graphql_hyper::{GraphQLBatchRequest, GraphQLRequest}; use crate::core::http::handle_request; +use crate::core::Errata; pub async fn start_http_1( sc: Arc, @@ -31,7 +31,7 @@ pub async fn start_http_1( } }); let builder = hyper::Server::try_bind(&addr) - .map_err(CLIError::from)? + .map_err(Errata::from)? .http1_pipeline_flush(sc.app_ctx.blueprint.server.pipeline_flush); super::log_launch(sc.as_ref()); @@ -48,7 +48,7 @@ pub async fn start_http_1( builder.serve(make_svc_single_req).await }; - let result = server.map_err(CLIError::from); + let result = server.map_err(Errata::from); Ok(result?) } diff --git a/src/cli/server/http_2.rs b/src/cli/server/http_2.rs index 30ee21b5d1..1895789603 100644 --- a/src/cli/server/http_2.rs +++ b/src/cli/server/http_2.rs @@ -9,9 +9,9 @@ use rustls_pki_types::{CertificateDer, PrivateKeyDer}; use tokio::sync::oneshot; use super::server_config::ServerConfig; -use crate::cli::CLIError; use crate::core::async_graphql_hyper::{GraphQLBatchRequest, GraphQLRequest}; use crate::core::http::handle_request; +use crate::core::Errata; pub async fn start_http_2( sc: Arc, @@ -60,7 +60,7 @@ pub async fn start_http_2( builder.serve(make_svc_single_req).await }; - let result = server.map_err(CLIError::from); + let result = server.map_err(Errata::from); Ok(result?) } diff --git a/src/cli/server/http_server.rs b/src/cli/server/http_server.rs index 62928c492f..3661c9f5f7 100644 --- a/src/cli/server/http_server.rs +++ b/src/cli/server/http_server.rs @@ -8,9 +8,9 @@ use super::http_1::start_http_1; use super::http_2::start_http_2; use super::server_config::ServerConfig; use crate::cli::telemetry::init_opentelemetry; -use crate::cli::CLIError; use crate::core::blueprint::{Blueprint, Http}; use crate::core::config::ConfigModule; +use crate::core::Errata; pub struct Server { config_module: ConfigModule, @@ -32,7 +32,7 @@ impl Server { /// Starts the server in the current Runtime pub async fn start(self) -> Result<()> { - let blueprint = Blueprint::try_from(&self.config_module).map_err(CLIError::from)?; + let blueprint = Blueprint::try_from(&self.config_module).map_err(Errata::from)?; let endpoints = self.config_module.extensions().endpoint_set.clone(); let server_config = Arc::new(ServerConfig::new(blueprint.clone(), endpoints).await?); diff --git a/src/cli/tc/check.rs b/src/cli/tc/check.rs index 9e41cb7a9d..6816836092 100644 --- a/src/cli/tc/check.rs +++ b/src/cli/tc/check.rs @@ -2,11 +2,11 @@ use anyhow::Result; use super::helpers::{display_schema, log_endpoint_set}; use crate::cli::fmt::Fmt; -use crate::cli::CLIError; use crate::core::blueprint::Blueprint; use crate::core::config::reader::ConfigReader; use crate::core::config::Source; use crate::core::runtime::TargetRuntime; +use crate::core::Errata; pub(super) struct CheckParams { pub(super) file_paths: Vec, @@ -24,7 +24,7 @@ pub(super) async fn check_command(params: CheckParams, config_reader: &ConfigRea if let Some(format) = format { Fmt::display(format.encode(&config_module)?); } - let blueprint = Blueprint::try_from(&config_module).map_err(CLIError::from); + let blueprint = Blueprint::try_from(&config_module).map_err(Errata::from); match blueprint { Ok(blueprint) => { diff --git a/src/cli/telemetry.rs b/src/cli/telemetry.rs index 46a64cba6c..50ea364d41 100644 --- a/src/cli/telemetry.rs +++ b/src/cli/telemetry.rs @@ -24,12 +24,12 @@ use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::{Layer, Registry}; use super::metrics::init_metrics; -use crate::cli::CLIError; use crate::core::blueprint::telemetry::{OtlpExporter, Telemetry, TelemetryExporter}; use crate::core::runtime::TargetRuntime; use crate::core::tracing::{ default_tracing, default_tracing_tailcall, get_log_level, tailcall_filter_target, }; +use crate::core::Errata; static RESOURCE: Lazy = Lazy::new(|| { Resource::default().merge(&Resource::new(vec![ @@ -206,8 +206,8 @@ pub fn init_opentelemetry(config: Telemetry, runtime: &TargetRuntime) -> anyhow: | global::Error::Log(LogError::Other(_)), ) { tracing::subscriber::with_default(default_tracing_tailcall(), || { - let cli = crate::cli::CLIError::new("Open Telemetry Error") - .caused_by(vec![CLIError::new(error.to_string().as_str())]) + let cli = crate::core::Errata::new("Open Telemetry Error") + .caused_by(vec![Errata::new(error.to_string().as_str())]) .trace(vec!["schema".to_string(), "@telemetry".to_string()]); tracing::error!("{}", cli.color(true)); }); diff --git a/src/cli/error.rs b/src/core/errata.rs similarity index 78% rename from src/cli/error.rs rename to src/core/errata.rs index 50e06e5301..7cc493ef3b 100644 --- a/src/cli/error.rs +++ b/src/core/errata.rs @@ -2,12 +2,15 @@ use std::fmt::{Debug, Display}; use colored::Colorize; use derive_setters::Setters; -use thiserror::Error; +use crate::core::error::Error as CoreError; use crate::core::valid::ValidationError; -#[derive(Debug, Error, Setters, PartialEq, Clone)] -pub struct CLIError { +/// The moral equivalent of a serde_json::Value but for errors. +/// It's a data structure like Value that can hold any error in an untyped +/// manner. +#[derive(Debug, thiserror::Error, Setters, PartialEq, Clone)] +pub struct Errata { is_root: bool, #[setters(skip)] color: bool, @@ -17,12 +20,12 @@ pub struct CLIError { trace: Vec, #[setters(skip)] - caused_by: Vec, + caused_by: Vec, } -impl CLIError { +impl Errata { pub fn new(message: &str) -> Self { - CLIError { + Errata { is_root: true, color: false, message: message.to_string(), @@ -32,7 +35,7 @@ impl CLIError { } } - pub fn caused_by(mut self, error: Vec) -> Self { + pub fn caused_by(mut self, error: Vec) -> Self { self.caused_by = error; for error in self.caused_by.iter_mut() { @@ -82,7 +85,7 @@ fn bullet(str: &str) -> String { chars.into_iter().collect::() } -impl Display for CLIError { +impl Display for Errata { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let default_padding = 2; @@ -132,46 +135,37 @@ impl Display for CLIError { } } -impl From for CLIError { +impl From for Errata { fn from(error: hyper::Error) -> Self { - // TODO: add type-safety to CLIError conversion - let cli_error = CLIError::new("Server Failed"); + // TODO: add type-safety to Errata conversion + let cli_error = Errata::new("Server Failed"); let message = error.to_string(); if message.to_lowercase().contains("os error 48") { cli_error .description("The port is already in use".to_string()) - .caused_by(vec![CLIError::new(message.as_str())]) + .caused_by(vec![Errata::new(message.as_str())]) } else { cli_error.description(message) } } } -impl From for CLIError { - fn from(error: rustls::Error) -> Self { - let cli_error = CLIError::new("Failed to create TLS Acceptor"); - let message = error.to_string(); - - cli_error.description(message) - } -} - -impl From for CLIError { +impl From for Errata { fn from(error: anyhow::Error) -> Self { - // Convert other errors to CLIError - let cli_error = match error.downcast::() { + // Convert other errors to Errata + let cli_error = match error.downcast::() { Ok(cli_error) => cli_error, Err(error) => { - // Convert other errors to CLIError + // Convert other errors to Errata let cli_error = match error.downcast::>() { - Ok(validation_error) => CLIError::from(validation_error), + Ok(validation_error) => Errata::from(validation_error), Err(error) => { let sources = error .source() - .map(|error| vec![CLIError::new(error.to_string().as_str())]) + .map(|error| vec![Errata::new(error.to_string().as_str())]) .unwrap_or_default(); - CLIError::new(&error.to_string()).caused_by(sources) + Errata::new(&error.to_string()).caused_by(sources) } }; cli_error @@ -181,24 +175,32 @@ impl From for CLIError { } } -impl From for CLIError { +impl From for Errata { fn from(error: std::io::Error) -> Self { - let cli_error = CLIError::new("IO Error"); + let cli_error = Errata::new("IO Error"); let message = error.to_string(); cli_error.description(message) } } -impl<'a> From> for CLIError { +impl From for Errata { + fn from(error: CoreError) -> Self { + let cli_error = Errata::new("Core Error"); + let message = error.to_string(); + + cli_error.description(message) + } +} + +impl<'a> From> for Errata { fn from(error: ValidationError<&'a str>) -> Self { - CLIError::new("Invalid Configuration").caused_by( + Errata::new("Invalid Configuration").caused_by( error .as_vec() .iter() .map(|cause| { - let mut err = - CLIError::new(cause.message).trace(Vec::from(cause.trace.clone())); + let mut err = Errata::new(cause.message).trace(Vec::from(cause.trace.clone())); if let Some(description) = cause.description { err = err.description(description.to_owned()); } @@ -209,29 +211,28 @@ impl<'a> From> for CLIError { } } -impl From> for CLIError { +impl From> for Errata { fn from(error: ValidationError) -> Self { - CLIError::new("Invalid Configuration").caused_by( + Errata::new("Invalid Configuration").caused_by( error .as_vec() .iter() .map(|cause| { - CLIError::new(cause.message.as_str()).trace(Vec::from(cause.trace.clone())) + Errata::new(cause.message.as_str()).trace(Vec::from(cause.trace.clone())) }) .collect(), ) } } -impl From> for CLIError { +impl From> for Errata { fn from(value: Box) -> Self { - CLIError::new(value.to_string().as_str()) + Errata::new(value.to_string().as_str()) } } #[cfg(test)] mod tests { - use pretty_assertions::assert_eq; use stripmargin::StripMargin; @@ -275,14 +276,14 @@ mod tests { #[test] fn test_title() { - let error = CLIError::new("Server could not be started"); + let error = Errata::new("Server could not be started"); let expected = r"Server could not be started".strip_margin(); assert_eq!(error.to_string(), expected); } #[test] fn test_title_description() { - let error = CLIError::new("Server could not be started") + let error = Errata::new("Server could not be started") .description("The port is already in use".to_string()); let expected = r"|Server could not be started: The port is already in use".strip_margin(); @@ -291,7 +292,7 @@ mod tests { #[test] fn test_title_description_trace() { - let error = CLIError::new("Server could not be started") + let error = Errata::new("Server could not be started") .description("The port is already in use".to_string()) .trace(vec!["@server".into(), "port".into()]); @@ -304,7 +305,7 @@ mod tests { #[test] fn test_title_trace_caused_by() { - let error = CLIError::new("Configuration Error").caused_by(vec![CLIError::new( + let error = Errata::new("Configuration Error").caused_by(vec![Errata::new( "Base URL needs to be specified", ) .trace(vec![ @@ -324,20 +325,20 @@ mod tests { #[test] fn test_title_trace_multiple_caused_by() { - let error = CLIError::new("Configuration Error").caused_by(vec![ - CLIError::new("Base URL needs to be specified").trace(vec![ + let error = Errata::new("Configuration Error").caused_by(vec![ + Errata::new("Base URL needs to be specified").trace(vec![ "User".into(), "posts".into(), "@http".into(), "baseURL".into(), ]), - CLIError::new("Base URL needs to be specified").trace(vec![ + Errata::new("Base URL needs to be specified").trace(vec![ "Post".into(), "users".into(), "@http".into(), "baseURL".into(), ]), - CLIError::new("Base URL needs to be specified") + Errata::new("Base URL needs to be specified") .description("Set `baseURL` in @http or @server directives".into()) .trace(vec![ "Query".into(), @@ -345,7 +346,7 @@ mod tests { "@http".into(), "baseURL".into(), ]), - CLIError::new("Base URL needs to be specified").trace(vec![ + Errata::new("Base URL needs to be specified").trace(vec![ "Query".into(), "posts".into(), "@http".into(), @@ -370,7 +371,7 @@ mod tests { .description("Set `baseURL` in @http or @server directives") .trace(vec!["Query", "users", "@http", "baseURL"]); let valid = ValidationError::from(cause); - let error = CLIError::from(valid); + let error = Errata::from(valid); let expected = r"|Invalid Configuration |Caused by: | • Base URL needs to be specified: Set `baseURL` in @http or @server directives [at Query.users.@http.baseURL]" @@ -381,12 +382,12 @@ mod tests { #[test] fn test_cli_error_identity() { - let cli_error = CLIError::new("Server could not be started") + let cli_error = Errata::new("Server could not be started") .description("The port is already in use".to_string()) .trace(vec!["@server".into(), "port".into()]); let anyhow_error: anyhow::Error = cli_error.clone().into(); - let actual = CLIError::from(anyhow_error); + let actual = Errata::from(anyhow_error); let expected = cli_error; assert_eq!(actual, expected); @@ -399,8 +400,8 @@ mod tests { ); let anyhow_error: anyhow::Error = validation_error.clone().into(); - let actual = CLIError::from(anyhow_error); - let expected = CLIError::from(validation_error); + let actual = Errata::from(anyhow_error); + let expected = Errata::from(validation_error); assert_eq!(actual, expected); } @@ -409,8 +410,8 @@ mod tests { fn test_generic_error() { let anyhow_error = anyhow::anyhow!("Some error msg"); - let actual: CLIError = CLIError::from(anyhow_error); - let expected = CLIError::new("Some error msg"); + let actual: Errata = Errata::from(anyhow_error); + let expected = Errata::new("Some error msg"); assert_eq!(actual, expected); } diff --git a/src/core/grpc/request.rs b/src/core/grpc/request.rs index c7e28b53e3..7bea4ccd68 100644 --- a/src/core/grpc/request.rs +++ b/src/core/grpc/request.rs @@ -160,7 +160,7 @@ mod tests { if let Err(err) = result { match err.downcast_ref::() { - Some(Error::GRPCError { + Some(Error::GRPC { grpc_code, grpc_description, grpc_status_message, diff --git a/src/core/http/response.rs b/src/core/http/response.rs index 2bb28e2b93..710ab0ac1a 100644 --- a/src/core/http/response.rs +++ b/src/core/http/response.rs @@ -102,9 +102,7 @@ impl Response { pub fn to_grpc_error(&self, operation: &ProtobufOperation) -> anyhow::Error { let grpc_status = match Status::from_header_map(&self.headers) { Some(status) => status, - None => { - return Error::IOException("Error while parsing upstream headers".to_owned()).into() - } + None => return Error::IO("Error while parsing upstream headers".to_owned()).into(), }; let mut obj: IndexMap = IndexMap::new(); @@ -136,7 +134,7 @@ impl Response { } obj.insert(Name::new("details"), ConstValue::List(status_details)); - let error = Error::GRPCError { + let error = Error::GRPC { grpc_code: grpc_status.code() as i32, grpc_description: grpc_status.code().description().to_owned(), grpc_status_message: grpc_status.message().to_owned(), diff --git a/src/core/ir/error.rs b/src/core/ir/error.rs index e08523bb71..2bcd513f83 100644 --- a/src/core/ir/error.rs +++ b/src/core/ir/error.rs @@ -1,48 +1,75 @@ +use std::fmt::Display; use std::sync::Arc; use async_graphql::{ErrorExtensions, Value as ConstValue}; use derive_more::From; use thiserror::Error; -use crate::core::{auth, cache, worker}; +use crate::core::{auth, cache, worker, Errata}; #[derive(From, Debug, Error, Clone)] pub enum Error { - #[error("IOException: {0}")] - IOException(String), + IO(String), - #[error("gRPC Error: status: {grpc_code}, description: `{grpc_description}`, message: `{grpc_status_message}`")] - GRPCError { + GRPC { grpc_code: i32, grpc_description: String, grpc_status_message: String, grpc_status_details: ConstValue, }, - #[error("APIValidationError: {0:?}")] - APIValidationError(Vec), + APIValidation(Vec), - #[error("ExprEvalError: {0}")] #[from(ignore)] - ExprEvalError(String), + ExprEval(String), - #[error("DeserializeError: {0}")] #[from(ignore)] - DeserializeError(String), + Deserialize(String), - #[error("Authentication Failure: {0}")] - AuthError(auth::error::Error), + Auth(auth::error::Error), - #[error("Worker Error: {0}")] - WorkerError(worker::Error), + Worker(worker::Error), - #[error("Cache Error: {0}")] - CacheError(cache::Error), + Cache(cache::Error), +} + +impl Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Errata::from(self.to_owned()).fmt(f) + } +} + +impl From for Errata { + fn from(value: Error) -> Self { + match value { + Error::IO(message) => Errata::new("IOException").description(message), + Error::GRPC { + grpc_code, + grpc_description, + grpc_status_message, + grpc_status_details: _, + } => Errata::new("gRPC Error") + .description(format!("status: {grpc_code}, description: `{grpc_description}`, message: `{grpc_status_message}`")), + Error::APIValidation(errors) => Errata::new("API Validation Error") + .caused_by(errors.iter().map(|e| Errata::new(e)).collect::>()), + Error::Deserialize(message) => { + Errata::new("Deserialization Error").description(message) + } + Error::ExprEval(message) => { + Errata::new("Expression Evaluation Error").description(message) + } + Error::Auth(err) => { + Errata::new("Authentication Failure").description(err.to_string()) + } + Error::Worker(err) => Errata::new("Worker Error").description(err.to_string()), + Error::Cache(err) => Errata::new("Cache Error").description(err.to_string()), + } + } } impl ErrorExtensions for Error { fn extend(&self) -> async_graphql::Error { async_graphql::Error::new(format!("{}", self)).extend_with(|_err, e| { - if let Error::GRPCError { + if let Error::GRPC { grpc_code, grpc_description, grpc_status_message, @@ -60,7 +87,7 @@ impl ErrorExtensions for Error { impl<'a> From> for Error { fn from(value: crate::core::valid::ValidationError<&'a str>) -> Self { - Error::APIValidationError( + Error::APIValidation( value .as_vec() .iter() @@ -74,7 +101,7 @@ impl From> for Error { fn from(error: Arc) -> Self { match error.downcast_ref::() { Some(err) => err.clone(), - None => Error::IOException(error.to_string()), + None => Error::IO(error.to_string()), } } } @@ -86,7 +113,7 @@ impl From for Error { fn from(value: anyhow::Error) -> Self { match value.downcast::() { Ok(err) => err, - Err(err) => Error::IOException(err.to_string()), + Err(err) => Error::IO(err.to_string()), } } } diff --git a/src/core/ir/eval.rs b/src/core/ir/eval.rs index 6bbf8871bd..c0d8c2356a 100644 --- a/src/core/ir/eval.rs +++ b/src/core/ir/eval.rs @@ -72,13 +72,10 @@ impl IR { if let Some(value) = map.get(&key) { Ok(ConstValue::String(value.to_owned())) } else { - Err(Error::ExprEvalError(format!( - "Can't find mapped key: {}.", - key - ))) + Err(Error::ExprEval(format!("Can't find mapped key: {}.", key))) } } else { - Err(Error::ExprEvalError( + Err(Error::ExprEval( "Mapped key must be string value.".to_owned(), )) } diff --git a/src/core/ir/eval_http.rs b/src/core/ir/eval_http.rs index bc99ef72a9..446eb9009a 100644 --- a/src/core/ir/eval_http.rs +++ b/src/core/ir/eval_http.rs @@ -236,7 +236,7 @@ pub fn parse_graphql_response( field_name: &str, ) -> Result { let res: async_graphql::Response = - from_value(res.body).map_err(|err| Error::DeserializeError(err.to_string()))?; + from_value(res.body).map_err(|err| Error::Deserialize(err.to_string()))?; for error in res.errors { ctx.add_error(error); diff --git a/src/core/mod.rs b/src/core/mod.rs index 5ca24f55e9..a885e5ab4f 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -12,6 +12,7 @@ pub mod data_loader; pub mod directive; pub mod document; pub mod endpoint; +mod errata; pub mod error; pub mod generator; pub mod graphql; @@ -47,6 +48,7 @@ use std::hash::Hash; use std::num::NonZeroU64; use async_graphql_value::ConstValue; +pub use errata::Errata; pub use error::{Error, Result}; use http::Response; use ir::model::IoId; diff --git a/src/main.rs b/src/main.rs index 3e912d28ad..b2dc98e001 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,8 +3,8 @@ use std::cell::Cell; -use tailcall::cli::CLIError; use tailcall::core::tracing::default_tracing_tailcall; +use tailcall::core::Errata; use tracing::subscriber::DefaultGuard; thread_local! { @@ -42,8 +42,8 @@ fn main() -> anyhow::Result<()> { match result { Ok(_) => {} Err(error) => { - // Ensure all errors are converted to CLIErrors before being printed. - let cli_error: CLIError = error.into(); + // Ensure all errors are converted to Errata before being printed. + let cli_error: Errata = error.into(); tracing::error!("{}", cli_error.color(true)); std::process::exit(exitcode::CONFIG); }