diff --git a/docs/README.md b/docs/README.md index 668abcd..34f8ffe 100644 --- a/docs/README.md +++ b/docs/README.md @@ -10,12 +10,13 @@ Simple extension to query OpenAI Completion API endpoints such as Ollama - `set_api_token(auth_token)` - `set_api_url(completions_url)` - `set_model_name(model_name)` +- `set_json_schema(json_schema)` ### Settings Setup the completions API configuration w/ optional auth token and model name ```sql SET VARIABLE openprompt_api_url = 'http://localhost:11434/v1/chat/completions'; -SET VARIABLE openprompt_api_token = 'your_api_key_here'; +SET VARIABLE openprompt_api_token = 'optional_api_key_here'; SET VARIABLE openprompt_model_name = 'qwen2.5:0.5b'; ``` @@ -31,6 +32,46 @@ D SELECT open_prompt('Write a one-line poem about ducks') AS response; └────────────────────────────────────────────────┘ ``` +#### JSON Structured Output _(very experimental)_ +Define a `json_schema` to receive a structured response in JSON format... most of the time. + +```javascript +{ + summary: 'VARCHAR', + favourite_animals:='VARCHAR[]', + favourite_activity:='VARCHAR[]', + star_rating:='INTEGER'}, + struct_descr:={star_rating: 'rating on a scale from 1 (bad) to 5 (very good)' +} +``` + +Prompt based. Output depends on model skills. + +```sql +SET VARIABLE openprompt_model_name = 'qwen2.5:1.5b'; +D SET VARIABLE openprompt_json_schema = "struct:={summary: 'VARCHAR', favourite_animals:='VARCHAR[]', favourite_activity:='VARCHAR[]', star_rating:='INTEGER'}, struct_descr:={star_rating: 'visit rating on a scale from 1 (bad) to 5 (very good)'}"; +D SELECT open_prompt('My zoo visit was fun and I loved the bears and tigers. i also had icecream') AS response; +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ response │ +│ varchar │ +├─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│ {"summary": "A short summary of your recent zoo visit activity.", "favourite_animals": ["bears", "tigers"], "favourite_activity": ["icecream"], "sta … │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +``` + +The results can be validated and parsed with the DuckDB `JSON` extension and its functions + +```sql +D LOAD json; ^ +D WITH response AS (SELECT open_prompt('My zoo visit was fun and I loved the bears and tigers. i also had icecream') AS response) SELECT json_structure(response) FROM response; +┌────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ json_structure(response) │ +│ json │ +├────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│ {"summary":"VARCHAR","favourite_animals":"VARCHAR","favourite_activity":"VARCHAR","star_rating":"UBIGINT"} │ +└────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +``` +
diff --git a/src/open_prompt_extension.cpp b/src/open_prompt_extension.cpp index 7335b05..fe713e1 100644 --- a/src/open_prompt_extension.cpp +++ b/src/open_prompt_extension.cpp @@ -167,6 +167,10 @@ static void SetModelName(DataChunk &args, ExpressionState &state, Vector &result SetConfigValue(args, state, result, "openprompt_model_name", "Model name"); } +static void SetJsonSchema(DataChunk &args, ExpressionState &state, Vector &result) { + SetConfigValue(args, state, result, "openprompt_json_schema", "JSON Schema"); +} + // Main Function static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.data.size() >= 1); // At least prompt required @@ -181,6 +185,8 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V "http://localhost:11434/v1/chat/completions"); std::string api_token = GetConfigValue(context, "openprompt_api_token", ""); std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b"); + std::string prompt_json_schema = GetConfigValue(context, "openprompt_json_schema", ""); + std::string json_schema; if (info.model_idx != 0) { @@ -198,7 +204,11 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V request_body += "},"; } request_body += "\"messages\":["; - request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},"; + if (!prompt_json_schema.empty()) { + request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant. Summarize and Output JSON format (without any omissions): " + prompt_json_schema + "\"},"; + } else { + request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},"; + } request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}"; request_body += "]}"; printf("%s\n", request_body.c_str()); @@ -216,11 +226,11 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V } auto res = client.Post(path.c_str(), headers, request_body, "application/json"); - + if (!res) { HandleHttpError(res, "POST"); } - + if (res->status != 200) { throw std::runtime_error("HTTP error " + std::to_string(res->status) + ": " + res->reason); } @@ -230,7 +240,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V duckdb_yyjson::yyjson_read(res->body.c_str(), res->body.length(), 0), &duckdb_yyjson::yyjson_doc_free ); - + if (!doc) { throw std::runtime_error("Failed to parse JSON response"); } @@ -301,6 +311,8 @@ static void LoadInternal(DatabaseInstance &instance) { "set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiUrl)); ExtensionUtil::RegisterFunction(instance, ScalarFunction( "set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName)); + ExtensionUtil::RegisterFunction(instance, ScalarFunction( + "set_json_schema", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetJsonSchema)); } void OpenPromptExtension::Load(DuckDB &db) {