diff --git a/src/constants.rs b/src/constants.rs index 4957c17..56580c7 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -6,9 +6,7 @@ pub const EMBEDDINGS_DIMENSION: usize = 384; pub const SSE_CHANNEL_BUFFER_SIZE: usize = 1; -pub const FINAL_EXPLANATION_TEMPERATURE: f64 = 0.7; - -pub const FUNCTIONS_CALLS_TEMPERATURE: f64 = 0.5; +pub const CHAT_COMPLETION_TEMPERATURE: f64 = 0.5; pub const ACTIX_WEB_SERVER_PORT: usize = 3000; diff --git a/src/conversation/data.rs b/src/conversation/data.rs new file mode 100644 index 0000000..550fa4a --- /dev/null +++ b/src/conversation/data.rs @@ -0,0 +1,63 @@ +use crate::prelude::*; +use crate::{github::Repository, utils::functions::Function}; +use openai_api_rs::v1::chat_completion::FunctionCall; +use serde::Deserialize; +use std::str::FromStr; + +#[derive(Deserialize)] +pub struct Query { + pub repository: Repository, + pub query: String, +} + +impl ToString for Query { + fn to_string(&self) -> String { + let Query { + repository: + Repository { + owner, + name, + branch, + }, + query, + } = self; + format!( + "##Repository Info##\nOwner:{}\nName:{}\nBranch:{}\n##User Query##\nQuery:{}", + owner, name, branch, query + ) + } +} + +#[derive(Debug)] +pub struct RelevantChunk { + pub path: String, + pub content: String, +} + +impl ToString for RelevantChunk { + fn to_string(&self) -> String { + format!( + "##Relevant file chunk##\nPath argument:{}\nRelevant content: {}", + self.path, + self.content.trim() + ) + } +} + +#[derive(Debug, Clone)] +pub struct ParsedFunctionCall { + pub name: Function, + pub args: serde_json::Value, +} + +impl TryFrom<&FunctionCall> for ParsedFunctionCall { + type Error = anyhow::Error; + + fn try_from(func: &FunctionCall) -> Result { + let func = func.clone(); + let name = Function::from_str(&func.name.unwrap_or("done".into()))?; + let args = func.arguments.unwrap_or("{}".to_string()); + let args = serde_json::from_str::(&args)?; + Ok(ParsedFunctionCall { name, args }) + } +} diff --git a/src/utils/conversation/mod.rs b/src/conversation/mod.rs similarity index 81% rename from src/utils/conversation/mod.rs rename to src/conversation/mod.rs index cdde807..6725282 100644 --- a/src/utils/conversation/mod.rs +++ b/src/conversation/mod.rs @@ -1,82 +1,35 @@ #![allow(unused_must_use)] +mod data; mod prompts; use crate::{ - prelude::*, constants::{RELEVANT_CHUNKS_LIMIT, RELEVANT_FILES_LIMIT}, db::RepositoryEmbeddingsDB, embeddings::EmbeddingsModel, - github::Repository, + prelude::*, routes::events::{emit, QueryEvent}, }; use actix_web_lab::sse::Sender; -use openai_api_rs::v1::chat_completion::{FinishReason, FunctionCall}; +pub use data::*; +use openai_api_rs::v1::chat_completion::FinishReason; use openai_api_rs::v1::{ api::Client, chat_completion::{ ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, MessageRole, }, }; -use serde::Deserialize; use std::env; -use std::str::FromStr; use std::sync::Arc; use prompts::{generate_completion_request, system_message}; use self::prompts::answer_generation_prompt; -use super::functions::{ +use crate::utils::functions::{ paths_to_completion_message, relevant_chunks_to_completion_message, search_codebase, search_file, search_path, Function, }; -#[derive(Deserialize)] -pub struct Query { - pub repository: Repository, - pub query: String, -} - -impl ToString for Query { - fn to_string(&self) -> String { - let Query { - repository: - Repository { - owner, - name, - branch, - }, - query, - } = self; - format!( - "##Repository Info##\nOwner:{}\nName:{}\nBranch:{}\n##User Query##\nQuery:{}", - owner, name, branch, query - ) - } -} - -#[derive(Debug)] -pub struct RelevantChunk { - pub path: String, - pub content: String, -} - -impl ToString for RelevantChunk { - fn to_string(&self) -> String { - format!( - "##Relevant file chunk##\nPath argument:{}\nRelevant content: {}", - self.path, - self.content.trim() - ) - } -} - -#[derive(Debug, Clone)] -struct ParsedFunctionCall { - name: Function, - args: serde_json::Value, -} - pub struct Conversation { query: Query, client: Client, @@ -133,7 +86,7 @@ impl Conversation { #[allow(unused_labels)] 'conversation: loop { //Generate a request with the message history and functions - let request = generate_completion_request(self.messages.clone(), true); + let request = generate_completion_request(self.messages.clone(), "auto"); match self.send_request(request).await { Ok(response) => { @@ -141,7 +94,8 @@ impl Conversation { if let Some(function_call) = response.choices[0].message.function_call.clone() { - let parsed_function_call = parse_function_call(&function_call)?; + let parsed_function_call = + ParsedFunctionCall::try_from(&function_call)?; let function_call_message = ChatCompletionMessage { name: None, function_call: Some(function_call), @@ -229,22 +183,17 @@ impl Conversation { ); self.append_message(completion_message); } - Function::None => { + Function::Done => { self.prepare_final_explanation_message(); //Generate a request with the message history and no functions let request = - generate_completion_request(self.messages.clone(), false); - emit( - &self.sender, - QueryEvent::GenerateResponse(Some( - parsed_function_call.args, - )), - ) - .await; + generate_completion_request(self.messages.clone(), "none"); + emit(&self.sender, QueryEvent::GenerateResponse(None)).await; let response = match self.send_request(request).await { Ok(response) => response, Err(e) => { + dbg!(e.to_string()); return Err(e); } }; @@ -268,14 +217,3 @@ impl Conversation { } } } - -fn parse_function_call(func: &FunctionCall) -> Result { - let func = func.clone(); - let function_name = Function::from_str(&func.name.unwrap_or("none".into()))?; - let function_args = func.arguments.unwrap_or("{}".to_string()); - let function_args = serde_json::from_str::(&function_args)?; - Ok(ParsedFunctionCall { - name: function_name, - args: function_args, - }) -} diff --git a/src/utils/conversation/prompts.rs b/src/conversation/prompts.rs similarity index 76% rename from src/utils/conversation/prompts.rs rename to src/conversation/prompts.rs index ebdfb98..ac36f16 100644 --- a/src/utils/conversation/prompts.rs +++ b/src/conversation/prompts.rs @@ -1,57 +1,39 @@ use openai_api_rs::v1::chat_completion::{ - ChatCompletionMessage, ChatCompletionRequest, Function, FunctionParameters, JSONSchemaDefine, - JSONSchemaType, GPT3_5_TURBO, + ChatCompletionMessage, ChatCompletionRequest, Function as F, FunctionParameters, + JSONSchemaDefine, JSONSchemaType, GPT3_5_TURBO, }; use std::collections::HashMap; -use crate::constants::{FINAL_EXPLANATION_TEMPERATURE, FUNCTIONS_CALLS_TEMPERATURE}; +use crate::{constants::CHAT_COMPLETION_TEMPERATURE, utils::functions::Function}; pub fn generate_completion_request( messages: Vec, - with_functions: bool, + function_call: &str, ) -> ChatCompletionRequest { - //All the chat completion requests will have functions except for the final explanation request - if with_functions { - ChatCompletionRequest { - model: GPT3_5_TURBO.into(), - messages, - functions: Some(functions()), - function_call: None, - temperature: Some(FUNCTIONS_CALLS_TEMPERATURE), - top_p: None, - n: None, - stream: None, - stop: None, - max_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, - } - } else { - ChatCompletionRequest { - model: GPT3_5_TURBO.into(), - messages, - functions: None, - function_call: None, - temperature: Some(FINAL_EXPLANATION_TEMPERATURE), - top_p: None, - n: None, - stream: None, - stop: None, - max_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, - } + // https://platform.openai.com/docs/api-reference/chat/create + ChatCompletionRequest { + model: GPT3_5_TURBO.into(), + messages, + functions: Some(functions()), + function_call: Some(function_call.into()), + temperature: Some(CHAT_COMPLETION_TEMPERATURE), + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, } } -pub fn functions() -> Vec { +// https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions +pub fn functions() -> Vec { vec![ - Function { - name: "none".into(), + F { + name: Function::Done.to_string(), description: Some("This is the final step, and signals that you have enough information to respond to the user's query.".into()), parameters: Some(FunctionParameters { schema_type: JSONSchemaType::Object, @@ -59,8 +41,8 @@ pub fn functions() -> Vec { required: None, }), }, - Function { - name: "search_codebase".into(), + F { + name: Function::SearchCodebase.to_string(), description: Some("Search the contents of files in a repository semantically. Results will not necessarily match search terms exactly, but should be related.".into()), parameters: Some(FunctionParameters { schema_type: JSONSchemaType::Object, @@ -77,8 +59,8 @@ pub fn functions() -> Vec { required: Some(vec!["query".into()]), }) }, - Function { - name: "search_path".into(), + F { + name: Function::SearchPath.to_string(), description: Some("Search the pathnames in a repository. Results may not be exact matches, but will be similar by some edit-distance. Use when you want to find a specific file".into()), parameters: Some(FunctionParameters { schema_type: JSONSchemaType::Object, @@ -95,8 +77,8 @@ pub fn functions() -> Vec { required: Some(vec!["path".into()]), }) }, - Function { - name: "search_file".into(), + F { + name: Function::SearchFile.to_string(), description: Some("Search a file returned from functions.search_path. Results will not necessarily match search terms exactly, but should be related.".into()), parameters: Some(FunctionParameters { schema_type: JSONSchemaType::Object, @@ -130,13 +112,13 @@ pub fn system_message() -> String { Follow these rules at all times: - Respond with functions to find information related to the query, until all relevant information has been found. - If the output of a function is not relevant or sufficient, try the same function again with different arguments or try using a different function -- When you have enough information to answer the user's query respond with functions.none +- When you have enough information to answer the user's query respond with functions.done - Do not assume the structure of the codebase, or the existence of files or folders - Never respond with a function that you've used before with the same arguments - Do NOT respond with functions.search_file unless you have already called functions.search_path -- If after making a path search the query can be answered by the existance of the paths, use the functions.none function +- If after making a path search the query can be answered by the existance of the paths, use the functions.done function - Only refer to paths that are returned by the functions.search_path function when calling functions.search_file -- If after attempting to gather information you are still unsure how to answer the query, respond with the functions.none function +- If after attempting to gather information you are still unsure how to answer the query, respond with the functions.done function - Always respond with a function call. Do NOT answer the question directly"#, ) } diff --git a/src/github/mod.rs b/src/github/mod.rs index 9f48440..2d824e1 100644 --- a/src/github/mod.rs +++ b/src/github/mod.rs @@ -128,7 +128,7 @@ pub async fn fetch_file_content(repository: &Repository, path: &str) -> Result>(sender: &Sender, event: T) -> Result<(), SendError> { sender.send(event.into()).await?; //Empty message to force send the above message to receiver @@ -9,28 +11,6 @@ pub async fn emit>(sender: &Sender, event: T) -> Result<(), SendEr Ok(()) } -//Custom implementation for SSE Events based on https://crates.io/crates/enum_str -macro_rules! sse_events { - ($name:ident, $(($key:ident, $value:expr),)*) => { - #[derive(Debug, PartialEq)] - pub enum $name - { - $($key(Option)),* - } - - impl From<$name> for Data { - fn from(event: $name) -> Data { - match event { - $( - $name::$key(data) => Data::new(data.unwrap_or_default().to_string()).event($value) - ),* - } - } - } - - } -} - sse_events! { EmbedEvent, (FetchRepo, "FETCH_REPO"), diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 77d0fa9..e3da27b 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -2,8 +2,8 @@ pub mod events; use crate::constants::SSE_CHANNEL_BUFFER_SIZE; +use crate::conversation::{Conversation, Query}; use crate::github::fetch_repo_files; -use crate::utils::conversation::{Conversation, Query}; use crate::{db::RepositoryEmbeddingsDB, github::Repository}; use actix_web::{ post, diff --git a/src/utils/functions.rs b/src/utils/functions.rs index 2e13cbf..72dddcb 100644 --- a/src/utils/functions.rs +++ b/src/utils/functions.rs @@ -2,46 +2,23 @@ use std::str::FromStr; use crate::{ constants::FILE_CHUNKER_CAPACITY_RANGE, + conversation::RelevantChunk, db::RepositoryEmbeddingsDB, embeddings::{cosine_similarity, Embeddings, EmbeddingsModel}, + functions_enum, github::{fetch_file_content, Repository}, prelude::*, - utils::conversation::RelevantChunk, }; use ndarray::ArrayView1; use openai_api_rs::v1::chat_completion::{ChatCompletionMessage, MessageRole}; use rayon::prelude::*; -#[derive(Debug, Clone)] -pub enum Function { - SearchCodebase, - SearchFile, - SearchPath, - None, -} - -impl FromStr for Function { - type Err = anyhow::Error; - fn from_str(s: &str) -> std::result::Result { - match s { - "search_codebase" => Ok(Self::SearchCodebase), - "search_file" => Ok(Self::SearchFile), - "search_path" => Ok(Self::SearchPath), - "none" => Ok(Self::None), - _ => Err(anyhow::anyhow!("Invalid function")), - } - } -} - -impl ToString for Function { - fn to_string(&self) -> String { - match self { - Self::SearchCodebase => "search_codebase".to_string(), - Self::SearchFile => "search_file".to_string(), - Self::SearchPath => "search_path".to_string(), - Self::None => "none".to_string(), - } - } +functions_enum! { + Function, + (SearchCodebase, "search_codebase"), + (SearchFile, "search_file"), + (SearchPath, "search_path"), + (Done, "done"), } pub async fn search_codebase( @@ -73,32 +50,25 @@ pub async fn search_file( model: &M, chunks_limit: usize, ) -> Result> { - let file_content = fetch_file_content(repository, path).await?; + let file_content = fetch_file_content(repository, path).await.unwrap_or_default(); + let splitter = text_splitter::TextSplitter::default().with_trim_chunks(true); + let chunks: Vec<&str> = splitter .chunks(&file_content, FILE_CHUNKER_CAPACITY_RANGE) .collect(); - let cleaned_chunks: Vec = chunks - .iter() - .map(|s| s.split_whitespace().collect::>().join(" ")) - .collect(); + + let cleaned_chunks: Vec = clean_chunks(chunks); let chunks_embeddings: Vec = cleaned_chunks .iter() .map(|chunk| model.embed(chunk).unwrap()) .collect(); + let query_embeddings = model.embed(query)?; - let similarities: Vec = chunks_embeddings - .par_iter() - .map(|embedding| { - cosine_similarity( - ArrayView1::from(&query_embeddings), - ArrayView1::from(embedding), - ) - }) - .collect(); - let mut indexed_vec: Vec<(usize, &f32)> = similarities.par_iter().enumerate().collect(); - indexed_vec.par_sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); - let indices: Vec = indexed_vec.iter().map(|x| x.0).take(chunks_limit).collect(); + + let similarities: Vec = similarity_score(chunks_embeddings, query_embeddings); + + let indices = get_top_n_indices(similarities, chunks_limit); let relevant_chunks: Vec = indices .iter() @@ -158,3 +128,31 @@ pub fn relevant_chunks_to_completion_message( function_call: None, } } + +//Remove extra whitespaces from chunks +fn clean_chunks(chunks: Vec<&str>) -> Vec { + chunks + .iter() + .map(|s| s.split_whitespace().collect::>().join(" ")) + .collect() +} + +//Compute cosine similarity between query and file content chunks +fn similarity_score(files_embeddings: Vec, query_embeddings: Embeddings) -> Vec { + files_embeddings + .par_iter() + .map(|embedding| { + cosine_similarity( + ArrayView1::from(&query_embeddings), + ArrayView1::from(embedding), + ) + }) + .collect() +} + +//Get n indices with highest similarity scores +fn get_top_n_indices(similarity_scores: Vec, n: usize) -> Vec { + let mut indexed_vec: Vec<(usize, &f32)> = similarity_scores.par_iter().enumerate().collect(); + indexed_vec.par_sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + indexed_vec.iter().map(|x| x.0).take(n).collect() +} diff --git a/src/utils/macros.rs b/src/utils/macros.rs new file mode 100644 index 0000000..3471cc9 --- /dev/null +++ b/src/utils/macros.rs @@ -0,0 +1,84 @@ +//Custom implementation for SSE Events based on https://crates.io/crates/enum_str +///Example usage +// sse_events! { +// EmbedEvent, +// (FetchRepo, "FETCH_REPO"), +// (EmbedRepo, "EMBED_REPO"), +// (SaveEmbeddings, "SAVE_EMBEDDINGS"), +// (Done, "DONE"), +// } +// tx.send(EmbedEvent::EmbedRepo(Some(json!({ +// "files": files.len(), +// }))).into()) +/// +#[macro_export] +macro_rules! sse_events { + ($name:ident, $(($key:ident, $value:expr),)*) => { + #[derive(Debug, PartialEq)] + pub enum $name + { + $($key(Option)),* + } + + impl From<$name> for Data { + fn from(event: $name) -> Data { + match event { + $( + $name::$key(data) => Data::new(data.unwrap_or_default().to_string()).event($value) + ),* + } + } + } + + impl $name { + + } + + } +} + +///Example usage +// functions_enum!{ +// Function, +// (SearchCodebase, "search_codebase"), +// (SearchFile, "search_file"), +// (SearchPath, "search_path"), +// (Done, "done"), +// } +// Function::from_str("search_codebase").unwrap(); +// Function::SearchCodebase.to_string(); +/// +#[macro_export] +macro_rules! functions_enum { + ($name:ident, $(($key:ident, $value:expr),)*) => { + #[derive(Debug, PartialEq, Clone)] + pub enum $name + { + $($key),* + } + + impl ToString for $name { + fn to_string(&self) -> String { + match self { + $( + &$name::$key => $value.to_string() + ),* + } + } + } + + impl FromStr for $name { + type Err = anyhow::Error; + + fn from_str(val: &str) -> Result { + match val + { + $( + $value => Ok($name::$key) + ),*, + _ => Err(anyhow::anyhow!("Invalid function")) + } + } + } + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 1601cab..05634b7 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,2 +1,2 @@ -pub mod conversation; pub mod functions; +pub mod macros;