Skip to content

Commit

Permalink
Moved tool call logic from SmartMessage to SmartThread
Browse files Browse the repository at this point in the history
  • Loading branch information
Brian Joseph Petro committed Dec 9, 2024
1 parent d8f8319 commit eddd3be
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 67 deletions.
62 changes: 2 additions & 60 deletions smart-chats/smart_message.js
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,15 @@ export class SmartMessage extends SmartBlock {
* @async
*/
async init() {
while (!this.thread) await new Promise(resolve => setTimeout(resolve, 100)); // this shouldn't be necessary (why is it not working without this?)
while (!this.thread) await new Promise(resolve => setTimeout(resolve, 100));
if(!this.thread.data.messages[this.id]){
this.thread.data.messages[this.id] = this.msg_i;
await new Promise(resolve => setTimeout(resolve, 30));
}
await this.render();
if(this.role === 'user') {
await this.thread.complete();
}else if(this.tool_calls?.length > 0){
await this.handle_tool_calls();
}else if(this.role === 'tool'){
} else if(this.role === 'tool'){
if(!this.settings.review_context){
this.thread.complete();
}
Expand Down Expand Up @@ -169,62 +167,6 @@ export class SmartMessage extends SmartBlock {
}
}

async handle_tool_calls(){
for(const tool_call of this.tool_calls){
if(tool_call.function.name === 'lookup'){
await this.handle_lookup_tool_call(tool_call);
}
}
}

build_lookup_params(args){
const params = {};
args = typeof args === 'string' ? JSON.parse(args) : args;
if(Array.isArray(args.hypotheticals)){
params.hypotheticals = args.hypotheticals;
}else if(typeof args.hypotheticals === 'object' && args.hypotheticals !== null){
params.hypotheticals = Object.values(args.hypotheticals);
}else if(typeof args.hypotheticals === 'string'){
params.hypotheticals = [args.hypotheticals];
}else{
console.warn('Invalid hypotheticals provided for lookup tool call, using user message as lookup context, args:' + JSON.stringify(args));
params.hypotheticals = [this.content];
}
params.hypotheticals = params.hypotheticals.map(h => {
if(typeof h === 'string') return h;
else return JSON.stringify(h);
})
if(this.previous_message.context.folder_refs) params.filter = {
key_starts_with_any: this.previous_message.context.folder_refs
};
params.filter = {
...(params.filter || {}),
limit: this.settings.lookup_limit || 10,
};
return params;
}
async handle_lookup_tool_call(tool_call){
const params = this.build_lookup_params(tool_call.function.arguments);
const lookup_collection = this.env.smart_blocks.settings.embed_blocks ? this.env.smart_blocks : this.env.smart_sources;
const lookup_results = (await lookup_collection.lookup(params))
.map(result => ({
key: result.item.key,
score: result.score,
}))
;
const msg_i = Object.keys(this.thread.data.messages || {}).length + 1;
const branch_i = (this.thread.data.branches?.[msg_i] || []).length + 1;
await this.env.smart_messages.create_or_update({
thread_key: this.thread.key,
tool_call_id: tool_call.id,
tool_name: tool_call.function.name,
tool_call_output: lookup_results,
role: 'tool',
response_id: tool_call.id,
id: `tool-${msg_i}-${branch_i}`,
});
}

/**
* Converts the message to a request payload
* @returns {Array<Object>} Request payload
Expand Down
108 changes: 101 additions & 7 deletions smart-chats/smart_thread.js
Original file line number Diff line number Diff line change
Expand Up @@ -215,21 +215,115 @@ export class SmartThread extends SmartSource {
const choices = response.choices;
const response_id = response.id;
if(!response_id) return [];
const msg_items = await Promise.all(choices.map(async (choice, index) => {

const msg_items = [];
for (const choice of choices) {
const msg_data = {
...(choice?.message || choice), // fallback on full choice to handle non-message choices
...(choice?.message || choice),
thread_key: this.key,
response_id,
};
const msg = this.messages.find(msg => msg.data.response_id === response_id);
if(msg){
msg_data.key = msg.key;

const msg = this.messages.find(m => m.data.response_id === response_id);
if(msg) msg_data.key = msg.key;
const new_msg = await this.env.smart_messages.create_or_update(msg_data);
msg_items.push(new_msg);

// Handle tool calls
if (msg_data.tool_calls?.length > 0) {
await this.handle_tool_calls(msg_data.tool_calls, msg_data);
}
return this.env.smart_messages.create_or_update(msg_data);
}));
}
return msg_items;
}

/**
* Handle any tool calls detected in a message.
* This was previously in SmartMessage, now moved to SmartThread.
* @param {Array<Object>} tool_calls
* @param {Object} msg_data
*/
async handle_tool_calls(tool_calls, msg_data) {
for (const tool_call of tool_calls) {
if (tool_call.function.name === 'lookup') {
await this.handle_lookup_tool_call(tool_call, msg_data);
}
}
}

/**
* Builds lookup parameters for the lookup tool call
* @param {Object|string} args - tool call arguments
* @param {Object} previous_message - the previous SmartMessage instance
* @returns {Object} params
*/
build_lookup_params(args, previous_message) {
const params = {};
args = typeof args === 'string' ? JSON.parse(args) : args;
if (Array.isArray(args.hypotheticals)) {
params.hypotheticals = args.hypotheticals;
} else if (typeof args.hypotheticals === 'object' && args.hypotheticals !== null) {
params.hypotheticals = Object.values(args.hypotheticals);
} else if (typeof args.hypotheticals === 'string') {
params.hypotheticals = [args.hypotheticals];
} else {
console.warn('Invalid hypotheticals provided for lookup tool call, using user message as lookup context, args:' + JSON.stringify(args));
// Fall back to previous message content or empty
const fallback_content = previous_message?.content || 'No context';
params.hypotheticals = [fallback_content];
}

// Ensure all hypotheticals are strings
params.hypotheticals = params.hypotheticals.map(h => {
if (typeof h === 'string') return h;
else return JSON.stringify(h);
});

// If previous_message has folder refs, use them as filters
if (previous_message?.context?.folder_refs) {
params.filter = {
key_starts_with_any: previous_message.context.folder_refs
};
}

params.filter = {
...(params.filter || {}),
limit: this.settings.lookup_limit || 10,
};

return params;
}

/**
* Handle lookup tool call logic
* @param {Object} tool_call
* @param {Object} msg_data - The data for the message that triggered the tool call
*/
async handle_lookup_tool_call(tool_call, msg_data) {
const previous_message = this.messages[this.messages.length - 1];
const params = this.build_lookup_params(tool_call.function.arguments, previous_message);

// Determine which collection to use (based on embed settings)
const lookup_collection = this.env.smart_blocks.settings.embed_blocks ? this.env.smart_blocks : this.env.smart_sources;
const lookup_results = (await lookup_collection.lookup(params)).map(result => ({
key: result.item.key,
score: result.score,
}));

const msg_i = Object.keys(this.data.messages || {}).length + 1;
const branch_i = (this.data.branches?.[msg_i] || []).length + 1;

await this.env.smart_messages.create_or_update({
thread_key: this.key,
tool_call_id: tool_call.id,
tool_name: tool_call.function.name,
tool_call_output: lookup_results,
role: 'tool',
response_id: tool_call.id,
id: `tool-${msg_i}-${branch_i}`,
});
}

/**
* Prepares the request payload for the AI model
* @async
Expand Down

0 comments on commit eddd3be

Please sign in to comment.