Skip to content

Commit

Permalink
VARIBLE attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
lmangani committed Oct 24, 2024
1 parent 0023b1f commit bcf58e8
Showing 1 changed file with 69 additions and 66 deletions.
135 changes: 69 additions & 66 deletions src/open_prompt_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,114 +147,117 @@ static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std

// Open Prompt Function
static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) {
D_ASSERT(args.data.size() == 2); // Expecting the prompt and model name
auto &context = state.GetContext();
D_ASSERT(args.data.size() == 2);

UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
[&](string_t user_prompt) {
std::string api_url = GetApiUrl(); // Retrieve the API URL from settings
std::string api_token = GetApiToken(); // Retrieve the API Token from settings
std::string model_name;

if (!args.data[1].GetValue(0).IsNull()) {
model_name = args.data[1].GetValue(0).ToString(); // Use passed model name
} else {
model_name = GetModelName(); // Use the default model if none is provided
}
try {
// Get configuration from DuckDB variables with explicit variable names
Value url_value;
std::string api_url;
if (context.TryGetCurrentSetting("open_prompt_url", url_value)) {
api_url = url_value.ToString();
} else {
api_url = "http://localhost:11434/v1/chat/completions";
}

// Manually construct the JSON body as a string. TODO use json parser from extension.
std::string request_body = "{";
request_body += "\"model\":\"" + model_name + "\",";
request_body += "\"messages\":[";
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},";
request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}";
request_body += "]}";
Value token_value;
std::string api_token;
if (context.TryGetCurrentSetting("open_prompt_token", token_value)) {
api_token = token_value.ToString();
}

try {
// Make the POST request
// Get model name with priority: function argument > variable > default
std::string model_name;
if (!args.data[1].GetValue(0).IsNull()) {
model_name = args.data[1].GetValue(0).ToString();
} else {
Value model_value;
if (context.TryGetCurrentSetting("open_prompt_model", model_value)) {
model_name = model_value.ToString();
} else {
model_name = "qwen2.5:0.5b";
}
}

// Debug logging
std::cerr << "Using API URL: " << api_url << std::endl;
std::cerr << "Using model: " << model_name << std::endl;

// Construct request body
std::string request_body = "{";
request_body += "\"model\":\"" + model_name + "\",";
request_body += "\"messages\":[";
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},";
request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}";
request_body += "]}";

// Setup HTTP client with the configured URL
auto client_and_path = SetupHttpClient(api_url);
auto &client = client_and_path.first;
auto &path = client_and_path.second;

// Setup headers
duckdb_httplib_openssl::Headers header_map;
header_map.emplace("Content-Type", "application/json");
duckdb_httplib_openssl::Headers headers;
headers.emplace("Content-Type", "application/json");
if (!api_token.empty()) {
header_map.emplace("Authorization", "Bearer " + api_token);
headers.emplace("Authorization", "Bearer " + api_token);
}

// Send the request
auto res = client.Post(path.c_str(), header_map, request_body, "application/json");
// Debug logging before making request
std::cerr << "Making request to path: " << path << std::endl;
std::cerr << "Request body: " << request_body << std::endl;

// Send request
auto res = client.Post(path.c_str(), headers, request_body, "application/json");
if (res && res->status == 200) {
// Extract the first choice's message content from the response
std::string response_body = res->body;
// Debug logging
std::cerr << "Response body: " << response_body << std::endl;

size_t choices_pos = response_body.find("\"choices\":");
if (choices_pos != std::string::npos) {
size_t message_pos = response_body.find("\"message\":", choices_pos);
size_t content_pos = response_body.find("\"content\":\"", message_pos);
if (content_pos != std::string::npos) {
content_pos += 11; // Move to the start of the content value
content_pos += 11; // Move past "content":"
size_t content_end = response_body.find("\"", content_pos);
if (content_end != std::string::npos) {
std::string first_message_content = response_body.substr(content_pos, content_end - content_pos);
return StringVector::AddString(result, first_message_content);
}
}
}
throw std::runtime_error("Failed to parse the first message content in the API response.");
throw std::runtime_error("Failed to parse API response");
} else if (res) {
std::string error_msg = "HTTP error: " + std::to_string(res->status);
if (!res->reason.empty()) {
error_msg += " - " + res->reason;
}
if (!res->body.empty()) {
error_msg += "\nResponse body: " + res->body;
}
throw std::runtime_error(error_msg);
} else {
throw std::runtime_error("HTTP POST error: " + std::to_string(res->status) + " - " + res->reason);
HandleHttpError(res, "POST");
}
} catch (std::exception &e) {
// In case of any error, return the original input text to avoid disruption
// Log error and return original prompt
std::cerr << "Error in OpenPromptRequestFunction: " << e.what() << std::endl;
return StringVector::AddString(result, user_prompt);
}
return StringVector::AddString(result, user_prompt);
});
}


static void LoadInternal(DatabaseInstance &instance) {
// Register open_prompt function with two arguments: prompt and model
ScalarFunctionSet open_prompt("open_prompt");
open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction));
ExtensionUtil::RegisterFunction(instance, open_prompt);

// Other set_* functions remain the same as before
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_api_token", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
[](DataChunk &args, ExpressionState &state, Vector &result) {
try {
auto token = args.data[0].GetValue(0).ToString();
SetApiToken(token);
return StringVector::AddString(result, "API token set successfully.");
} catch (std::exception &e) {
return StringVector::AddString(result, "Failed to set API token: " + std::string(e.what()));
}
}));

ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
[](DataChunk &args, ExpressionState &state, Vector &result) {
try {
auto new_url = args.data[0].GetValue(0).ToString();
SetApiUrl(new_url);
return StringVector::AddString(result, "API URL set successfully.");
} catch (std::exception &e) {
return StringVector::AddString(result, "Failed to set API URL: " + std::string(e.what()));
}
}));

ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
[](DataChunk &args, ExpressionState &state, Vector &result) {
try {
auto model = args.data[0].GetValue(0).ToString();
SetModelName(model);
return StringVector::AddString(result, "Model name set successfully.");
} catch (std::exception &e) {
return StringVector::AddString(result, "Failed to set model name: " + std::string(e.what()));
}
}));
}


Expand Down

0 comments on commit bcf58e8

Please sign in to comment.