From e62593115663a6f01976ea2d611f205863731db1 Mon Sep 17 00:00:00 2001 From: jeadie Date: Fri, 4 Oct 2024 15:46:45 +1000 Subject: [PATCH 1/2] handle assistant messages with 'tool_calls' when used in chat_template --- mistralrs-core/src/pipeline/mod.rs | 17 +++++++++-------- mistralrs-core/src/pipeline/processing.rs | 9 ++++++--- mistralrs-core/src/request.rs | 3 ++- mistralrs-pyo3/src/lib.rs | 13 +++++++------ mistralrs-server/src/chat_completion.rs | 15 ++++++++------- mistralrs-server/src/interactive_mode.rs | 5 +++-- mistralrs/examples/lower_level/tools/main.rs | 2 +- mistralrs/src/messages.rs | 7 ++++--- 8 files changed, 40 insertions(+), 31 deletions(-) diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index ad6927bec..e93a71e51 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -539,6 +539,7 @@ mod tests { use crate::MessageContent; use either::Either; use indexmap::IndexMap; + use serde_json::Value; macro_rules! hashmap { (@single $($x:tt)*) => (()); @@ -550,7 +551,7 @@ mod tests { let _cap = hashmap!(@count $($key),*); let mut _map = ::indexmap::IndexMap::with_capacity(_cap); $( - let _ = _map.insert($key, $value); + let _ = _map.insert($key, Value::String($value)); )* _map } @@ -655,7 +656,7 @@ mod tests { ]; let mut inputs = Vec::new(); for [role, content] in messages { - let mut message: IndexMap>>> = + let mut message: IndexMap>>> = IndexMap::new(); message.insert("role".to_string(), Either::Left(role.to_string())); message.insert("content".to_string(), Either::Left(content.to_string())); @@ -689,7 +690,7 @@ mod tests { let mut inputs = Vec::new(); - let mut message: IndexMap>>> = + let mut message: IndexMap>>> = IndexMap::new(); message.insert("role".to_string(), Either::Left("system".to_string())); message.insert( @@ -701,7 +702,7 @@ mod tests { ); inputs.push(message); - let mut message: IndexMap>>> = + let mut message: IndexMap>>> = IndexMap::new(); message.insert("role".to_string(), Either::Left("user".to_string())); message.insert( @@ -718,7 +719,7 @@ mod tests { ); inputs.push(message); - let mut message: IndexMap>>> = + let mut message: IndexMap>>> = IndexMap::new(); message.insert("role".to_string(), Either::Left("assistant".to_string())); message.insert( @@ -730,7 +731,7 @@ mod tests { ); inputs.push(message); - let mut message: IndexMap>>> = + let mut message: IndexMap>>> = IndexMap::new(); message.insert("role".to_string(), Either::Left("user".to_string())); message.insert( @@ -747,7 +748,7 @@ mod tests { ); inputs.push(message); - let mut message: IndexMap>>> = + let mut message: IndexMap>>> = IndexMap::new(); message.insert("role".to_string(), Either::Left("assistant".to_string())); message.insert( @@ -759,7 +760,7 @@ mod tests { ); inputs.push(message); - let mut message: IndexMap>>> = + let mut message: IndexMap>>> = IndexMap::new(); message.insert("role".to_string(), Either::Left("user".to_string())); message.insert( diff --git a/mistralrs-core/src/pipeline/processing.rs b/mistralrs-core/src/pipeline/processing.rs index e1034b84a..91d5d2e8f 100644 --- a/mistralrs-core/src/pipeline/processing.rs +++ b/mistralrs-core/src/pipeline/processing.rs @@ -84,8 +84,11 @@ pub(crate) fn apply_chat_template( 'outer: for content_row in rv { for (content_k, content_v) in content_row { if content_k == "text" { - new_message.insert(k, Either::Left(content_v)); - break 'outer; + if let Some(content_str) = content_v.as_str() { + new_message.insert(k, Either::Left(content_str.to_string())); + break 'outer; + } + } } } @@ -149,6 +152,6 @@ impl Processor for BasicProcessor { &[] } fn template_action(&self) -> MessagesAction { - MessagesAction::FlattenOnlyText + MessagesAction::Keep } } diff --git a/mistralrs-core/src/request.rs b/mistralrs-core/src/request.rs index 4371afb1b..65cde9e56 100644 --- a/mistralrs-core/src/request.rs +++ b/mistralrs-core/src/request.rs @@ -2,6 +2,7 @@ use either::Either; use indexmap::IndexMap; use mistralrs_quant::IsqType; use serde::{Deserialize, Serialize}; +use serde_json::Value; use crate::{ response::Response, @@ -28,7 +29,7 @@ pub enum ImageGenerationResponseFormat { B64Json, } -pub type MessageContent = Either>>; +pub type MessageContent = Either>>; #[derive(Clone, Debug)] /// Message or messages for a [`Request`]. diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 0eae3a3c6..62a4c3033 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -5,6 +5,7 @@ use anymoe::{AnyMoeConfig, AnyMoeExpertType}; use either::Either; use indexmap::IndexMap; use requests::{ChatCompletionRequest, CompletionRequest, ToolChoice}; +use serde_json::Value; use std::{ cell::RefCell, collections::HashMap, @@ -681,7 +682,7 @@ impl Runner { Either::Left(content) => { let mut message_map: IndexMap< String, - Either>>, + Either>>, > = IndexMap::new(); message_map.insert( "role".to_string(), @@ -758,7 +759,7 @@ impl Runner { } let mut message_map: IndexMap< String, - Either>>, + Either>>, > = IndexMap::new(); message_map.insert( "role".to_string(), @@ -772,11 +773,11 @@ impl Runner { let mut content_map = Vec::new(); let mut content_image_map = IndexMap::new(); - content_image_map.insert("type".to_string(), "image".to_string()); + content_image_map.insert("type".to_string(), Value::String("image".to_string())); content_map.push(content_image_map); let mut content_text_map = IndexMap::new(); - content_text_map.insert("type".to_string(), "text".to_string()); - content_text_map.insert("text".to_string(), content); + content_text_map.insert("type".to_string(), Value::String("text".to_string())); + content_text_map.insert("text".to_string(), Value::String(content)); content_map.push(content_text_map); message_map @@ -806,7 +807,7 @@ impl Runner { let mut messages = Vec::new(); let mut message_map: IndexMap< String, - Either>>, + Either>>, > = IndexMap::new(); message_map.insert("role".to_string(), Either::Left("user".to_string())); message_map.insert("content".to_string(), Either::Left(prompt.to_string())); diff --git a/mistralrs-server/src/chat_completion.rs b/mistralrs-server/src/chat_completion.rs index 5ba1745de..499377737 100644 --- a/mistralrs-server/src/chat_completion.rs +++ b/mistralrs-server/src/chat_completion.rs @@ -8,6 +8,7 @@ use std::{ task::{Context, Poll}, time::Duration, }; +use serde_json::Value; use tokio::sync::mpsc::{channel, Receiver, Sender}; use crate::{ @@ -173,7 +174,7 @@ async fn parse_request( Either::Left(content) => { let mut message_map: IndexMap< String, - Either>>, + Either>>, > = IndexMap::new(); message_map.insert("role".to_string(), Either::Left(message.role)); message_map @@ -234,7 +235,7 @@ async fn parse_request( } let mut message_map: IndexMap< String, - Either>>, + Either>>, > = IndexMap::new(); message_map.insert("role".to_string(), Either::Left(message.role)); let (content, url) = if items[0] == "text" { @@ -243,13 +244,13 @@ async fn parse_request( get_content_and_url(1, 0, image_messages)? }; - let mut content_map = Vec::new(); + let mut content_map: Vec> = Vec::new(); let mut content_image_map = IndexMap::new(); - content_image_map.insert("type".to_string(), "image".to_string()); + content_image_map.insert("type".to_string(), Value::String("image".to_string())); content_map.push(content_image_map); let mut content_text_map = IndexMap::new(); - content_text_map.insert("type".to_string(), "text".to_string()); - content_text_map.insert("text".to_string(), content); + content_text_map.insert("type".to_string(), Value::String("text".to_string())); + content_text_map.insert("text".to_string(), Value::String(content)); content_map.push(content_text_map); message_map.insert("content".to_string(), Either::Right(content_map)); @@ -276,7 +277,7 @@ async fn parse_request( } Either::Right(prompt) => { let mut messages = Vec::new(); - let mut message_map: IndexMap>>> = + let mut message_map: IndexMap>>> = IndexMap::new(); message_map.insert("role".to_string(), Either::Left("user".to_string())); message_map.insert("content".to_string(), Either::Left(prompt)); diff --git a/mistralrs-server/src/interactive_mode.rs b/mistralrs-server/src/interactive_mode.rs index 3575f3fea..f9d881590 100644 --- a/mistralrs-server/src/interactive_mode.rs +++ b/mistralrs-server/src/interactive_mode.rs @@ -6,6 +6,7 @@ use mistralrs_core::{ ResponseOk, SamplingParams, TERMINATE_ALL_NEXT_STEP, }; use once_cell::sync::Lazy; +use serde_json::Value; use std::{ io::{self, Write}, sync::{atomic::Ordering, Arc, Mutex}, @@ -220,7 +221,7 @@ async fn text_interactive_mode(mistralrs: Arc, throughput: bool) { println!(); info!("Average T/s: {}", toks as f64 / time); } - let mut assistant_message: IndexMap>>> = + let mut assistant_message: IndexMap>>> = IndexMap::new(); assistant_message.insert("role".to_string(), Either::Left("assistant".to_string())); assistant_message.insert("content".to_string(), Either::Left(assistant_output)); @@ -402,7 +403,7 @@ async fn vision_interactive_mode(mistralrs: Arc, throughput: bool) { println!(); info!("Average T/s: {}", toks as f64 / time); } - let mut assistant_message: IndexMap>>> = + let mut assistant_message: IndexMap>>> = IndexMap::new(); assistant_message.insert("role".to_string(), Either::Left("assistant".to_string())); assistant_message.insert("content".to_string(), Either::Left(assistant_output)); diff --git a/mistralrs/examples/lower_level/tools/main.rs b/mistralrs/examples/lower_level/tools/main.rs index 4ccf7a5b9..64197e6d2 100644 --- a/mistralrs/examples/lower_level/tools/main.rs +++ b/mistralrs/examples/lower_level/tools/main.rs @@ -128,7 +128,7 @@ fn main() -> anyhow::Result<()> { ("role".to_string(), Either::Left("asistant".to_string())), ( "content".to_string(), - Either::Left( + Either::Right( json!({ "name": called.function.name, "parameters": called.function.arguments, diff --git a/mistralrs/src/messages.rs b/mistralrs/src/messages.rs index d525119ee..c0433d56f 100644 --- a/mistralrs/src/messages.rs +++ b/mistralrs/src/messages.rs @@ -4,6 +4,7 @@ use super::*; use either::Either; use image::DynamicImage; use indexmap::IndexMap; +use serde_json::Value; /// A type which can be used as a request. pub trait RequestLike { @@ -203,10 +204,10 @@ impl VisionMessages { ( "content".to_string(), Either::Right(vec![ - IndexMap::from([("type".to_string(), "image".to_string())]), + IndexMap::from([("type".to_string(), Value::String("image".to_string()))]), IndexMap::from([ - ("type".to_string(), "text".to_string()), - ("content".to_string(), text.to_string()), + ("type".to_string(), Value::String("text".to_string())), + ("content".to_string(), Value::String(text.to_string())), ]), ]), ), From 60e9375b6049eead77cbda7cb3d3d83187f501a2 Mon Sep 17 00:00:00 2001 From: jeadie Date: Fri, 4 Oct 2024 16:34:12 +1000 Subject: [PATCH 2/2] linting --- mistralrs-core/src/pipeline/processing.rs | 6 ++++-- mistralrs-pyo3/src/lib.rs | 6 ++++-- mistralrs-server/src/chat_completion.rs | 8 +++++--- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/mistralrs-core/src/pipeline/processing.rs b/mistralrs-core/src/pipeline/processing.rs index 91d5d2e8f..bb19e3f46 100644 --- a/mistralrs-core/src/pipeline/processing.rs +++ b/mistralrs-core/src/pipeline/processing.rs @@ -85,10 +85,12 @@ pub(crate) fn apply_chat_template( for (content_k, content_v) in content_row { if content_k == "text" { if let Some(content_str) = content_v.as_str() { - new_message.insert(k, Either::Left(content_str.to_string())); + new_message.insert( + k, + Either::Left(content_str.to_string()), + ); break 'outer; } - } } } diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 62a4c3033..dc54c1a1c 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -773,10 +773,12 @@ impl Runner { let mut content_map = Vec::new(); let mut content_image_map = IndexMap::new(); - content_image_map.insert("type".to_string(), Value::String("image".to_string())); + content_image_map + .insert("type".to_string(), Value::String("image".to_string())); content_map.push(content_image_map); let mut content_text_map = IndexMap::new(); - content_text_map.insert("type".to_string(), Value::String("text".to_string())); + content_text_map + .insert("type".to_string(), Value::String("text".to_string())); content_text_map.insert("text".to_string(), Value::String(content)); content_map.push(content_text_map); diff --git a/mistralrs-server/src/chat_completion.rs b/mistralrs-server/src/chat_completion.rs index 499377737..e10092e68 100644 --- a/mistralrs-server/src/chat_completion.rs +++ b/mistralrs-server/src/chat_completion.rs @@ -1,3 +1,4 @@ +use serde_json::Value; use std::{ collections::HashMap, env, @@ -8,7 +9,6 @@ use std::{ task::{Context, Poll}, time::Duration, }; -use serde_json::Value; use tokio::sync::mpsc::{channel, Receiver, Sender}; use crate::{ @@ -246,10 +246,12 @@ async fn parse_request( let mut content_map: Vec> = Vec::new(); let mut content_image_map = IndexMap::new(); - content_image_map.insert("type".to_string(), Value::String("image".to_string())); + content_image_map + .insert("type".to_string(), Value::String("image".to_string())); content_map.push(content_image_map); let mut content_text_map = IndexMap::new(); - content_text_map.insert("type".to_string(), Value::String("text".to_string())); + content_text_map + .insert("type".to_string(), Value::String("text".to_string())); content_text_map.insert("text".to_string(), Value::String(content)); content_map.push(content_text_map);