Skip to content

Commit

Permalink
Merge pull request #5 from spiceai/jack/tools
Browse files Browse the repository at this point in the history
Handle assistant messages with 'tool_calls' when used in chat_template
  • Loading branch information
Jeadie authored Oct 4, 2024
2 parents 3e79d85 + 60e9375 commit 470e0c1
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 31 deletions.
17 changes: 9 additions & 8 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ mod tests {
use crate::MessageContent;
use either::Either;
use indexmap::IndexMap;
use serde_json::Value;

macro_rules! hashmap {
(@single $($x:tt)*) => (());
Expand All @@ -550,7 +551,7 @@ mod tests {
let _cap = hashmap!(@count $($key),*);
let mut _map = ::indexmap::IndexMap::with_capacity(_cap);
$(
let _ = _map.insert($key, $value);
let _ = _map.insert($key, Value::String($value));
)*
_map
}
Expand Down Expand Up @@ -655,7 +656,7 @@ mod tests {
];
let mut inputs = Vec::new();
for [role, content] in messages {
let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, String>>>> =
let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
IndexMap::new();
message.insert("role".to_string(), Either::Left(role.to_string()));
message.insert("content".to_string(), Either::Left(content.to_string()));
Expand Down Expand Up @@ -689,7 +690,7 @@ mod tests {

let mut inputs = Vec::new();

let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, String>>>> =
let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
IndexMap::new();
message.insert("role".to_string(), Either::Left("system".to_string()));
message.insert(
Expand All @@ -701,7 +702,7 @@ mod tests {
);
inputs.push(message);

let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, String>>>> =
let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
IndexMap::new();
message.insert("role".to_string(), Either::Left("user".to_string()));
message.insert(
Expand All @@ -718,7 +719,7 @@ mod tests {
);
inputs.push(message);

let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, String>>>> =
let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
IndexMap::new();
message.insert("role".to_string(), Either::Left("assistant".to_string()));
message.insert(
Expand All @@ -730,7 +731,7 @@ mod tests {
);
inputs.push(message);

let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, String>>>> =
let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
IndexMap::new();
message.insert("role".to_string(), Either::Left("user".to_string()));
message.insert(
Expand All @@ -747,7 +748,7 @@ mod tests {
);
inputs.push(message);

let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, String>>>> =
let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
IndexMap::new();
message.insert("role".to_string(), Either::Left("assistant".to_string()));
message.insert(
Expand All @@ -759,7 +760,7 @@ mod tests {
);
inputs.push(message);

let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, String>>>> =
let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
IndexMap::new();
message.insert("role".to_string(), Either::Left("user".to_string()));
message.insert(
Expand Down
11 changes: 8 additions & 3 deletions mistralrs-core/src/pipeline/processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,13 @@ pub(crate) fn apply_chat_template(
'outer: for content_row in rv {
for (content_k, content_v) in content_row {
if content_k == "text" {
new_message.insert(k, Either::Left(content_v));
break 'outer;
if let Some(content_str) = content_v.as_str() {
new_message.insert(
k,
Either::Left(content_str.to_string()),
);
break 'outer;
}
}
}
}
Expand Down Expand Up @@ -149,6 +154,6 @@ impl Processor for BasicProcessor {
&[]
}
fn template_action(&self) -> MessagesAction {
MessagesAction::FlattenOnlyText
MessagesAction::Keep
}
}
3 changes: 2 additions & 1 deletion mistralrs-core/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use either::Either;
use indexmap::IndexMap;
use mistralrs_quant::IsqType;
use serde::{Deserialize, Serialize};
use serde_json::Value;

use crate::{
response::Response,
Expand All @@ -28,7 +29,7 @@ pub enum ImageGenerationResponseFormat {
B64Json,
}

pub type MessageContent = Either<String, Vec<IndexMap<String, String>>>;
pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;

#[derive(Clone, Debug)]
/// Message or messages for a [`Request`].
Expand Down
15 changes: 9 additions & 6 deletions mistralrs-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use anymoe::{AnyMoeConfig, AnyMoeExpertType};
use either::Either;
use indexmap::IndexMap;
use requests::{ChatCompletionRequest, CompletionRequest, ToolChoice};
use serde_json::Value;
use std::{
cell::RefCell,
collections::HashMap,
Expand Down Expand Up @@ -681,7 +682,7 @@ impl Runner {
Either::Left(content) => {
let mut message_map: IndexMap<
String,
Either<String, Vec<IndexMap<String, String>>>,
Either<String, Vec<IndexMap<String, Value>>>,
> = IndexMap::new();
message_map.insert(
"role".to_string(),
Expand Down Expand Up @@ -758,7 +759,7 @@ impl Runner {
}
let mut message_map: IndexMap<
String,
Either<String, Vec<IndexMap<String, String>>>,
Either<String, Vec<IndexMap<String, Value>>>,
> = IndexMap::new();
message_map.insert(
"role".to_string(),
Expand All @@ -772,11 +773,13 @@ impl Runner {

let mut content_map = Vec::new();
let mut content_image_map = IndexMap::new();
content_image_map.insert("type".to_string(), "image".to_string());
content_image_map
.insert("type".to_string(), Value::String("image".to_string()));
content_map.push(content_image_map);
let mut content_text_map = IndexMap::new();
content_text_map.insert("type".to_string(), "text".to_string());
content_text_map.insert("text".to_string(), content);
content_text_map
.insert("type".to_string(), Value::String("text".to_string()));
content_text_map.insert("text".to_string(), Value::String(content));
content_map.push(content_text_map);

message_map
Expand Down Expand Up @@ -806,7 +809,7 @@ impl Runner {
let mut messages = Vec::new();
let mut message_map: IndexMap<
String,
Either<String, Vec<IndexMap<String, String>>>,
Either<String, Vec<IndexMap<String, Value>>>,
> = IndexMap::new();
message_map.insert("role".to_string(), Either::Left("user".to_string()));
message_map.insert("content".to_string(), Either::Left(prompt.to_string()));
Expand Down
17 changes: 10 additions & 7 deletions mistralrs-server/src/chat_completion.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use serde_json::Value;
use std::{
collections::HashMap,
env,
Expand Down Expand Up @@ -173,7 +174,7 @@ async fn parse_request(
Either::Left(content) => {
let mut message_map: IndexMap<
String,
Either<String, Vec<IndexMap<String, String>>>,
Either<String, Vec<IndexMap<String, Value>>>,
> = IndexMap::new();
message_map.insert("role".to_string(), Either::Left(message.role));
message_map
Expand Down Expand Up @@ -234,7 +235,7 @@ async fn parse_request(
}
let mut message_map: IndexMap<
String,
Either<String, Vec<IndexMap<String, String>>>,
Either<String, Vec<IndexMap<String, Value>>>,
> = IndexMap::new();
message_map.insert("role".to_string(), Either::Left(message.role));
let (content, url) = if items[0] == "text" {
Expand All @@ -243,13 +244,15 @@ async fn parse_request(
get_content_and_url(1, 0, image_messages)?
};

let mut content_map = Vec::new();
let mut content_map: Vec<IndexMap<String, Value>> = Vec::new();
let mut content_image_map = IndexMap::new();
content_image_map.insert("type".to_string(), "image".to_string());
content_image_map
.insert("type".to_string(), Value::String("image".to_string()));
content_map.push(content_image_map);
let mut content_text_map = IndexMap::new();
content_text_map.insert("type".to_string(), "text".to_string());
content_text_map.insert("text".to_string(), content);
content_text_map
.insert("type".to_string(), Value::String("text".to_string()));
content_text_map.insert("text".to_string(), Value::String(content));
content_map.push(content_text_map);

message_map.insert("content".to_string(), Either::Right(content_map));
Expand All @@ -276,7 +279,7 @@ async fn parse_request(
}
Either::Right(prompt) => {
let mut messages = Vec::new();
let mut message_map: IndexMap<String, Either<String, Vec<IndexMap<String, String>>>> =
let mut message_map: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
IndexMap::new();
message_map.insert("role".to_string(), Either::Left("user".to_string()));
message_map.insert("content".to_string(), Either::Left(prompt));
Expand Down
5 changes: 3 additions & 2 deletions mistralrs-server/src/interactive_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use mistralrs_core::{
ResponseOk, SamplingParams, TERMINATE_ALL_NEXT_STEP,
};
use once_cell::sync::Lazy;
use serde_json::Value;
use std::{
io::{self, Write},
sync::{atomic::Ordering, Arc, Mutex},
Expand Down Expand Up @@ -220,7 +221,7 @@ async fn text_interactive_mode(mistralrs: Arc<MistralRs>, throughput: bool) {
println!();
info!("Average T/s: {}", toks as f64 / time);
}
let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, String>>>> =
let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
IndexMap::new();
assistant_message.insert("role".to_string(), Either::Left("assistant".to_string()));
assistant_message.insert("content".to_string(), Either::Left(assistant_output));
Expand Down Expand Up @@ -402,7 +403,7 @@ async fn vision_interactive_mode(mistralrs: Arc<MistralRs>, throughput: bool) {
println!();
info!("Average T/s: {}", toks as f64 / time);
}
let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, String>>>> =
let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
IndexMap::new();
assistant_message.insert("role".to_string(), Either::Left("assistant".to_string()));
assistant_message.insert("content".to_string(), Either::Left(assistant_output));
Expand Down
2 changes: 1 addition & 1 deletion mistralrs/examples/lower_level/tools/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn main() -> anyhow::Result<()> {
("role".to_string(), Either::Left("asistant".to_string())),
(
"content".to_string(),
Either::Left(
Either::Right(
json!({
"name": called.function.name,
"parameters": called.function.arguments,
Expand Down
7 changes: 4 additions & 3 deletions mistralrs/src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use super::*;
use either::Either;
use image::DynamicImage;
use indexmap::IndexMap;
use serde_json::Value;

/// A type which can be used as a request.
pub trait RequestLike {
Expand Down Expand Up @@ -203,10 +204,10 @@ impl VisionMessages {
(
"content".to_string(),
Either::Right(vec![
IndexMap::from([("type".to_string(), "image".to_string())]),
IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
IndexMap::from([
("type".to_string(), "text".to_string()),
("content".to_string(), text.to_string()),
("type".to_string(), Value::String("text".to_string())),
("content".to_string(), Value::String(text.to_string())),
]),
]),
),
Expand Down

0 comments on commit 470e0c1

Please sign in to comment.