Skip to content

Commit

Permalink
chore: functions macro, restructure
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Jul 21, 2023
1 parent 9b5a658 commit 800e983
Show file tree
Hide file tree
Showing 11 changed files with 247 additions and 205 deletions.
4 changes: 1 addition & 3 deletions src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ pub const EMBEDDINGS_DIMENSION: usize = 384;

pub const SSE_CHANNEL_BUFFER_SIZE: usize = 1;

pub const FINAL_EXPLANATION_TEMPERATURE: f64 = 0.7;

pub const FUNCTIONS_CALLS_TEMPERATURE: f64 = 0.5;
pub const CHAT_COMPLETION_TEMPERATURE: f64 = 0.5;

pub const ACTIX_WEB_SERVER_PORT: usize = 3000;

Expand Down
63 changes: 63 additions & 0 deletions src/conversation/data.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use crate::prelude::*;
use crate::{github::Repository, utils::functions::Function};
use openai_api_rs::v1::chat_completion::FunctionCall;
use serde::Deserialize;
use std::str::FromStr;

#[derive(Deserialize)]
pub struct Query {
pub repository: Repository,
pub query: String,
}

impl ToString for Query {
fn to_string(&self) -> String {
let Query {
repository:
Repository {
owner,
name,
branch,
},
query,
} = self;
format!(
"##Repository Info##\nOwner:{}\nName:{}\nBranch:{}\n##User Query##\nQuery:{}",
owner, name, branch, query
)
}
}

#[derive(Debug)]
pub struct RelevantChunk {
pub path: String,
pub content: String,
}

impl ToString for RelevantChunk {
fn to_string(&self) -> String {
format!(
"##Relevant file chunk##\nPath argument:{}\nRelevant content: {}",
self.path,
self.content.trim()
)
}
}

#[derive(Debug, Clone)]
pub struct ParsedFunctionCall {
pub name: Function,
pub args: serde_json::Value,
}

impl TryFrom<&FunctionCall> for ParsedFunctionCall {
type Error = anyhow::Error;

fn try_from(func: &FunctionCall) -> Result<Self> {
let func = func.clone();
let name = Function::from_str(&func.name.unwrap_or("done".into()))?;
let args = func.arguments.unwrap_or("{}".to_string());
let args = serde_json::from_str::<serde_json::Value>(&args)?;
Ok(ParsedFunctionCall { name, args })
}
}
86 changes: 12 additions & 74 deletions src/utils/conversation/mod.rs → src/conversation/mod.rs
Original file line number Diff line number Diff line change
@@ -1,82 +1,35 @@
#![allow(unused_must_use)]
mod data;
mod prompts;

use crate::{
prelude::*,
constants::{RELEVANT_CHUNKS_LIMIT, RELEVANT_FILES_LIMIT},
db::RepositoryEmbeddingsDB,
embeddings::EmbeddingsModel,
github::Repository,
prelude::*,
routes::events::{emit, QueryEvent},
};
use actix_web_lab::sse::Sender;
use openai_api_rs::v1::chat_completion::{FinishReason, FunctionCall};
pub use data::*;
use openai_api_rs::v1::chat_completion::FinishReason;
use openai_api_rs::v1::{
api::Client,
chat_completion::{
ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, MessageRole,
},
};
use serde::Deserialize;
use std::env;
use std::str::FromStr;
use std::sync::Arc;

use prompts::{generate_completion_request, system_message};

use self::prompts::answer_generation_prompt;

