diff --git a/mistralrs-core/src/tools/response.rs b/mistralrs-core/src/tools/response.rs index 4eecd5caf..1e082a900 100644 --- a/mistralrs-core/src/tools/response.rs +++ b/mistralrs-core/src/tools/response.rs @@ -6,6 +6,14 @@ pub enum ToolCallType { Function, } +impl std::fmt::Display for ToolCallType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ToolCallType::Function => write!(f, "function"), + } + } +} + #[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)] #[cfg_attr(feature = "pyo3_macros", pyo3(get_all))] #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] diff --git a/mistralrs/examples/lower_level/tools/main.rs b/mistralrs/examples/lower_level/tools/main.rs index 64197e6d2..4fc8b1b08 100644 --- a/mistralrs/examples/lower_level/tools/main.rs +++ b/mistralrs/examples/lower_level/tools/main.rs @@ -125,21 +125,27 @@ fn main() -> anyhow::Result<()> { let result = get_weather(input); // Add tool call message from assistant so it knows what it called messages.push(IndexMap::from([ - ("role".to_string(), Either::Left("asistant".to_string())), + ("role".to_string(), Either::Left("assistant".to_string())), + ("content".to_string(), Either::Left("".to_string())), ( - "content".to_string(), - Either::Right( - json!({ - "name": called.function.name, - "parameters": called.function.arguments, - }) - .to_string(), - ), + "tool_calls".to_string(), + Either::Right(Vec![IndexMap::from([ + ("id".to_string(), tool_call.id), + ( + "function".to_string(), + json!({ + "name": called.function.name, + "arguments": called.function.arguments, + }) + ), + ("type".to_string(), "function".to_string()), + ])]), ), ])); // Add message from the tool messages.push(IndexMap::from([ ("role".to_string(), Either::Left("tool".to_string())), + ("tool_call_id".to_string(), Either::Left(tool_call.id)), ("content".to_string(), Either::Left(result)), ])); diff --git a/mistralrs/examples/tools/main.rs b/mistralrs/examples/tools/main.rs index 10b102047..02c002ca0 100644 --- a/mistralrs/examples/tools/main.rs +++ b/mistralrs/examples/tools/main.rs @@ -66,15 +66,12 @@ async fn main() -> Result<()> { // Add tool call message from assistant so it knows what it called // Then, add message from the tool messages = messages - .add_message( + .add_message_with_tool_call( TextMessageRole::Assistant, - json!({ - "name": called.function.name, - "parameters": called.function.arguments, - }) - .to_string(), + String::new(), + vec![called.clone()], ) - .add_message(TextMessageRole::Tool, result) + .add_tool_message(result, called.id.clone()) .set_tool_choice(ToolChoice::None); let response = model.send_chat_request(messages.clone()).await?; diff --git a/mistralrs/src/messages.rs b/mistralrs/src/messages.rs index c0433d56f..b31b9c4a1 100644 --- a/mistralrs/src/messages.rs +++ b/mistralrs/src/messages.rs @@ -4,7 +4,7 @@ use super::*; use either::Either; use image::DynamicImage; use indexmap::IndexMap; -use serde_json::Value; +use serde_json::{json, Value}; /// A type which can be used as a request. pub trait RequestLike { @@ -331,6 +331,10 @@ impl RequestBuilder { } } + /// Add a message to the request. + /// + /// For messages with tool calls, use [`Self::add_message_with_tool_call`]. + /// For messages with tool outputs, use [`Self::add_tool_message`]. pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self { self.messages.push(IndexMap::from([ ("role".to_string(), Either::Left(role.to_string())), @@ -339,6 +343,55 @@ impl RequestBuilder { self } + /// Add a message with the output of a tool call. + pub fn add_tool_message(mut self, tool_content: impl ToString, tool_id: impl ToString) -> Self { + self.messages.push(IndexMap::from([ + ( + "role".to_string(), + Either::Left(TextMessageRole::Tool.to_string()), + ), + ( + "content".to_string(), + Either::Left(tool_content.to_string()), + ), + ( + "tool_call_id".to_string(), + Either::Left(tool_id.to_string()), + ), + ])); + self + } + + pub fn add_message_with_tool_call( + mut self, + role: TextMessageRole, + text: impl ToString, + tool_calls: Vec, + ) -> Self { + let tool_messages = tool_calls + .iter() + .map(|t| { + IndexMap::from([ + ("id".to_string(), Value::String(t.id.clone())), + ("type".to_string(), Value::String(t.tp.to_string())), + ( + "function".to_string(), + json!({ + "name": t.function.name, + "arguments": t.function.arguments, + }), + ), + ]) + }) + .collect(); + self.messages.push(IndexMap::from([ + ("role".to_string(), Either::Left(role.to_string())), + ("content".to_string(), Either::Left(text.to_string())), + ("function".to_string(), Either::Right(tool_messages)), + ])); + self + } + pub fn add_image_message( mut self, role: TextMessageRole,