Skip to content

Commit

Permalink
fixed prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
akorchyn committed Feb 3, 2024
1 parent 62079c8 commit 73ed89a
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions src/openai/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ pub struct OpenAIClient {

#[derive(Clone)]
pub struct Prompt {
messages: Vec<ChatCompletionMessage>,
system_message: ChatCompletionMessage,
user_message: ChatCompletionMessage,
gpt_length: GPTLenght,
}

Expand All @@ -84,9 +85,9 @@ impl OpenAIClient {
&self,
messages: &[Message],
gpt_length: GPTLenght,
) -> Option<Prompt> {
) -> Vec<Prompt> {
if messages.is_empty() {
return None;
return vec![];
}

let system_message = format!(
Expand All @@ -95,6 +96,7 @@ impl OpenAIClient {
gpt_length.to_prompt_text(),
PROMPT_HEADER_FINAL,
);
let system_message_len = system_message.len();
let user_message = |message| ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: chat_completion::Content::Text(message),
Expand All @@ -106,7 +108,7 @@ impl OpenAIClient {
content: chat_completion::Content::Text(system_message),
name: None,
};
let mut prompt = vec![system_message];
let mut prompts: Vec<_> = vec![];
let mut msg = String::new();
for (i, message) in messages.iter().rev().enumerate() {
let new_line = format!(
Expand All @@ -118,29 +120,37 @@ impl OpenAIClient {
.unwrap_or("Unknown".to_string()),
message.text()
);
if msg.len() + new_line.len() > consts::TOKEN_LIMITS_PER_MESSAGE {
if system_message_len + msg.len() + new_line.len() > consts::TOKEN_LIMITS_PER_MESSAGE {
msg.push_str("```");
prompt.push(user_message(msg));
prompts.push(Prompt {
system_message: system_message.clone(),
user_message: user_message(msg),
gpt_length,
});
msg = new_line;
} else {
msg.push_str(&new_line);
}
}
msg.push_str("```");
prompt.push(user_message(msg));
Some(Prompt {
messages: prompt,
prompts.push(Prompt {
system_message,
user_message: user_message(msg),
gpt_length,
})
});
prompts
}

pub fn send_prompt(&self, prompt: Prompt) -> anyhow::Result<ChatCompletionResponse> {
let client: Client = Client::new(self.api_key.clone());

let req = ChatCompletionRequest::new(GPT3_5_TURBO.to_string(), prompt.messages)
.max_tokens(prompt.gpt_length.to_max_tokens())
.temperature(0.5)
.top_p(0.5);
let req = ChatCompletionRequest::new(
GPT3_5_TURBO.to_string(),
vec![prompt.system_message, prompt.user_message],
)
.max_tokens(prompt.gpt_length.to_max_tokens())
.temperature(0.5)
.top_p(0.5);

let result = client.chat_completion(req)?;
if result.choices.is_empty() || result.choices[0].message.content.is_none() {
Expand Down

0 comments on commit 73ed89a

Please sign in to comment.