Skip to content

Commit

Permalink
yyjson fix
Browse files Browse the repository at this point in the history
  • Loading branch information
akvlad committed Oct 26, 2024
1 parent 9cedcf8 commit f846019
Showing 1 changed file with 35 additions and 15 deletions.
50 changes: 35 additions & 15 deletions src/open_prompt_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,18 @@ namespace duckdb {
struct OpenPromptData: FunctionData {
idx_t model_idx;
idx_t json_schema_idx;
idx_t json_system_prompt_idx;
unique_ptr<FunctionData> Copy() const {
auto res = make_uniq<OpenPromptData>();
res->model_idx = model_idx;
res->json_schema_idx = json_schema_idx;
res->json_system_prompt_idx = json_system_prompt_idx;
return res;
};
bool Equals(const FunctionData &other) const {
return model_idx == other.Cast<OpenPromptData>().model_idx &&
json_schema_idx == other.Cast<OpenPromptData>().json_schema_idx;
json_schema_idx == other.Cast<OpenPromptData>().json_schema_idx &&
json_system_prompt_idx==other.Cast<OpenPromptData>().json_system_prompt_idx;
};
OpenPromptData() {
model_idx = 0;
Expand All @@ -49,6 +52,8 @@ namespace duckdb {
res->model_idx = i;
} else if (argument->alias == "json_schema") {
res->json_schema_idx = i;
} else if (argument->alias == "system_prompt") {
res->json_system_prompt_idx = i;
}
}
return std::move(res);
Expand Down Expand Up @@ -182,18 +187,22 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
std::string api_token = GetConfigValue(context, "openprompt_api_token", "");
std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b");
std::string json_schema;
std::string system_prompt;

if (info.model_idx != 0) {
model_name = args.data[info.model_idx].GetValue(0).ToString();
}
if (info.json_schema_idx != 0) {
json_schema = args.data[info.json_schema_idx].GetValue(0).ToString();
}
if (info.json_system_prompt_idx != 0) {
system_prompt = args.data[info.json_system_prompt_idx].GetValue(0).ToString();
}

unique_ptr<duckdb_yyjson::yyjson_mut_doc,
void (*)(duckdb_yyjson::yyjson_mut_doc*)> doc(
new duckdb_yyjson::yyjson_mut_doc(), &duckdb_yyjson::yyjson_mut_doc_free);
unique_ptr<duckdb_yyjson::yyjson_mut_doc, void (*)(duckdb_yyjson::yyjson_mut_doc*)> doc(
duckdb_yyjson::yyjson_mut_doc_new(nullptr), &duckdb_yyjson::yyjson_mut_doc_free);
auto obj = duckdb_yyjson::yyjson_mut_obj(doc.get());
duckdb_yyjson::yyjson_mut_doc_set_root(doc.get(), obj);
duckdb_yyjson::yyjson_mut_obj_add(obj,
duckdb_yyjson::yyjson_mut_str(doc.get(), "model"),
duckdb_yyjson::yyjson_mut_str(doc.get(), model_name.c_str())
Expand All @@ -213,23 +222,30 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
}
auto messages = duckdb_yyjson::yyjson_mut_arr(doc.get());
string str_messages[2][2] = {
{"system", "You are a helpful assistant."},
{"system", system_prompt},
{"user", user_prompt.GetString()}
};
for (auto message : str_messages) {
if (message[1].empty()) {
continue;
}
auto yymessage = duckdb_yyjson::yyjson_mut_obj(doc.get());

auto yymessage = duckdb_yyjson::yyjson_mut_arr_add_obj(doc.get(),messages);
duckdb_yyjson::yyjson_mut_obj_add(yymessage,
duckdb_yyjson::yyjson_mut_str(doc.get(), "role"),
duckdb_yyjson::yyjson_mut_str(doc.get(), message[0].c_str()));
duckdb_yyjson::yyjson_mut_obj_add(yymessage,
duckdb_yyjson::yyjson_mut_str(doc.get(), "content"),
duckdb_yyjson::yyjson_mut_str(doc.get(), message[1].c_str()));
}
duckdb_yyjson::yyjson_mut_obj_add(obj, duckdb_yyjson::yyjson_mut_str(doc.get(), "messages"),
)
request_body += "\"messages\":[";
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},";
request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}";
request_body += "]}";

messages);
duckdb_yyjson::yyjson_write_err err;
auto request_body = duckdb_yyjson::yyjson_mut_write_opts(doc.get(), 0, nullptr, nullptr, &err);
if (request_body == nullptr) {
throw std::runtime_error(err.msg);
}
string str_request_body(request_body);
free(request_body);

try {
auto client_and_path = SetupHttpClient(api_url);
Expand All @@ -242,7 +258,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
headers.emplace("Authorization", "Bearer " + api_token);
}

auto res = client.Post(path.c_str(), headers, request_body, "application/json");
auto res = client.Post(path.c_str(), headers, str_request_body, "application/json");

if (!res) {
HandleHttpError(res, "POST");
Expand Down Expand Up @@ -314,10 +330,14 @@ static void LoadInternal(DatabaseInstance &instance) {
open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction,
OpenPromptBind));
open_prompt.AddFunction(ScalarFunction(
open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
LogicalType::VARCHAR, OpenPromptRequestFunction,
OpenPromptBind));
open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
LogicalType::VARCHAR, OpenPromptRequestFunction,
OpenPromptBind));

ExtensionUtil::RegisterFunction(instance, open_prompt);

Expand Down

0 comments on commit f846019

Please sign in to comment.