Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SECRET support #20

Merged
merged 9 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
project(${TARGET_NAME})
include_directories(src/include duckdb/third_party/httplib)

set(EXTENSION_SOURCES src/open_prompt_extension.cpp)
set(EXTENSION_SOURCES src/open_prompt_extension.cpp src/open_prompt_secret.cpp)

if(MINGW)
set(OPENSSL_USE_STATIC_LIBS TRUE)
Expand Down
20 changes: 20 additions & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,29 @@ Setup the completions API configuration w/ optional auth token and model name
SET VARIABLE openprompt_api_url = 'http://localhost:11434/v1/chat/completions';
SET VARIABLE openprompt_api_token = 'your_api_key_here';
SET VARIABLE openprompt_model_name = 'qwen2.5:0.5b';
```

Alternatively the following ENV variables can be used at runtime
```
OPEN_PROMPT_API_URL='http://localhost:11434/v1/chat/completions'
OPEN_PROMPT_API_TOKEN='your_api_key_here'
OPEN_PROMPT_MODEL_NAME='qwen2.5:0.5b'
OPEN_PROMPT_API_TIMEOUT='30'
```

For persistent usage, configure parameters using DuckDB SECRETS
```sql
CREATE SECRET IF NOT EXISTS open_prompt (
TYPE open_prompt,
PROVIDER config,
api_token 'your-api-token',
api_url 'http://localhost:11434/v1/chat/completions',
model_name 'qwen2.5:0.5b',
api_timeout '30'
);
```


### Usage
```sql
D SELECT open_prompt('Write a one-line poem about ducks') AS response;
Expand Down
2 changes: 1 addition & 1 deletion duckdb
Submodule duckdb updated 3121 files
13 changes: 13 additions & 0 deletions src/include/open_prompt_secret.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

#include "duckdb/main/secret/secret.hpp"
#include "duckdb/main/extension_util.hpp"

namespace duckdb {

struct CreateOpenPromptSecretFunctions {
public:
static void Register(DatabaseInstance &instance);
};

} // namespace duckdb
97 changes: 85 additions & 12 deletions src/open_prompt_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,22 @@
#include "duckdb/common/exception/http_exception.hpp"
#include <duckdb/parser/parsed_data/create_scalar_function_info.hpp>

#include "duckdb/main/secret/secret_manager.hpp"
#include "duckdb/main/secret/secret.hpp"
#include "duckdb/main/secret/secret_storage.hpp"

#include "open_prompt_secret.hpp"

#ifdef USE_ZLIB
#define CPPHTTPLIB_ZLIB_SUPPORT
#endif

#define CPPHTTPLIB_OPENSSL_SUPPORT
#include "httplib.hpp"

#include <cstdlib>
#include <algorithm>
#include <cctype>
#include <string>
#include <sstream>
#include <mutex>
Expand All @@ -29,14 +38,14 @@
idx_t model_idx;
idx_t json_schema_idx;
idx_t json_system_prompt_idx;
unique_ptr<FunctionData> Copy() const {
auto res = make_uniq<OpenPromptData>();
res->model_idx = model_idx;
res->json_schema_idx = json_schema_idx;
res->json_system_prompt_idx = json_system_prompt_idx;
return res;
};
unique_ptr<FunctionData> Copy() const override {
auto res = make_uniq<OpenPromptData>();
res->model_idx = model_idx;
res->json_schema_idx = json_schema_idx;
res->json_system_prompt_idx = json_system_prompt_idx;
return unique_ptr<FunctionData>(std::move(res));
};
bool Equals(const FunctionData &other) const {

Check warning on line 48 in src/open_prompt_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'Equals' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 48 in src/open_prompt_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'Equals' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 48 in src/open_prompt_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_eh, wasm32-emscripten)

'Equals' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 48 in src/open_prompt_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_eh, wasm32-emscripten)

'Equals' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 48 in src/open_prompt_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'Equals' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 48 in src/open_prompt_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'Equals' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 48 in src/open_prompt_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten)

'Equals' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 48 in src/open_prompt_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten)

'Equals' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 48 in src/open_prompt_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten)

'Equals' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 48 in src/open_prompt_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten)

'Equals' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 48 in src/open_prompt_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'Equals' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 48 in src/open_prompt_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'Equals' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 48 in src/open_prompt_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'Equals' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 48 in src/open_prompt_extension.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'Equals' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]
return model_idx == other.Cast<OpenPromptData>().model_idx &&
json_schema_idx == other.Cast<OpenPromptData>().json_schema_idx &&
json_system_prompt_idx==other.Cast<OpenPromptData>().json_system_prompt_idx;
Expand Down Expand Up @@ -142,14 +151,75 @@

