diff --git a/docs/README.md b/docs/README.md
index 668abcd..34f8ffe 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -10,12 +10,13 @@ Simple extension to query OpenAI Completion API endpoints such as Ollama
- `set_api_token(auth_token)`
- `set_api_url(completions_url)`
- `set_model_name(model_name)`
+- `set_json_schema(json_schema)`
### Settings
Setup the completions API configuration w/ optional auth token and model name
```sql
SET VARIABLE openprompt_api_url = 'http://localhost:11434/v1/chat/completions';
-SET VARIABLE openprompt_api_token = 'your_api_key_here';
+SET VARIABLE openprompt_api_token = 'optional_api_key_here';
SET VARIABLE openprompt_model_name = 'qwen2.5:0.5b';
```
@@ -31,6 +32,46 @@ D SELECT open_prompt('Write a one-line poem about ducks') AS response;
└────────────────────────────────────────────────┘
```
+#### JSON Structured Output _(very experimental)_
+Define a `json_schema` to receive a structured response in JSON format... most of the time.
+
+```javascript
+{
+ summary: 'VARCHAR',
+ favourite_animals:='VARCHAR[]',
+ favourite_activity:='VARCHAR[]',
+ star_rating:='INTEGER'},
+ struct_descr:={star_rating: 'rating on a scale from 1 (bad) to 5 (very good)'
+}
+```
+
+Prompt based. Output depends on model skills.
+
+```sql
+SET VARIABLE openprompt_model_name = 'qwen2.5:1.5b';
+D SET VARIABLE openprompt_json_schema = "struct:={summary: 'VARCHAR', favourite_animals:='VARCHAR[]', favourite_activity:='VARCHAR[]', star_rating:='INTEGER'}, struct_descr:={star_rating: 'visit rating on a scale from 1 (bad) to 5 (very good)'}";
+D SELECT open_prompt('My zoo visit was fun and I loved the bears and tigers. i also had icecream') AS response;
+┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
+│ response │
+│ varchar │
+├─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
+│ {"summary": "A short summary of your recent zoo visit activity.", "favourite_animals": ["bears", "tigers"], "favourite_activity": ["icecream"], "sta … │
+└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
+```
+
+The results can be validated and parsed with the DuckDB `JSON` extension and its functions
+
+```sql
+D LOAD json; ^
+D WITH response AS (SELECT open_prompt('My zoo visit was fun and I loved the bears and tigers. i also had icecream') AS response) SELECT json_structure(response) FROM response;
+┌────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
+│ json_structure(response) │
+│ json │
+├────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
+│ {"summary":"VARCHAR","favourite_animals":"VARCHAR","favourite_activity":"VARCHAR","star_rating":"UBIGINT"} │
+└────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
+```
+
diff --git a/src/open_prompt_extension.cpp b/src/open_prompt_extension.cpp
index 7335b05..fe713e1 100644
--- a/src/open_prompt_extension.cpp
+++ b/src/open_prompt_extension.cpp
@@ -167,6 +167,10 @@ static void SetModelName(DataChunk &args, ExpressionState &state, Vector &result
SetConfigValue(args, state, result, "openprompt_model_name", "Model name");
}
+static void SetJsonSchema(DataChunk &args, ExpressionState &state, Vector &result) {
+ SetConfigValue(args, state, result, "openprompt_json_schema", "JSON Schema");
+}
+
// Main Function
static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) {
D_ASSERT(args.data.size() >= 1); // At least prompt required
@@ -181,6 +185,8 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
"http://localhost:11434/v1/chat/completions");
std::string api_token = GetConfigValue(context, "openprompt_api_token", "");
std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b");
+ std::string prompt_json_schema = GetConfigValue(context, "openprompt_json_schema", "");
+
std::string json_schema;
if (info.model_idx != 0) {
@@ -198,7 +204,11 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
request_body += "},";
}
request_body += "\"messages\":[";
- request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},";
+ if (!prompt_json_schema.empty()) {
+ request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant. Summarize and Output JSON format (without any omissions): " + prompt_json_schema + "\"},";
+ } else {
+ request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},";
+ }
request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}";
request_body += "]}";
printf("%s\n", request_body.c_str());
@@ -216,11 +226,11 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
}
auto res = client.Post(path.c_str(), headers, request_body, "application/json");
-
+
if (!res) {
HandleHttpError(res, "POST");
}
-
+
if (res->status != 200) {
throw std::runtime_error("HTTP error " + std::to_string(res->status) + ": " + res->reason);
}
@@ -230,7 +240,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
duckdb_yyjson::yyjson_read(res->body.c_str(), res->body.length(), 0),
&duckdb_yyjson::yyjson_doc_free
);
-
+
if (!doc) {
throw std::runtime_error("Failed to parse JSON response");
}
@@ -301,6 +311,8 @@ static void LoadInternal(DatabaseInstance &instance) {
"set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiUrl));
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName));
+ ExtensionUtil::RegisterFunction(instance, ScalarFunction(
+ "set_json_schema", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetJsonSchema));
}
void OpenPromptExtension::Load(DuckDB &db) {