Skip to content

Commit

Permalink
Merge branch 'main' into fix/jit-union-shared-field
Browse files Browse the repository at this point in the history
  • Loading branch information
meskill authored Aug 27, 2024
2 parents 9014e29 + 452648e commit a9110b4
Show file tree
Hide file tree
Showing 23 changed files with 677 additions and 220 deletions.
37 changes: 28 additions & 9 deletions src/cli/generator/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,18 @@ pub struct Config<Status = UnResolved> {
#[serde(skip_serializing_if = "Option::is_none")]
pub preset: Option<PresetConfig>,
pub schema: Schema,
#[serde(skip_serializing_if = "TemplateString::is_empty")]
pub secret: TemplateString,
#[serde(skip_serializing_if = "Option::is_none")]
pub llm: Option<LLMConfig>,
}

#[derive(Deserialize, Serialize, Debug, Default, PartialEq, Clone)]
#[serde(rename_all = "camelCase")]
#[serde(deny_unknown_fields)]
pub struct LLMConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub secret: Option<TemplateString>,
}

#[derive(Clone, Deserialize, Serialize, Debug, Default)]
Expand Down Expand Up @@ -273,13 +283,17 @@ impl Config {
.collect::<anyhow::Result<Vec<Input<Resolved>>>>()?;

let output = self.output.resolve(parent_dir)?;
let llm = self.llm.map(|llm| {
let secret = llm.secret.map(|s| s.resolve(&reader_context));
LLMConfig { model: llm.model, secret }
});

Ok(Config {
inputs,
output,
schema: self.schema,
preset: self.preset,
secret: self.secret.resolve(&reader_context),
llm,
})
}
}
Expand Down Expand Up @@ -419,7 +433,7 @@ mod tests {
fn test_raise_error_unknown_field_at_root_level() {
let json = r#"{"input": "value"}"#;
let expected_error =
"unknown field `input`, expected one of `inputs`, `output`, `preset`, `schema`, `secret` at line 1 column 8";
"unknown field `input`, expected one of `inputs`, `output`, `preset`, `schema`, `llm` at line 1 column 8";
assert_deserialization_error(json, expected_error);
}

Expand Down Expand Up @@ -492,7 +506,7 @@ mod tests {
}

#[test]
fn test_secret() {
fn test_llm_config() {
let mut env_vars = HashMap::new();
let token = "eyJhbGciOiJIUzI1NiIsInR5";
env_vars.insert("TAILCALL_SECRET".to_owned(), token.to_owned());
Expand All @@ -506,12 +520,17 @@ mod tests {
headers: Default::default(),
};

let config =
Config::default().secret(TemplateString::parse("{{.env.TAILCALL_SECRET}}").unwrap());
let config = Config::default().llm(Some(LLMConfig {
model: Some("gpt-3.5-turbo".to_string()),
secret: Some(TemplateString::parse("{{.env.TAILCALL_SECRET}}").unwrap()),
}));
let resolved_config = config.into_resolved("", reader_ctx).unwrap();

let actual = resolved_config.secret;
let expected = TemplateString::from("eyJhbGciOiJIUzI1NiIsInR5");
let actual = resolved_config.llm;
let expected = Some(LLMConfig {
model: Some("gpt-3.5-turbo".to_string()),
secret: Some(TemplateString::from(token)),
});

assert_eq!(actual, expected);
}
Expand Down
26 changes: 11 additions & 15 deletions src/cli/generator/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use hyper::HeaderMap;
use inquire::Confirm;
use pathdiff::diff_paths;

use super::config::{Config, Resolved, Source};
use super::config::{Config, LLMConfig, Resolved, Source};
use super::source::ConfigSource;
use crate::cli::llm::InferTypeName;
use crate::core::config::transformer::{Preset, RenameTypes};
Expand Down Expand Up @@ -164,7 +164,7 @@ impl Generator {
let query_type = config.schema.query.clone();
let mutation_type_name = config.schema.mutation.clone();

let secret = config.secret.clone();
let llm = config.llm.clone();
let preset = config.preset.clone().unwrap_or_default();
let preset: Preset = preset.validate_into().to_result()?;
let input_samples = self.resolve_io(config).await?;
Expand All @@ -180,19 +180,15 @@ impl Generator {
let mut config = config_gen.mutation(mutation_type_name).generate(true)?;

if infer_type_names {
let key = if !secret.is_empty() {
Some(secret.to_string())
} else {
None
};

let mut llm_gen = InferTypeName::new(key);
let suggested_names = llm_gen.generate(config.config()).await?;
let cfg = RenameTypes::new(suggested_names.iter())
.transform(config.config().to_owned())
.to_result()?;

config = ConfigModule::from(cfg);
if let Some(LLMConfig { model: Some(model), secret }) = llm {
let mut llm_gen = InferTypeName::new(model, secret.map(|s| s.to_string()));
let suggested_names = llm_gen.generate(config.config()).await?;
let cfg = RenameTypes::new(suggested_names.iter())
.transform(config.config().to_owned())
.to_result()?;

config = ConfigModule::from(cfg);
}
}

self.write(&config, &path).await?;
Expand Down
15 changes: 5 additions & 10 deletions src/cli/llm/infer_type_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@ use genai::chat::{ChatMessage, ChatRequest, ChatResponse};
use serde::{Deserialize, Serialize};
use serde_json::json;

use super::model::groq;
use super::{Error, Result, Wizard};
use crate::core::config::Config;
use crate::core::Mustache;

#[derive(Default)]
pub struct InferTypeName {
secret: Option<String>,
wizard: Wizard<Question, Answer>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -74,14 +72,11 @@ impl TryInto<ChatRequest> for Question {
}

impl InferTypeName {
pub fn new(secret: Option<String>) -> InferTypeName {
Self { secret }
pub fn new(model: String, secret: Option<String>) -> InferTypeName {
Self { wizard: Wizard::new(model, secret) }
}
pub async fn generate(&mut self, config: &Config) -> Result<HashMap<String, String>> {
let secret = self.secret.as_ref().map(|s| s.to_owned());

let wizard: Wizard<Question, Answer> = Wizard::new(groq::LLAMA38192, secret);

pub async fn generate(&mut self, config: &Config) -> Result<HashMap<String, String>> {
let mut new_name_mappings: HashMap<String, String> = HashMap::new();

// removed root type from types.
Expand All @@ -104,7 +99,7 @@ impl InferTypeName {

let mut delay = 3;
loop {
let answer = wizard.ask(question.clone()).await;
let answer = self.wizard.ask(question.clone()).await;
match answer {
Ok(answer) => {
let name = &answer.suggestions.join(", ");
Expand Down
1 change: 0 additions & 1 deletion src/cli/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ pub mod infer_type_name;
pub use error::Error;
use error::Result;
pub use infer_type_name::InferTypeName;
mod model;
mod wizard;

pub use wizard::Wizard;
73 changes: 0 additions & 73 deletions src/cli/llm/model.rs

This file was deleted.

5 changes: 2 additions & 3 deletions src/cli/llm/wizard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@ use genai::resolver::AuthResolver;
use genai::Client;

use super::Result;
use crate::cli::llm::model::Model;

#[derive(Setters, Clone)]
pub struct Wizard<Q, A> {
client: Client,
model: Model,
model: String,
_q: std::marker::PhantomData<Q>,
_a: std::marker::PhantomData<A>,
}

impl<Q, A> Wizard<Q, A> {
pub fn new(model: Model, secret: Option<String>) -> Self {
pub fn new(model: String, secret: Option<String>) -> Self {
let mut config = genai::adapter::AdapterConfig::default();
if let Some(key) = secret {
config = config.with_auth_resolver(AuthResolver::from_key_value(key));
Expand Down
2 changes: 0 additions & 2 deletions src/cli/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
pub mod command;
mod error;
mod fmt;
pub mod generator;
#[cfg(feature = "js")]
Expand All @@ -11,5 +10,4 @@ pub mod server;
mod tc;
pub mod telemetry;
pub(crate) mod update_checker;
pub use error::CLIError;
pub use tc::run::run;
7 changes: 3 additions & 4 deletions src/cli/runtime/file.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use tokio::io::{AsyncReadExt, AsyncWriteExt};

use crate::cli::CLIError;
use crate::core::FileIO;
use crate::core::{Errata, FileIO};

#[derive(Clone)]
pub struct NativeFileIO {}
Expand Down Expand Up @@ -29,7 +28,7 @@ async fn write<'a>(path: &'a str, content: &'a [u8]) -> anyhow::Result<()> {
impl FileIO for NativeFileIO {
async fn write<'a>(&'a self, path: &'a str, content: &'a [u8]) -> anyhow::Result<()> {
write(path, content).await.map_err(|err| {
CLIError::new(format!("Failed to write file: {}", path).as_str())
Errata::new(format!("Failed to write file: {}", path).as_str())
.description(err.to_string())
})?;
tracing::info!("File write: {} ... ok", path);
Expand All @@ -38,7 +37,7 @@ impl FileIO for NativeFileIO {

async fn read<'a>(&'a self, path: &'a str) -> anyhow::Result<String> {
let content = read(path).await.map_err(|err| {
CLIError::new(format!("Failed to read file: {}", path).as_str())
Errata::new(format!("Failed to read file: {}", path).as_str())
.description(err.to_string())
})?;
tracing::info!("File read: {} ... ok", path);
Expand Down
6 changes: 3 additions & 3 deletions src/cli/server/http_1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use hyper::service::{make_service_fn, service_fn};
use tokio::sync::oneshot;

use super::server_config::ServerConfig;
use crate::cli::CLIError;
use crate::core::async_graphql_hyper::{GraphQLBatchRequest, GraphQLRequest};
use crate::core::http::handle_request;
use crate::core::Errata;

pub async fn start_http_1(
sc: Arc<ServerConfig>,
Expand All @@ -31,7 +31,7 @@ pub async fn start_http_1(
}
});
let builder = hyper::Server::try_bind(&addr)
.map_err(CLIError::from)?
.map_err(Errata::from)?
.http1_pipeline_flush(sc.app_ctx.blueprint.server.pipeline_flush);
super::log_launch(sc.as_ref());

Expand All @@ -48,7 +48,7 @@ pub async fn start_http_1(
builder.serve(make_svc_single_req).await
};

let result = server.map_err(CLIError::from);
let result = server.map_err(Errata::from);

Ok(result?)
}
4 changes: 2 additions & 2 deletions src/cli/server/http_2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use tokio::sync::oneshot;

use super::server_config::ServerConfig;
use crate::cli::CLIError;
use crate::core::async_graphql_hyper::{GraphQLBatchRequest, GraphQLRequest};
use crate::core::http::handle_request;
use crate::core::Errata;

pub async fn start_http_2(
sc: Arc<ServerConfig>,
Expand Down Expand Up @@ -60,7 +60,7 @@ pub async fn start_http_2(
builder.serve(make_svc_single_req).await
};

let result = server.map_err(CLIError::from);
let result = server.map_err(Errata::from);

Ok(result?)
}
4 changes: 2 additions & 2 deletions src/cli/server/http_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use super::http_1::start_http_1;
use super::http_2::start_http_2;
use super::server_config::ServerConfig;
use crate::cli::telemetry::init_opentelemetry;
use crate::cli::CLIError;
use crate::core::blueprint::{Blueprint, Http};
use crate::core::config::ConfigModule;
use crate::core::Errata;

pub struct Server {
config_module: ConfigModule,
Expand All @@ -32,7 +32,7 @@ impl Server {

/// Starts the server in the current Runtime
pub async fn start(self) -> Result<()> {
let blueprint = Blueprint::try_from(&self.config_module).map_err(CLIError::from)?;
let blueprint = Blueprint::try_from(&self.config_module).map_err(Errata::from)?;
let endpoints = self.config_module.extensions().endpoint_set.clone();
let server_config = Arc::new(ServerConfig::new(blueprint.clone(), endpoints).await?);

Expand Down
Loading

0 comments on commit a9110b4

Please sign in to comment.