Skip to content

Commit

Permalink
Merge pull request #7 from spiceai/jeadie/24-10-16/updates
Browse files Browse the repository at this point in the history
Add better methods for using tools in  and update examples
  • Loading branch information
Jeadie authored Oct 16, 2024
2 parents 470e0c1 + 23d7df0 commit cae50ae
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 17 deletions.
8 changes: 8 additions & 0 deletions mistralrs-core/src/tools/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ pub enum ToolCallType {
Function,
}

impl std::fmt::Display for ToolCallType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ToolCallType::Function => write!(f, "function"),
}
}
}

#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
Expand Down
24 changes: 15 additions & 9 deletions mistralrs/examples/lower_level/tools/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,27 @@ fn main() -> anyhow::Result<()> {
let result = get_weather(input);
// Add tool call message from assistant so it knows what it called
messages.push(IndexMap::from([
("role".to_string(), Either::Left("asistant".to_string())),
("role".to_string(), Either::Left("assistant".to_string())),
("content".to_string(), Either::Left("".to_string())),
(
"content".to_string(),
Either::Right(
json!({
"name": called.function.name,
"parameters": called.function.arguments,
})
.to_string(),
),
"tool_calls".to_string(),
Either::Right(Vec![IndexMap::from([
("id".to_string(), tool_call.id),
(
"function".to_string(),
json!({
"name": called.function.name,
"arguments": called.function.arguments,
})
),
("type".to_string(), "function".to_string()),
])]),
),
]));
// Add message from the tool
messages.push(IndexMap::from([
("role".to_string(), Either::Left("tool".to_string())),
("tool_call_id".to_string(), Either::Left(tool_call.id)),
("content".to_string(), Either::Left(result)),
]));

Expand Down
11 changes: 4 additions & 7 deletions mistralrs/examples/tools/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,12 @@ async fn main() -> Result<()> {
// Add tool call message from assistant so it knows what it called
// Then, add message from the tool
messages = messages
.add_message(
.add_message_with_tool_call(
TextMessageRole::Assistant,
json!({
"name": called.function.name,
"parameters": called.function.arguments,
})
.to_string(),
String::new(),
vec![called.clone()],
)
.add_message(TextMessageRole::Tool, result)
.add_tool_message(result, called.id.clone())
.set_tool_choice(ToolChoice::None);

let response = model.send_chat_request(messages.clone()).await?;
Expand Down
55 changes: 54 additions & 1 deletion mistralrs/src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::*;
use either::Either;
use image::DynamicImage;
use indexmap::IndexMap;
use serde_json::Value;
use serde_json::{json, Value};

/// A type which can be used as a request.
pub trait RequestLike {
Expand Down Expand Up @@ -331,6 +331,10 @@ impl RequestBuilder {
}
}

/// Add a message to the request.
///
/// For messages with tool calls, use [`Self::add_message_with_tool_call`].
/// For messages with tool outputs, use [`Self::add_tool_message`].
pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
self.messages.push(IndexMap::from([
("role".to_string(), Either::Left(role.to_string())),
Expand All @@ -339,6 +343,55 @@ impl RequestBuilder {
self
}

/// Add a message with the output of a tool call.
pub fn add_tool_message(mut self, tool_content: impl ToString, tool_id: impl ToString) -> Self {
self.messages.push(IndexMap::from([
(
"role".to_string(),
Either::Left(TextMessageRole::Tool.to_string()),
),
(
"content".to_string(),
Either::Left(tool_content.to_string()),
),
(
"tool_call_id".to_string(),
Either::Left(tool_id.to_string()),
),
]));
self
}

pub fn add_message_with_tool_call(
mut self,
role: TextMessageRole,
text: impl ToString,
tool_calls: Vec<ToolCallResponse>,
) -> Self {
let tool_messages = tool_calls
.iter()
.map(|t| {
IndexMap::from([
("id".to_string(), Value::String(t.id.clone())),
("type".to_string(), Value::String(t.tp.to_string())),
(
"function".to_string(),
json!({
"name": t.function.name,
"arguments": t.function.arguments,
}),
),
])
})
.collect();
self.messages.push(IndexMap::from([
("role".to_string(), Either::Left(role.to_string())),
("content".to_string(), Either::Left(text.to_string())),
("function".to_string(), Either::Right(tool_messages)),
]));
self
}

pub fn add_image_message(
mut self,
role: TextMessageRole,
Expand Down

0 comments on commit cae50ae

Please sign in to comment.