From 29a1137409e8cae5d6be942c18dd131362a6083b Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 26 Jun 2024 23:21:39 +0000 Subject: [PATCH] feat: use model name as adapter id in chat endpoints --- router/src/lib.rs | 4 ++-- router/src/server.rs | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 126726c6a58..b7d77f31492 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -370,7 +370,7 @@ pub struct CompletionRequest { /// UNUSED #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. - pub model: String, + pub model: Option, /// The prompt to generate completions for. #[schema(example = "What is Deep Learning?")] @@ -706,7 +706,7 @@ impl ChatCompletionChunk { pub(crate) struct ChatRequest { #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] /// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. - pub model: String, + pub model: Option, /// A list of messages comprising the conversation so far. #[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")] diff --git a/router/src/server.rs b/router/src/server.rs index 7f15bfdd6a6..0738c56fa88 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -606,6 +606,7 @@ async fn completions( metrics::increment_counter!("tgi_request_count"); let CompletionRequest { + model, max_tokens, seed, stop, @@ -673,7 +674,7 @@ async fn completions( seed, top_n_tokens: None, grammar: None, - ..Default::default() + adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from), }, }) .collect(); @@ -1011,6 +1012,7 @@ async fn chat_completions( let span = tracing::Span::current(); metrics::increment_counter!("tgi_request_count"); let ChatRequest { + model, logprobs, max_tokens, messages, @@ -1116,7 +1118,7 @@ async fn chat_completions( seed, top_n_tokens: req.top_logprobs, grammar, - ..Default::default() + adapter_id: model.filter(|m| *m != "tgi").map(String::from), }, };