diff --git a/src/cli/generator/config.rs b/src/cli/generator/config.rs index 08d9fd40da..dab568cb42 100644 --- a/src/cli/generator/config.rs +++ b/src/cli/generator/config.rs @@ -25,8 +25,18 @@ pub struct Config { #[serde(skip_serializing_if = "Option::is_none")] pub preset: Option, pub schema: Schema, - #[serde(skip_serializing_if = "TemplateString::is_empty")] - pub secret: TemplateString, + #[serde(skip_serializing_if = "Option::is_none")] + pub llm: Option, +} + +#[derive(Deserialize, Serialize, Debug, Default, PartialEq, Clone)] +#[serde(rename_all = "camelCase")] +#[serde(deny_unknown_fields)] +pub struct LLMConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub secret: Option, } #[derive(Clone, Deserialize, Serialize, Debug, Default)] @@ -273,13 +283,17 @@ impl Config { .collect::>>>()?; let output = self.output.resolve(parent_dir)?; + let llm = self.llm.map(|llm| { + let secret = llm.secret.map(|s| s.resolve(&reader_context)); + LLMConfig { model: llm.model, secret } + }); Ok(Config { inputs, output, schema: self.schema, preset: self.preset, - secret: self.secret.resolve(&reader_context), + llm, }) } } @@ -419,7 +433,7 @@ mod tests { fn test_raise_error_unknown_field_at_root_level() { let json = r#"{"input": "value"}"#; let expected_error = - "unknown field `input`, expected one of `inputs`, `output`, `preset`, `schema`, `secret` at line 1 column 8"; + "unknown field `input`, expected one of `inputs`, `output`, `preset`, `schema`, `llm` at line 1 column 8"; assert_deserialization_error(json, expected_error); } @@ -492,7 +506,7 @@ mod tests { } #[test] - fn test_secret() { + fn test_llm_config() { let mut env_vars = HashMap::new(); let token = "eyJhbGciOiJIUzI1NiIsInR5"; env_vars.insert("TAILCALL_SECRET".to_owned(), token.to_owned()); @@ -506,12 +520,17 @@ mod tests { headers: Default::default(), }; - let config = - Config::default().secret(TemplateString::parse("{{.env.TAILCALL_SECRET}}").unwrap()); + let config = Config::default().llm(Some(LLMConfig { + model: Some("gpt-3.5-turbo".to_string()), + secret: Some(TemplateString::parse("{{.env.TAILCALL_SECRET}}").unwrap()), + })); let resolved_config = config.into_resolved("", reader_ctx).unwrap(); - let actual = resolved_config.secret; - let expected = TemplateString::from("eyJhbGciOiJIUzI1NiIsInR5"); + let actual = resolved_config.llm; + let expected = Some(LLMConfig { + model: Some("gpt-3.5-turbo".to_string()), + secret: Some(TemplateString::from(token)), + }); assert_eq!(actual, expected); } diff --git a/src/cli/generator/generator.rs b/src/cli/generator/generator.rs index b54ce98224..f34ce02616 100644 --- a/src/cli/generator/generator.rs +++ b/src/cli/generator/generator.rs @@ -6,7 +6,7 @@ use hyper::HeaderMap; use inquire::Confirm; use pathdiff::diff_paths; -use super::config::{Config, Resolved, Source}; +use super::config::{Config, LLMConfig, Resolved, Source}; use super::source::ConfigSource; use crate::cli::llm::InferTypeName; use crate::core::config::transformer::{Preset, RenameTypes}; @@ -164,7 +164,7 @@ impl Generator { let query_type = config.schema.query.clone(); let mutation_type_name = config.schema.mutation.clone(); - let secret = config.secret.clone(); + let llm = config.llm.clone(); let preset = config.preset.clone().unwrap_or_default(); let preset: Preset = preset.validate_into().to_result()?; let input_samples = self.resolve_io(config).await?; @@ -180,19 +180,15 @@ impl Generator { let mut config = config_gen.mutation(mutation_type_name).generate(true)?; if infer_type_names { - let key = if !secret.is_empty() { - Some(secret.to_string()) - } else { - None - }; - - let mut llm_gen = InferTypeName::new(key); - let suggested_names = llm_gen.generate(config.config()).await?; - let cfg = RenameTypes::new(suggested_names.iter()) - .transform(config.config().to_owned()) - .to_result()?; - - config = ConfigModule::from(cfg); + if let Some(LLMConfig { model: Some(model), secret }) = llm { + let mut llm_gen = InferTypeName::new(model, secret.map(|s| s.to_string())); + let suggested_names = llm_gen.generate(config.config()).await?; + let cfg = RenameTypes::new(suggested_names.iter()) + .transform(config.config().to_owned()) + .to_result()?; + + config = ConfigModule::from(cfg); + } } self.write(&config, &path).await?; diff --git a/src/cli/llm/infer_type_name.rs b/src/cli/llm/infer_type_name.rs index 61ae8bb360..cfba78ff77 100644 --- a/src/cli/llm/infer_type_name.rs +++ b/src/cli/llm/infer_type_name.rs @@ -4,14 +4,12 @@ use genai::chat::{ChatMessage, ChatRequest, ChatResponse}; use serde::{Deserialize, Serialize}; use serde_json::json; -use super::model::groq; use super::{Error, Result, Wizard}; use crate::core::config::Config; use crate::core::Mustache; -#[derive(Default)] pub struct InferTypeName { - secret: Option, + wizard: Wizard, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -74,14 +72,11 @@ impl TryInto for Question { } impl InferTypeName { - pub fn new(secret: Option) -> InferTypeName { - Self { secret } + pub fn new(model: String, secret: Option) -> InferTypeName { + Self { wizard: Wizard::new(model, secret) } } - pub async fn generate(&mut self, config: &Config) -> Result> { - let secret = self.secret.as_ref().map(|s| s.to_owned()); - - let wizard: Wizard = Wizard::new(groq::LLAMA38192, secret); + pub async fn generate(&mut self, config: &Config) -> Result> { let mut new_name_mappings: HashMap = HashMap::new(); // removed root type from types. @@ -104,7 +99,7 @@ impl InferTypeName { let mut delay = 3; loop { - let answer = wizard.ask(question.clone()).await; + let answer = self.wizard.ask(question.clone()).await; match answer { Ok(answer) => { let name = &answer.suggestions.join(", "); diff --git a/src/cli/llm/mod.rs b/src/cli/llm/mod.rs index ef63fb9d4a..40c0dce610 100644 --- a/src/cli/llm/mod.rs +++ b/src/cli/llm/mod.rs @@ -3,7 +3,6 @@ pub mod infer_type_name; pub use error::Error; use error::Result; pub use infer_type_name::InferTypeName; -mod model; mod wizard; pub use wizard::Wizard; diff --git a/src/cli/llm/model.rs b/src/cli/llm/model.rs deleted file mode 100644 index a3da95d8eb..0000000000 --- a/src/cli/llm/model.rs +++ /dev/null @@ -1,73 +0,0 @@ -#![allow(unused)] - -use std::borrow::Cow; -use std::fmt::{Display, Formatter}; -use std::marker::PhantomData; - -use derive_setters::Setters; -use genai::adapter::AdapterKind; - -#[derive(Clone)] -pub struct Model(&'static str); - -pub mod open_ai { - use super::*; - pub const GPT3_5_TURBO: Model = Model("gp-3.5-turbo"); - pub const GPT4: Model = Model("gpt-4"); - pub const GPT4_TURBO: Model = Model("gpt-4-turbo"); - pub const GPT4O_MINI: Model = Model("gpt-4o-mini"); - pub const GPT4O: Model = Model("gpt-4o"); -} - -pub mod ollama { - use super::*; - pub const GEMMA2B: Model = Model("gemma:2b"); -} - -pub mod anthropic { - use super::*; - pub const CLAUDE3_HAIKU_20240307: Model = Model("claude-3-haiku-20240307"); - pub const CLAUDE3_SONNET_20240229: Model = Model("claude-3-sonnet-20240229"); - pub const CLAUDE3_OPUS_20240229: Model = Model("claude-3-opus-20240229"); - pub const CLAUDE35_SONNET_20240620: Model = Model("claude-3-5-sonnet-20240620"); -} - -pub mod cohere { - use super::*; - pub const COMMAND_LIGHT_NIGHTLY: Model = Model("command-light-nightly"); - pub const COMMAND_LIGHT: Model = Model("command-light"); - pub const COMMAND_NIGHTLY: Model = Model("command-nightly"); - pub const COMMAND: Model = Model("command"); - pub const COMMAND_R: Model = Model("command-r"); - pub const COMMAND_R_PLUS: Model = Model("command-r-plus"); -} - -pub mod gemini { - use super::*; - pub const GEMINI15_FLASH_LATEST: Model = Model("gemini-1.5-flash-latest"); - pub const GEMINI10_PRO: Model = Model("gemini-1.0-pro"); - pub const GEMINI15_FLASH: Model = Model("gemini-1.5-flash"); - pub const GEMINI15_PRO: Model = Model("gemini-1.5-pro"); -} - -pub mod groq { - use super::*; - pub const LLAMA708192: Model = Model("llama3-70b-8192"); - pub const LLAMA38192: Model = Model("llama3-8b-8192"); - pub const LLAMA_GROQ8B8192_TOOL_USE_PREVIEW: Model = - Model("llama3-groq-8b-8192-tool-use-preview"); - pub const LLAMA_GROQ70B8192_TOOL_USE_PREVIEW: Model = - Model("llama3-groq-70b-8192-tool-use-preview"); - pub const GEMMA29B_IT: Model = Model("gemma2-9b-it"); - pub const GEMMA7B_IT: Model = Model("gemma-7b-it"); - pub const MIXTRAL_8X7B32768: Model = Model("mixtral-8x7b-32768"); - pub const LLAMA8B_INSTANT: Model = Model("llama-3.1-8b-instant"); - pub const LLAMA70B_VERSATILE: Model = Model("llama-3.1-70b-versatile"); - pub const LLAMA405B_REASONING: Model = Model("llama-3.1-405b-reasoning"); -} - -impl Model { - pub fn as_str(&self) -> &'static str { - self.0 - } -} diff --git a/src/cli/llm/wizard.rs b/src/cli/llm/wizard.rs index 1604d7f15f..46d7a18624 100644 --- a/src/cli/llm/wizard.rs +++ b/src/cli/llm/wizard.rs @@ -5,18 +5,17 @@ use genai::resolver::AuthResolver; use genai::Client; use super::Result; -use crate::cli::llm::model::Model; #[derive(Setters, Clone)] pub struct Wizard { client: Client, - model: Model, + model: String, _q: std::marker::PhantomData, _a: std::marker::PhantomData, } impl Wizard { - pub fn new(model: Model, secret: Option) -> Self { + pub fn new(model: String, secret: Option) -> Self { let mut config = genai::adapter::AdapterConfig::default(); if let Some(key) = secret { config = config.with_auth_resolver(AuthResolver::from_key_value(key)); diff --git a/tests/cli/snapshots/cli_spec__test__generator_spec__tests__cli__fixtures__generator__gen_deezer.json.snap b/tests/cli/snapshots/cli_spec__test__generator_spec__tests__cli__fixtures__generator__gen_deezer.json.snap new file mode 100644 index 0000000000..3f1b48ca31 --- /dev/null +++ b/tests/cli/snapshots/cli_spec__test__generator_spec__tests__cli__fixtures__generator__gen_deezer.json.snap @@ -0,0 +1,419 @@ +--- +source: tests/cli/gen.rs +expression: config.to_sdl() +--- +schema @server @upstream(baseURL: "https://api.deezer.com") { + query: Query +} + +type Album { + cover: String + cover_big: String + cover_medium: String + cover_small: String + cover_xl: String + id: Int + md5_image: String + title: String + tracklist: String + type: String +} + +type Artist { + id: Int + link: String + name: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + radio: Boolean + tracklist: String + type: String +} + +type Chart { + albums: T167 + artists: T169 + playlists: T181 + podcasts: Podcast + tracks: T166 +} + +type Contributor { + id: Int + link: String + name: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + radio: Boolean + role: String + share: String + tracklist: String + type: String +} + +type Datum { + album: Album + artist: T42 + duration: Int + explicit_content_cover: Int + explicit_content_lyrics: Int + explicit_lyrics: Boolean + id: Int + link: String + md5_image: String + preview: String + rank: Int + readable: Boolean + time_add: Int + title: String + title_short: String + title_version: String + type: String +} + +type Editorial { + data: [T185] + total: Int +} + +type Genre { + data: [T5] +} + +type Playlist { + checksum: String + collaborative: Boolean + creation_date: String + creator: User + description: String + duration: Int + fans: Int + id: Int + is_loved_track: Boolean + link: String + md5_image: String + nb_tracks: Int + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_type: String + picture_xl: String + public: Boolean + share: String + title: String + tracklist: String + tracks: Track + type: String +} + +type Podcast { + data: [T182] + total: Int +} + +type Query { + album(p1: Int!): T39 @http(path: "/album/{{.args.p1}}") + artist(p1: Int!): T40 @http(path: "/artist/{{.args.p1}}") + chart: Chart @http(path: "/chart") + editorial: Editorial @http(path: "/editorial") + playlist(p1: Int!): Playlist @http(path: "/playlist/{{.args.p1}}") + search(q: String): Search @http(path: "/search", query: [{key: "q", value: "{{.args.q}}"}]) + track(p1: Int!): T4 @http(path: "/track/{{.args.p1}}") + user(p1: Int!): T187 @http(path: "/user/{{.args.p1}}") +} + +type Search { + data: [JSON] + next: String + total: Int +} + +type T165 { + album: Album + artist: Artist + duration: Int + explicit_content_cover: Int + explicit_content_lyrics: Int + explicit_lyrics: Boolean + id: Int + link: String + md5_image: String + position: Int + preview: String + rank: Int + title: String + title_short: String + title_version: String + type: String +} + +type T166 { + data: [T165] + total: Int +} + +type T167 { + data: [JSON] + total: Int +} + +type T168 { + id: Int + link: String + name: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + position: Int + radio: Boolean + tracklist: String + type: String +} + +type T169 { + data: [T168] + total: Int +} + +type T180 { + checksum: String + creation_date: String + id: Int + link: String + md5_image: String + nb_tracks: Int + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_type: String + picture_xl: String + public: Boolean + title: String + tracklist: String + type: String + user: User +} + +type T181 { + data: [T180] + total: Int +} + +type T182 { + available: Boolean + description: String + fans: Int + id: Int + link: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + share: String + title: String + type: String +} + +type T185 { + id: Int + name: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + type: String +} + +type T187 { + country: String + id: Int + link: String + name: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + tracklist: String + type: String +} + +type T2 { + id: Int + link: String + name: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + radio: Boolean + share: String + tracklist: String + type: String +} + +type T3 { + cover: String + cover_big: String + cover_medium: String + cover_small: String + cover_xl: String + id: Int + link: String + md5_image: String + release_date: String + title: String + tracklist: String + type: String +} + +type T37 { + album: Album + artist: User + duration: Int + explicit_content_cover: Int + explicit_content_lyrics: Int + explicit_lyrics: Boolean + id: Int + link: String + md5_image: String + preview: String + rank: Int + readable: Boolean + title: String + title_short: String + title_version: String + type: String +} + +type T38 { + data: [T37] +} + +type T39 { + artist: T8 + available: Boolean + contributors: [Contributor] + cover: String + cover_big: String + cover_medium: String + cover_small: String + cover_xl: String + duration: Int + explicit_content_cover: Int + explicit_content_lyrics: Int + explicit_lyrics: Boolean + fans: Int + genre_id: Int + genres: Genre + id: Int + label: String + link: String + md5_image: String + nb_tracks: Int + record_type: String + release_date: String + share: String + title: String + tracklist: String + tracks: T38 + type: String + upc: String +} + +type T4 { + album: T3 + artist: T2 + available_countries: [String] + bpm: Int + contributors: [Contributor] + disk_number: Int + duration: Int + explicit_content_cover: Int + explicit_content_lyrics: Int + explicit_lyrics: Boolean + gain: Int + id: Int + isrc: String + link: String + md5_image: String + preview: String + rank: Int + readable: Boolean + release_date: String + share: String + title: String + title_short: String + title_version: String + track_position: Int + type: String +} + +type T40 { + id: Int + link: String + name: String + nb_album: Int + nb_fan: Int + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + radio: Boolean + share: String + tracklist: String + type: String +} + +type T42 { + id: Int + link: String + name: String + tracklist: String + type: String +} + +type T5 { + id: Int + name: String + picture: String + type: String +} + +type T8 { + id: Int + name: String + picture: String + picture_big: String + picture_medium: String + picture_small: String + picture_xl: String + tracklist: String + type: String +} + +type Track { + checksum: String + data: [Datum] +} + +type User { + id: Int + name: String + tracklist: String + type: String +} diff --git a/tests/cli/snapshots/cli_spec__test__generator_spec__tests__cli__fixtures__generator__gen_jsonplaceholder.json.snap b/tests/cli/snapshots/cli_spec__test__generator_spec__tests__cli__fixtures__generator__gen_jsonplaceholder.json.snap new file mode 100644 index 0000000000..47bdfb6296 --- /dev/null +++ b/tests/cli/snapshots/cli_spec__test__generator_spec__tests__cli__fixtures__generator__gen_jsonplaceholder.json.snap @@ -0,0 +1,81 @@ +--- +source: tests/cli/gen.rs +expression: config.to_sdl() +--- +schema @server @upstream(baseURL: "https://jsonplaceholder.typicode.com") { + query: Query +} + +type Address { + city: String + geo: Geo + street: String + suite: String + zipcode: String +} + +type Comment { + body: String + email: String + id: Int + name: String + postId: Int +} + +type Company { + bs: String + catchPhrase: String + name: String +} + +type Geo { + lat: String + lng: String +} + +type Photo { + albumId: Int + id: Int + thumbnailUrl: String + title: String + url: String +} + +type Post { + body: String + id: Int + title: String + userId: Int +} + +type Query { + comment(p1: Int!): Comment @http(path: "/comments/{{.args.p1}}") + comments: [Comment] @http(path: "/comments") + photo(p1: Int!): Photo @http(path: "/photos/{{.args.p1}}") + photos: [Photo] @http(path: "/photos") + post(p1: Int!): Post @http(path: "/posts/{{.args.p1}}") + postComments(postId: Int): [Comment] @http(path: "/comments", query: [{key: "postId", value: "{{.args.postId}}"}]) + posts: [Post] @http(path: "/posts") + todo(p1: Int!): Todo @http(path: "/todos/{{.args.p1}}") + todos: [Todo] @http(path: "/todos") + user(p1: Int!): User @http(path: "/users/{{.args.p1}}") + users: [User] @http(path: "/users") +} + +type Todo { + completed: Boolean + id: Int + title: String + userId: Int +} + +type User { + address: Address + company: Company + email: String + id: Int + name: String + phone: String + username: String + website: String +}