use super::functions::{
use crate::utils::functions::{
paths_to_completion_message, relevant_chunks_to_completion_message, search_codebase,
search_file, search_path, Function,
};

#[derive(Deserialize)]
pub struct Query {
pub repository: Repository,
pub query: String,
}

impl ToString for Query {
fn to_string(&self) -> String {
let Query {
repository:
Repository {
owner,
name,
branch,
},
query,
} = self;
format!(
"##Repository Info##\nOwner:{}\nName:{}\nBranch:{}\n##User Query##\nQuery:{}",
owner, name, branch, query
)
}
}

#[derive(Debug)]
pub struct RelevantChunk {
pub path: String,
pub content: String,
}

impl ToString for RelevantChunk {
fn to_string(&self) -> String {
format!(
"##Relevant file chunk##\nPath argument:{}\nRelevant content: {}",
self.path,
self.content.trim()
)
}
}

#[derive(Debug, Clone)]
struct ParsedFunctionCall {
name: Function,
args: serde_json::Value,
}

pub struct Conversation<D: RepositoryEmbeddingsDB, M: EmbeddingsModel> {
query: Query,
client: Client,
Expand Down Expand Up @@ -133,15 +86,16 @@ impl<D: RepositoryEmbeddingsDB, M: EmbeddingsModel> Conversation<D, M> {
#[allow(unused_labels)]
'conversation: loop {
//Generate a request with the message history and functions
let request = generate_completion_request(self.messages.clone(), true);
let request = generate_completion_request(self.messages.clone(), "auto");

match self.send_request(request).await {
Ok(response) => {
if let FinishReason::function_call = response.choices[0].finish_reason {
if let Some(function_call) =
response.choices[0].message.function_call.clone()
{
let parsed_function_call = parse_function_call(&function_call)?;
let parsed_function_call =
ParsedFunctionCall::try_from(&function_call)?;
let function_call_message = ChatCompletionMessage {
name: None,
function_call: Some(function_call),
Expand Down Expand Up @@ -229,22 +183,17 @@ impl<D: RepositoryEmbeddingsDB, M: EmbeddingsModel> Conversation<D, M> {
);
self.append_message(completion_message);
}
Function::None => {
Function::Done => {
self.prepare_final_explanation_message();

//Generate a request with the message history and no functions
let request =
generate_completion_request(self.messages.clone(), false);
emit(
&self.sender,
QueryEvent::GenerateResponse(Some(
parsed_function_call.args,
)),
)
.await;
generate_completion_request(self.messages.clone(), "none");
emit(&self.sender, QueryEvent::GenerateResponse(None)).await;
let response = match self.send_request(request).await {
Ok(response) => response,
Err(e) => {
dbg!(e.to_string());
return Err(e);
}
};
Expand All @@ -268,14 +217,3 @@ impl<D: RepositoryEmbeddingsDB, M: EmbeddingsModel> Conversation<D, M> {
}
}
}

fn parse_function_call(func: &FunctionCall) -> Result<ParsedFunctionCall> {
let func = func.clone();
let function_name = Function::from_str(&func.name.unwrap_or("none".into()))?;
let function_args = func.arguments.unwrap_or("{}".to_string());
let function_args = serde_json::from_str::<serde_json::Value>(&function_args)?;
Ok(ParsedFunctionCall {
name: function_name,
args: function_args,
})
}
84 changes: 33 additions & 51 deletions src/utils/conversation/prompts.rs → src/conversation/prompts.rs
Original file line number Diff line number Diff line change
@@ -1,66 +1,48 @@
use openai_api_rs::v1::chat_completion::{
ChatCompletionMessage, ChatCompletionRequest, Function, FunctionParameters, JSONSchemaDefine,
JSONSchemaType, GPT3_5_TURBO,
ChatCompletionMessage, ChatCompletionRequest, Function as F, FunctionParameters,
JSONSchemaDefine, JSONSchemaType, GPT3_5_TURBO,
};
use std::collections::HashMap;

use crate::constants::{FINAL_EXPLANATION_TEMPERATURE, FUNCTIONS_CALLS_TEMPERATURE};
use crate::{constants::CHAT_COMPLETION_TEMPERATURE, utils::functions::Function};

pub fn generate_completion_request(
messages: Vec<ChatCompletionMessage>,
with_functions: bool,
function_call: &str,
) -> ChatCompletionRequest {
//All the chat completion requests will have functions except for the final explanation request
if with_functions {
ChatCompletionRequest {
model: GPT3_5_TURBO.into(),
messages,
functions: Some(functions()),
function_call: None,
temperature: Some(FUNCTIONS_CALLS_TEMPERATURE),
top_p: None,
n: None,
stream: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
}
} else {
ChatCompletionRequest {
model: GPT3_5_TURBO.into(),
messages,
functions: None,
function_call: None,
temperature: Some(FINAL_EXPLANATION_TEMPERATURE),
top_p: None,
n: None,
stream: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
}
// https://platform.openai.com/docs/api-reference/chat/create
ChatCompletionRequest {
model: GPT3_5_TURBO.into(),
messages,
functions: Some(functions()),
function_call: Some(function_call.into()),
temperature: Some(CHAT_COMPLETION_TEMPERATURE),
top_p: None,
n: None,
stream: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
}
}

pub fn functions() -> Vec<Function> {
// https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions
pub fn functions() -> Vec<F> {
vec![
Function {
name: "none".into(),
F {
name: Function::Done.to_string(),
description: Some("This is the final step, and signals that you have enough information to respond to the user's query.".into()),
parameters: Some(FunctionParameters {
schema_type: JSONSchemaType::Object,
properties: Some(HashMap::new()),
required: None,
}),
},
Function {
name: "search_codebase".into(),
F {
name: Function::SearchCodebase.to_string(),
description: Some("Search the contents of files in a repository semantically. Results will not necessarily match search terms exactly, but should be related.".into()),
parameters: Some(FunctionParameters {
schema_type: JSONSchemaType::Object,
Expand All @@ -77,8 +59,8 @@ pub fn functions() -> Vec<Function> {
required: Some(vec!["query".into()]),
})
},
Function {
name: "search_path".into(),
F {
name: Function::SearchPath.to_string(),
description: Some("Search the pathnames in a repository. Results may not be exact matches, but will be similar by some edit-distance. Use when you want to find a specific file".into()),
parameters: Some(FunctionParameters {
schema_type: JSONSchemaType::Object,
Expand All @@ -95,8 +77,8 @@ pub fn functions() -> Vec<Function> {
required: Some(vec!["path".into()]),
})
},
Function {
name: "search_file".into(),
F {
name: Function::SearchFile.to_string(),
description: Some("Search a file returned from functions.search_path. Results will not necessarily match search terms exactly, but should be related.".into()),
parameters: Some(FunctionParameters {
schema_type: JSONSchemaType::Object,
Expand Down Expand Up @@ -130,13 +112,13 @@ pub fn system_message() -> String {
Follow these rules at all times:
- Respond with functions to find information related to the query, until all relevant information has been found.
- If the output of a function is not relevant or sufficient, try the same function again with different arguments or try using a different function
- When you have enough information to answer the user's query respond with functions.none
- When you have enough information to answer the user's query respond with functions.done
- Do not assume the structure of the codebase, or the existence of files or folders
- Never respond with a function that you've used before with the same arguments
- Do NOT respond with functions.search_file unless you have already called functions.search_path
- If after making a path search the query can be answered by the existance of the paths, use the functions.none function
- If after making a path search the query can be answered by the existance of the paths, use the functions.done function
- Only refer to paths that are returned by the functions.search_path function when calling functions.search_file
- If after attempting to gather information you are still unsure how to answer the query, respond with the functions.none function
- If after attempting to gather information you are still unsure how to answer the query, respond with the functions.done function
- Always respond with a function call. Do NOT answer the question directly"#,
)
}
Expand Down
8 changes: 3 additions & 5 deletions src/github/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ pub async fn fetch_file_content(repository: &Repository, path: &str) -> Result<S
let content = response.text().await?;
Ok(content)
} else {
Ok(String::new())
Err(anyhow::anyhow!("Unable to fetch file content"))
}
}

Expand Down Expand Up @@ -195,10 +195,8 @@ mod tests {

let result = fetch_file_content(&repository, path).await;

//Assert that the function returns Result containing an empty string for invalid file path
assert!(result.is_ok());
let content = result.unwrap();
assert!(content.len() == 0);
//Assert that the function returns Err for an invalid file path
assert!(result.is_err());
}

#[test]
Expand Down
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod constants;
mod conversation;
mod db;
mod embeddings;
mod github;
Expand Down
Loading

0 comments on commit 800e983

Please sign in to comment.