Skip to content

Commit

Permalink
Fix Variables
Browse files Browse the repository at this point in the history
  • Loading branch information
lmangani authored Oct 26, 2024
1 parent ef858ca commit c63a848
Showing 1 changed file with 126 additions and 91 deletions.
217 changes: 126 additions & 91 deletions src/open_prompt_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
#include <string>
#include <sstream>
#include <mutex>
#include <iostream>

#include <iostream>
#include "yyjson.hpp"

namespace duckdb {

// Helper function to parse URL and setup client
static std::pair<duckdb_httplib_openssl::Client, std::string> SetupHttpClient(const std::string &url) {
std::string scheme, domain, path;
size_t pos = url.find("://");
Expand All @@ -37,7 +36,6 @@ static std::pair<duckdb_httplib_openssl::Client, std::string> SetupHttpClient(co
path = "/";
}

// Create client and set a reasonable timeout (e.g., 10 seconds)
duckdb_httplib_openssl::Client client(domain.c_str());
client.set_read_timeout(10, 0); // 10 seconds
client.set_follow_location(true); // Follow redirects
Expand Down Expand Up @@ -89,127 +87,166 @@ static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std
throw std::runtime_error(err_message);
}

// Settings management
static std::string GetConfigValue(ClientContext &context, const string &var_name, const string &default_value) {
Value value;
auto &config = ClientConfig::GetConfig(context);
if (!config.GetUserVariable(var_name, value) || value.IsNull()) {
return default_value;
}
return value.ToString();
}

// Global settings
static constexpr const char* TOKEN_VAR = "";
static constexpr const char* URL_VAR = "http://localhost:11434/v1/chat/completions";
static constexpr const char* MODEL_VAR = "qwen2.5:0.5b";

// Open Prompt Function
static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) {
auto &context = state.GetContext();
D_ASSERT(args.data.size() == 2);

static void SetConfigValue(DataChunk &args, ExpressionState &state, Vector &result,
const string &var_name, const string &value_type) {
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
[&](string_t user_prompt) {
[&](string_t value) {
try {
// Get configuration from DuckDB variables with explicit variable names
Value url_value;
std::string api_url;
if (context.TryGetCurrentSetting("open_prompt_url", url_value)) {
api_url = url_value.ToString();
} else {
api_url = URL_VAR;
if (value == "" || value.GetSize() == 0) {
throw std::invalid_argument(value_type + " cannot be empty.");
}

ClientConfig::GetConfig(state.GetContext()).SetUserVariable(
var_name,
Value::CreateValue(value.GetString())
);
return StringVector::AddString(result, value_type + " set to: " + value.GetString());
} catch (std::exception &e) {
return StringVector::AddString(result, "Failed to set " + value_type + ": " + e.what());
}
});
}

Value token_value;
std::string api_token;
if (context.TryGetCurrentSetting("open_prompt_token", token_value)) {
api_token = token_value.ToString();
}
static void SetApiToken(DataChunk &args, ExpressionState &state, Vector &result) {
SetConfigValue(args, state, result, "openprompt_api_token", "API token");
}

// Get model name with priority: function argument > variable > default
std::string model_name;
if (!args.data[1].GetValue(0).IsNull()) {
model_name = args.data[1].GetValue(0).ToString();
} else {
Value model_value;
if (context.TryGetCurrentSetting("open_prompt_model", model_value)) {
model_name = model_value.ToString();
} else {
model_name = MODEL_VAR;
}
}
static void SetApiUrl(DataChunk &args, ExpressionState &state, Vector &result) {
SetConfigValue(args, state, result, "openprompt_api_url", "API URL");
}

static void SetModelName(DataChunk &args, ExpressionState &state, Vector &result) {
SetConfigValue(args, state, result, "openprompt_model_name", "Model name");
}

// Debug logging
std::cerr << "Using API URL: " << api_url << std::endl;
std::cerr << "Using model: " << model_name << std::endl;
// Main Function
static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) {
D_ASSERT(args.data.size() >= 1); // At least prompt required

UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
[&](string_t user_prompt) {
auto &context = state.GetContext();

// Get configuration with defaults
std::string api_url = GetConfigValue(context, "openprompt_api_url",
"http://localhost:11434/v1/chat/completions");
std::string api_token = GetConfigValue(context, "openprompt_api_token", "");
std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b");

// Override model if provided as second argument
if (args.data.size() > 1 && !args.data[1].GetValue(0).IsNull()) {
model_name = args.data[1].GetValue(0).ToString();
}

// Construct request body
std::string request_body = "{";
request_body += "\"model\":\"" + model_name + "\",";
request_body += "\"messages\":[";
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},";
request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}";
request_body += "]}";
std::string request_body = "{";
request_body += "\"model\":\"" + model_name + "\",";
request_body += "\"messages\":[";
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},";
request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}";
request_body += "]}";

