Skip to content

Commit

Permalink
Merge branch 'main' into https-support
Browse files Browse the repository at this point in the history
  • Loading branch information
lmangani authored Dec 22, 2024
2 parents 8fdca81 + 780e38f commit 57c0b65
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions src/open_prompt_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,8 @@ namespace duckdb {
client_url = domain;
}

// Create client and set a reasonable timeout (e.g., 10 seconds)
duckdb_httplib_openssl::Client client(client_url);
client.set_read_timeout(10, 0); // 10 seconds
client.set_read_timeout(20, 0); // 20 seconds
client.set_follow_location(true); // Follow redirects

return std::make_pair(std::move(client), path);
Expand Down Expand Up @@ -179,11 +178,15 @@ namespace duckdb {
SetConfigValue(args, state, result, "openprompt_api_url", "API URL");
}

static void SetApiTimeout(DataChunk &args, ExpressionState &state, Vector &result) {
SetConfigValue(args, state, result, "openprompt_api_timeout", "API timeout");
}

static void SetModelName(DataChunk &args, ExpressionState &state, Vector &result) {
SetConfigValue(args, state, result, "openprompt_model_name", "Model name");
}

// Main Function
// Complete OpenPromptRequestFunction
static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) {
D_ASSERT(args.data.size() >= 1); // At least prompt required

Expand All @@ -192,11 +195,12 @@ namespace duckdb {
auto &func_expr = state.expr.Cast<BoundFunctionExpression>();
auto &info = func_expr.bind_info->Cast<OpenPromptData>();
auto &context = state.GetContext();
// Get configuration with defaults

std::string api_url = GetConfigValue(context, "openprompt_api_url",
"http://localhost:11434/v1/chat/completions");
std::string api_token = GetConfigValue(context, "openprompt_api_token", "");
std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b");
std::string api_timeout = GetConfigValue(context, "openprompt_api_timeout", "");
std::string json_schema;
std::string system_prompt;

Expand Down Expand Up @@ -269,12 +273,16 @@ namespace duckdb {
headers.emplace("Authorization", "Bearer " + api_token);
}

if (!api_timeout.empty()) {
client.set_read_timeout(stoi(api_timeout), 0);
}

auto res = client.Post(path.c_str(), headers, str_request_body, "application/json");

if (!res) {
HandleHttpError(res, "POST");
}

if (res->status != 200) {
throw std::runtime_error("HTTP error " + std::to_string(res->status) + ": " + res->reason);
}
Expand All @@ -284,7 +292,7 @@ namespace duckdb {
duckdb_yyjson::yyjson_read(res->body.c_str(), res->body.length(), 0),
&duckdb_yyjson::yyjson_doc_free
);

if (!doc) {
throw std::runtime_error("Failed to parse JSON response");
}
Expand Down Expand Up @@ -324,17 +332,15 @@ namespace duckdb {
throw std::runtime_error("Failed to parse response: " + std::string(e.what()));
}
} catch (std::exception &e) {
// Log error and return error message
return StringVector::AddString(result, "Error: " + std::string(e.what()));
}
});
}

// LoadInternal function
// Complete LoadInternal function
static void LoadInternal(DatabaseInstance &instance) {
ScalarFunctionSet open_prompt("open_prompt");

// Register with both single and two-argument variants

open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction,
OpenPromptBind));
Expand All @@ -349,16 +355,17 @@ namespace duckdb {
{LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
LogicalType::VARCHAR, OpenPromptRequestFunction,
OpenPromptBind));

ExtensionUtil::RegisterFunction(instance, open_prompt);

// Register setting functions
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_api_token", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiToken));
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiUrl));
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName));
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_api_timeout", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiTimeout));
}

void OpenPromptExtension::Load(DuckDB &db) {
Expand Down

0 comments on commit 57c0b65

Please sign in to comment.