Skip to content

Commit

Permalink
Merge branch 'main' into json_schema
Browse files Browse the repository at this point in the history
  • Loading branch information
lmangani authored Oct 26, 2024
2 parents 92aa770 + fbad951 commit 520874c
Showing 1 changed file with 65 additions and 9 deletions.
74 changes: 65 additions & 9 deletions src/open_prompt_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,47 @@
#include <sstream>
#include <mutex>
#include <iostream>
#include <duckdb/planner/expression/bound_function_expression.hpp>

#include "yyjson.hpp"

#include<stdio.h>

namespace duckdb {
struct OpenPromptData: FunctionData {
idx_t model_idx;
idx_t json_schema_idx;
unique_ptr<FunctionData> Copy() const {
auto res = make_uniq<OpenPromptData>();
res->model_idx = model_idx;
res->json_schema_idx = json_schema_idx;
return res;
};
bool Equals(const FunctionData &other) const {
return model_idx == other.Cast<OpenPromptData>().model_idx &&
json_schema_idx == other.Cast<OpenPromptData>().json_schema_idx;
};
OpenPromptData() {
model_idx = 0;
json_schema_idx = 0;
}
};

unique_ptr<FunctionData> OpenPromptBind(ClientContext &context, ScalarFunction &bound_function,
vector<unique_ptr<Expression>> &arguments) {
auto res = make_uniq<OpenPromptData>();
for (idx_t i = 1; i < arguments.size(); ++i) {
const auto &argument = arguments[i];
if (i == 1 && argument->alias.empty()) {
res->model_idx = i;
} else if (argument->alias == "json_schema") {
res->json_schema_idx = i;
}
}
return std::move(res);
}



static std::pair<duckdb_httplib_openssl::Client, std::string> SetupHttpClient(const std::string &url) {
std::string scheme, domain, path;
Expand Down Expand Up @@ -139,30 +177,42 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V

UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
[&](string_t user_prompt) {
auto &func_expr = state.expr.Cast<BoundFunctionExpression>();
auto &info = func_expr.bind_info->Cast<OpenPromptData>();
auto &context = state.GetContext();

// Get configuration with defaults
std::string api_url = GetConfigValue(context, "openprompt_api_url",
"http://localhost:11434/v1/chat/completions");
std::string api_token = GetConfigValue(context, "openprompt_api_token", "");
std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b");
std::string json_schema = GetConfigValue(context, "openprompt_json_schema", "");
std::string prompt_json_schema = GetConfigValue(context, "openprompt_json_schema", "");

std::string json_schema;

// Override model if provided as second argument
if (args.data.size() > 1 && !args.data[1].GetValue(0).IsNull()) {
model_name = args.data[1].GetValue(0).ToString();
if (info.model_idx != 0) {
model_name = args.data[info.model_idx].GetValue(0).ToString();
}
if (info.json_schema_idx != 0) {
json_schema = args.data[info.json_schema_idx].GetValue(0).ToString();
}

std::string request_body = "{";
request_body += "\"model\":\"" + model_name + "\",";
request_body += "\"messages\":[";
if (!json_schema.empty()) {
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant. Summarize and Output JSON format (without any omissions): " + json_schema + "\"},";
request_body += "\"response_format\":{\"type\":\"json_object\", \"schema\":";
request_body += json_schema;
request_body += "},";
}
request_body += "\"messages\":[";
if (!prompt_json_schema.empty()) {
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant. Summarize and Output JSON format (without any omissions): " + prompt_json_schema + "\"},";
} else {
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},";
}
request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}";
request_body += "]}";
printf("%s\n", request_body.c_str());


try {
auto client_and_path = SetupHttpClient(api_url);
Expand Down Expand Up @@ -242,9 +292,15 @@ static void LoadInternal(DatabaseInstance &instance) {

// Register with both single and two-argument variants
open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction));
{LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction,
OpenPromptBind));
open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction));
{LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction,
OpenPromptBind));
open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
LogicalType::VARCHAR, OpenPromptRequestFunction,
OpenPromptBind));

ExtensionUtil::RegisterFunction(instance, open_prompt);

Expand Down

0 comments on commit 520874c

Please sign in to comment.