diff --git a/src/open_prompt_extension.cpp b/src/open_prompt_extension.cpp index 767df49..b4c8155 100644 --- a/src/open_prompt_extension.cpp +++ b/src/open_prompt_extension.cpp @@ -24,15 +24,18 @@ namespace duckdb { struct OpenPromptData: FunctionData { idx_t model_idx; idx_t json_schema_idx; + idx_t json_system_prompt_idx; unique_ptr Copy() const { auto res = make_uniq(); res->model_idx = model_idx; res->json_schema_idx = json_schema_idx; + res->json_system_prompt_idx = json_system_prompt_idx; return res; }; bool Equals(const FunctionData &other) const { return model_idx == other.Cast().model_idx && - json_schema_idx == other.Cast().json_schema_idx; + json_schema_idx == other.Cast().json_schema_idx && + json_system_prompt_idx==other.Cast().json_system_prompt_idx; }; OpenPromptData() { model_idx = 0; @@ -49,6 +52,8 @@ namespace duckdb { res->model_idx = i; } else if (argument->alias == "json_schema") { res->json_schema_idx = i; + } else if (argument->alias == "system_prompt") { + res->json_system_prompt_idx = i; } } return std::move(res); @@ -182,6 +187,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V std::string api_token = GetConfigValue(context, "openprompt_api_token", ""); std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b"); std::string json_schema; + std::string system_prompt; if (info.model_idx != 0) { model_name = args.data[info.model_idx].GetValue(0).ToString(); @@ -189,11 +195,14 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V if (info.json_schema_idx != 0) { json_schema = args.data[info.json_schema_idx].GetValue(0).ToString(); } + if (info.json_system_prompt_idx != 0) { + system_prompt = args.data[info.json_system_prompt_idx].GetValue(0).ToString(); + } - unique_ptr doc( - new duckdb_yyjson::yyjson_mut_doc(), &duckdb_yyjson::yyjson_mut_doc_free); + unique_ptr doc( + duckdb_yyjson::yyjson_mut_doc_new(nullptr), &duckdb_yyjson::yyjson_mut_doc_free); auto obj = duckdb_yyjson::yyjson_mut_obj(doc.get()); + duckdb_yyjson::yyjson_mut_doc_set_root(doc.get(), obj); duckdb_yyjson::yyjson_mut_obj_add(obj, duckdb_yyjson::yyjson_mut_str(doc.get(), "model"), duckdb_yyjson::yyjson_mut_str(doc.get(), model_name.c_str()) @@ -213,23 +222,30 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V } auto messages = duckdb_yyjson::yyjson_mut_arr(doc.get()); string str_messages[2][2] = { - {"system", "You are a helpful assistant."}, + {"system", system_prompt}, {"user", user_prompt.GetString()} }; for (auto message : str_messages) { if (message[1].empty()) { continue; } - auto yymessage = duckdb_yyjson::yyjson_mut_obj(doc.get()); - + auto yymessage = duckdb_yyjson::yyjson_mut_arr_add_obj(doc.get(),messages); + duckdb_yyjson::yyjson_mut_obj_add(yymessage, + duckdb_yyjson::yyjson_mut_str(doc.get(), "role"), + duckdb_yyjson::yyjson_mut_str(doc.get(), message[0].c_str())); + duckdb_yyjson::yyjson_mut_obj_add(yymessage, + duckdb_yyjson::yyjson_mut_str(doc.get(), "content"), + duckdb_yyjson::yyjson_mut_str(doc.get(), message[1].c_str())); } duckdb_yyjson::yyjson_mut_obj_add(obj, duckdb_yyjson::yyjson_mut_str(doc.get(), "messages"), - ) - request_body += "\"messages\":["; - request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},"; - request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}"; - request_body += "]}"; - + messages); + duckdb_yyjson::yyjson_write_err err; + auto request_body = duckdb_yyjson::yyjson_mut_write_opts(doc.get(), 0, nullptr, nullptr, &err); + if (request_body == nullptr) { + throw std::runtime_error(err.msg); + } + string str_request_body(request_body); + free(request_body); try { auto client_and_path = SetupHttpClient(api_url); @@ -242,7 +258,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V headers.emplace("Authorization", "Bearer " + api_token); } - auto res = client.Post(path.c_str(), headers, request_body, "application/json"); + auto res = client.Post(path.c_str(), headers, str_request_body, "application/json"); if (!res) { HandleHttpError(res, "POST"); @@ -314,10 +330,14 @@ static void LoadInternal(DatabaseInstance &instance) { open_prompt.AddFunction(ScalarFunction( {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction, OpenPromptBind)); - open_prompt.AddFunction(ScalarFunction( + open_prompt.AddFunction(ScalarFunction( {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction, OpenPromptBind)); + open_prompt.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, + LogicalType::VARCHAR, OpenPromptRequestFunction, + OpenPromptBind)); ExtensionUtil::RegisterFunction(instance, open_prompt);