// Settings management
static std::string GetConfigValue(ClientContext &context, const string &var_name, const string &default_value) {
Value value;
auto &config = ClientConfig::GetConfig(context);
if (!config.GetUserVariable(var_name, value) || value.IsNull()) {
return default_value;
// Try environment variables
{
// Create uppercase ENV version: OPEN_PROMPT_SETTING
std::string stripped_name = var_name;
const std::string prefix = "openprompt_";
if (stripped_name.substr(0, prefix.length()) == prefix) {
stripped_name = stripped_name.substr(prefix.length());
}
std::string env_var_name = "OPEN_PROMPT_" + stripped_name;
std::transform(env_var_name.begin(), env_var_name.end(), env_var_name.begin(), ::toupper);
// std::cout << "SEARCH ENV FOR " << env_var_name << "\n";

const char* env_value = std::getenv(env_var_name.c_str());
if (env_value != nullptr && strlen(env_value) > 0) {
// std::cout << "USING ENV FOR " << var_name << "\n";
std::string result(env_value);
return result;
}
}

// Try to get from secrets
{
// Create lowercase secret version: open_prompt_setting
std::string secret_key = var_name;
const std::string prefix = "openprompt_";
if (secret_key.substr(0, prefix.length()) == prefix) {
secret_key = secret_key.substr(prefix.length());
}
// secret_key = "open_prompt_" + secret_key;
std::transform(secret_key.begin(), secret_key.end(), secret_key.begin(), ::tolower);

auto &secret_manager = SecretManager::Get(context);
try {
// std::cout << "SEARCH SECRET FOR " << secret_key << "\n";
auto transaction = CatalogTransaction::GetSystemCatalogTransaction(context);
auto secret_match = secret_manager.LookupSecret(transaction, "open_prompt", "open_prompt");
if (secret_match.HasMatch()) {
auto &secret = secret_match.GetSecret();
if (secret.GetType() != "open_prompt") {
throw InvalidInputException("Invalid secret type. Expected 'open_prompt', got '%s'", secret.GetType());
}
const auto *kv_secret = dynamic_cast<const KeyValueSecret*>(&secret);
if (!kv_secret) {
throw InvalidInputException("Invalid secret format for 'open_prompt' secret");
}
Value secret_value;
if (kv_secret->TryGetValue(secret_key, secret_value)) {
// std::cout << "USING SECRET FOR " << var_name << "\n";
return secret_value.ToString();
}
}
} catch (...) {
// If secret lookup fails, fall back to user variables
}
return value.ToString();
}

// Fall back to user variables if secret not found (using original var_name)
Value value;
auto &config = ClientConfig::GetConfig(context);
if (!config.GetUserVariable(var_name, value) || value.IsNull()) {
// std::cout << "USING SET FOR " << var_name << "\n";
return default_value;
}

// std::cout << "USING DEFAULT FOR " << var_name << "\n";
return value.ToString();
}


static void SetConfigValue(DataChunk &args, ExpressionState &state, Vector &result,
const string &var_name, const string &value_type) {
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
Expand Down Expand Up @@ -356,6 +426,9 @@
LogicalType::VARCHAR, OpenPromptRequestFunction,
OpenPromptBind));

// Register Secret functions
CreateOpenPromptSecretFunctions::Register(instance);

ExtensionUtil::RegisterFunction(instance, open_prompt);

ExtensionUtil::RegisterFunction(instance, ScalarFunction(
Expand Down
60 changes: 60 additions & 0 deletions src/open_prompt_secret.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "open_prompt_secret.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/main/secret/secret.hpp"
#include "duckdb/main/extension_util.hpp"

namespace duckdb {

static void CopySecret(const std::string &key, const CreateSecretInput &input, KeyValueSecret &result) {
auto val = input.options.find(key);
if (val != input.options.end()) {
result.secret_map[key] = val->second;
}
}

static void RegisterCommonSecretParameters(CreateSecretFunction &function) {
// Register open_prompt common parameters
function.named_parameters["api_token"] = LogicalType::VARCHAR;
function.named_parameters["api_url"] = LogicalType::VARCHAR;
function.named_parameters["model_name"] = LogicalType::VARCHAR;
function.named_parameters["api_timeout"] = LogicalType::VARCHAR;
}

static void RedactCommonKeys(KeyValueSecret &result) {
// Redact sensitive information
result.redact_keys.insert("api_token");
}

static unique_ptr<BaseSecret> CreateOpenPromptSecretFromConfig(ClientContext &context, CreateSecretInput &input) {
auto scope = input.scope;
auto result = make_uniq<KeyValueSecret>(scope, input.type, input.provider, input.name);

// Copy all relevant secrets
CopySecret("api_token", input, *result);
CopySecret("api_url", input, *result);
CopySecret("model_name", input, *result);
CopySecret("api_timeout", input, *result);

// Redact sensitive keys
RedactCommonKeys(*result);

return std::move(result);
}

void CreateOpenPromptSecretFunctions::Register(DatabaseInstance &instance) {
string type = "open_prompt";

// Register the new type
SecretType secret_type;
secret_type.name = type;
secret_type.deserializer = KeyValueSecret::Deserialize<KeyValueSecret>;
secret_type.default_provider = "config";
ExtensionUtil::RegisterSecretType(instance, secret_type);

// Register the config secret provider
CreateSecretFunction config_function = {type, "config", CreateOpenPromptSecretFromConfig};
RegisterCommonSecretParameters(config_function);
ExtensionUtil::RegisterFunction(instance, config_function);
}

} // namespace duckdb
31 changes: 31 additions & 0 deletions test/sql/open_prompt.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# name: test/sql/rusty_quack.test
# description: test rusty_quack extension
# group: [quack]

# Before we load the extension, this will fail
statement error
SELECT open_prompt('error');
----
Catalog Error: Scalar Function with name open_prompt does not exist!

# Require statement will ensure the extension is loaded from now on
require open_prompt

# Confirm the extension works by setting a secret
query I
CREATE SECRET IF NOT EXISTS open_prompt (
TYPE open_prompt,
PROVIDER config,
api_token 'xxxxx',
api_url 'https://api.groq.com/openai/v1/chat/completions',
model_name 'llama-3.3-70b-versatile',
api_timeout '30'
);
----
true

# Confirm the secret exists
query I
SELECT name FROM duckdb_secrets() WHERE name = 'open_prompt' ;
----
open_prompt
Loading