Skip to content

Commit

Permalink
Rudimental parser for completions
Browse files Browse the repository at this point in the history
  • Loading branch information
lmangani authored Oct 21, 2024
1 parent 8902140 commit 7ad5673
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions src/open_prompt_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
std::string api_token = GetApiToken(); // Retrieve the API Token from settings
std::string model_name = GetModelName(); // Retrieve the model name from settings

// Prepare the JSON body
// Manually construct the JSON body as a string
std::string request_body = "{";
request_body += "\"model\":\"" + model_name + "\",";
request_body += "\"messages\":[";
Expand All @@ -156,7 +156,22 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
// Send the request
auto res = client.Post(path.c_str(), header_map, request_body, "application/json");
if (res && res->status == 200) {
return StringVector::AddString(result, res->body);
// Extract the first choice's message content from the response
std::string response_body = res->body;
size_t choices_pos = response_body.find("\"choices\":");
if (choices_pos != std::string::npos) {
size_t message_pos = response_body.find("\"message\":", choices_pos);
size_t content_pos = response_body.find("\"content\":\"", message_pos);
if (content_pos != std::string::npos) {
content_pos += 11; // Move to the start of the content value
size_t content_end = response_body.find("\"", content_pos);
if (content_end != std::string::npos) {
std::string first_message_content = response_body.substr(content_pos, content_end - content_pos);
return StringVector::AddString(result, first_message_content);
}
}
}
throw std::runtime_error("Failed to parse the first message content in the API response.");
} else {
throw std::runtime_error("HTTP POST error: " + std::to_string(res->status) + " - " + res->reason);
}
Expand All @@ -167,7 +182,6 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
});
}


static void LoadInternal(DatabaseInstance &instance) {
// Register open_prompt function
ScalarFunctionSet open_prompt("open_prompt");
Expand Down

0 comments on commit 7ad5673

Please sign in to comment.