// Setup HTTP client with the configured URL
try {
auto client_and_path = SetupHttpClient(api_url);
auto &client = client_and_path.first;
auto &path = client_and_path.second;

// Setup headers
duckdb_httplib_openssl::Headers headers;
headers.emplace("Content-Type", "application/json");
if (!api_token.empty()) {
headers.emplace("Authorization", "Bearer " + api_token);
}

// Debug logging before making request
std::cerr << "Making request to path: " << path << std::endl;
std::cerr << "Request body: " << request_body << std::endl;

// Send request
auto res = client.Post(path.c_str(), headers, request_body, "application/json");
if (res && res->status == 200) {
// Extract the first choice's message content from the response
std::string response_body = res->body;
// Debug logging
std::cerr << "Response body: " << response_body << std::endl;

if (!res) {
HandleHttpError(res, "POST");
}

if (res->status != 200) {
throw std::runtime_error("HTTP error " + std::to_string(res->status) + ": " + res->reason);
}

try {
unique_ptr<duckdb_yyjson::yyjson_doc, void(*)(duckdb_yyjson::yyjson_doc *)> doc(
duckdb_yyjson::yyjson_read(res->body.c_str(), res->body.length(), 0),
&duckdb_yyjson::yyjson_doc_free
);

size_t choices_pos = response_body.find("\"choices\":");
if (choices_pos != std::string::npos) {
size_t message_pos = response_body.find("\"message\":", choices_pos);
size_t content_pos = response_body.find("\"content\":\"", message_pos);
if (content_pos != std::string::npos) {
content_pos += 11; // Move past "content":"
size_t content_end = response_body.find("\"", content_pos);
if (content_end != std::string::npos) {
std::string first_message_content = response_body.substr(content_pos, content_end - content_pos);
return StringVector::AddString(result, first_message_content);
}
}
if (!doc) {
throw std::runtime_error("Failed to parse JSON response");
}
throw std::runtime_error("Failed to parse API response");
} else if (res) {
std::string error_msg = "HTTP error: " + std::to_string(res->status);
if (!res->reason.empty()) {
error_msg += " - " + res->reason;

auto root = duckdb_yyjson::yyjson_doc_get_root(doc.get());
if (!root) {
throw std::runtime_error("Invalid JSON response: no root object");
}
if (!res->body.empty()) {
error_msg += "\nResponse body: " + res->body;

auto choices = duckdb_yyjson::yyjson_obj_get(root, "choices");
if (!choices || !duckdb_yyjson::yyjson_is_arr(choices)) {
throw std::runtime_error("Invalid response format: missing choices array");
}
throw std::runtime_error(error_msg);
} else {
HandleHttpError(res, "POST");

auto first_choice = duckdb_yyjson::yyjson_arr_get_first(choices);
if (!first_choice) {
throw std::runtime_error("Empty choices array in response");
}

auto message = duckdb_yyjson::yyjson_obj_get(first_choice, "message");
if (!message) {
throw std::runtime_error("Missing message in response");
}

auto content = duckdb_yyjson::yyjson_obj_get(message, "content");
if (!content) {
throw std::runtime_error("Missing content in response");
}

auto content_str = duckdb_yyjson::yyjson_get_str(content);
if (!content_str) {
throw std::runtime_error("Invalid content in response");
}

return StringVector::AddString(result, content_str);
} catch (std::exception &e) {
throw std::runtime_error("Failed to parse response: " + std::string(e.what()));
}
} catch (std::exception &e) {
// Log error and return original prompt
std::cerr << "Error in OpenPromptRequestFunction: " << e.what() << std::endl;
return StringVector::AddString(result, user_prompt);
// Log error and return error message
return StringVector::AddString(result, "Error: " + std::string(e.what()));
}
return StringVector::AddString(result, user_prompt);
});
}

// LoadInternal function
static void LoadInternal(DatabaseInstance &instance) {
// Register open_prompt function with two arguments: prompt and model
ScalarFunctionSet open_prompt("open_prompt");

// Register with both single and two-argument variants
open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction));
open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction));

ExtensionUtil::RegisterFunction(instance, open_prompt);
}

// Register setting functions
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_api_token", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiToken));
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiUrl));
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName));
}

void OpenPromptExtension::Load(DuckDB &db) {
LoadInternal(*db.instance);
Expand All @@ -227,7 +264,6 @@ std::string OpenPromptExtension::Version() const {
#endif
}


} // namespace duckdb

extern "C" {
Expand All @@ -244,4 +280,3 @@ DUCKDB_EXTENSION_API const char *open_prompt_version() {
#ifndef DUCKDB_EXTENSION_MAIN
#error DUCKDB_EXTENSION_MAIN not defined
#endif

0 comments on commit c63a848

Please sign in to comment.