diff --git a/src/open_prompt_extension.cpp b/src/open_prompt_extension.cpp index 93684e0..7335b05 100644 --- a/src/open_prompt_extension.cpp +++ b/src/open_prompt_extension.cpp @@ -14,9 +14,47 @@ #include #include #include +#include + #include "yyjson.hpp" +#include + namespace duckdb { + struct OpenPromptData: FunctionData { + idx_t model_idx; + idx_t json_schema_idx; + unique_ptr Copy() const { + auto res = make_uniq(); + res->model_idx = model_idx; + res->json_schema_idx = json_schema_idx; + return res; + }; + bool Equals(const FunctionData &other) const { + return model_idx == other.Cast().model_idx && + json_schema_idx == other.Cast().json_schema_idx; + }; + OpenPromptData() { + model_idx = 0; + json_schema_idx = 0; + } + }; + + unique_ptr OpenPromptBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + auto res = make_uniq(); + for (idx_t i = 1; i < arguments.size(); ++i) { + const auto &argument = arguments[i]; + if (i == 1 && argument->alias.empty()) { + res->model_idx = i; + } else if (argument->alias == "json_schema") { + res->json_schema_idx = i; + } + } + return std::move(res); + } + + static std::pair SetupHttpClient(const std::string &url) { std::string scheme, domain, path; @@ -135,25 +173,36 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t user_prompt) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); auto &context = state.GetContext(); - // Get configuration with defaults std::string api_url = GetConfigValue(context, "openprompt_api_url", "http://localhost:11434/v1/chat/completions"); std::string api_token = GetConfigValue(context, "openprompt_api_token", ""); std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b"); + std::string json_schema; - // Override model if provided as second argument - if (args.data.size() > 1 && !args.data[1].GetValue(0).IsNull()) { - model_name = args.data[1].GetValue(0).ToString(); + if (info.model_idx != 0) { + model_name = args.data[info.model_idx].GetValue(0).ToString(); + } + if (info.json_schema_idx != 0) { + json_schema = args.data[info.json_schema_idx].GetValue(0).ToString(); } std::string request_body = "{"; request_body += "\"model\":\"" + model_name + "\","; + if (!json_schema.empty()) { + request_body += "\"response_format\":{\"type\":\"json_object\", \"schema\":"; + request_body += json_schema; + request_body += "},"; + } request_body += "\"messages\":["; request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},"; request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}"; request_body += "]}"; + printf("%s\n", request_body.c_str()); + try { auto client_and_path = SetupHttpClient(api_url); @@ -233,9 +282,15 @@ static void LoadInternal(DatabaseInstance &instance) { // Register with both single and two-argument variants open_prompt.AddFunction(ScalarFunction( - {LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction)); + {LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction, + OpenPromptBind)); open_prompt.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction)); + {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction, + OpenPromptBind)); + open_prompt.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, + LogicalType::VARCHAR, OpenPromptRequestFunction, + OpenPromptBind)); ExtensionUtil::RegisterFunction(instance, open_prompt);