From c0d634a723cbd30f78d9b7f50431198b4f74c5da Mon Sep 17 00:00:00 2001 From: Lorenzo Mangani Date: Tue, 7 Jan 2025 20:00:19 +0100 Subject: [PATCH] SECRET support (#20) * secret manager * ENV support * cast unique_ptr to the correct type * cast unique_ptr to the correct type * resync * add tests * fix env, secrets handling * Update README.md --- CMakeLists.txt | 2 +- docs/README.md | 20 ++++++ duckdb | 2 +- extension-ci-tools | 2 +- src/include/open_prompt_secret.hpp | 13 ++++ src/open_prompt_extension.cpp | 97 ++++++++++++++++++++++++++---- src/open_prompt_secret.cpp | 60 ++++++++++++++++++ test/sql/open_prompt.test | 31 ++++++++++ 8 files changed, 212 insertions(+), 15 deletions(-) create mode 100644 src/include/open_prompt_secret.hpp create mode 100644 src/open_prompt_secret.cpp create mode 100644 test/sql/open_prompt.test diff --git a/CMakeLists.txt b/CMakeLists.txt index 8551c08..3014c0d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,7 +14,7 @@ set(LOADABLE_EXTENSION_NAME ${TARGET_NAME}_loadable_extension) project(${TARGET_NAME}) include_directories(src/include duckdb/third_party/httplib) -set(EXTENSION_SOURCES src/open_prompt_extension.cpp) +set(EXTENSION_SOURCES src/open_prompt_extension.cpp src/open_prompt_secret.cpp) if(MINGW) set(OPENSSL_USE_STATIC_LIBS TRUE) diff --git a/docs/README.md b/docs/README.md index 07dea1f..5d1da08 100644 --- a/docs/README.md +++ b/docs/README.md @@ -28,9 +28,29 @@ Setup the completions API configuration w/ optional auth token and model name SET VARIABLE openprompt_api_url = 'http://localhost:11434/v1/chat/completions'; SET VARIABLE openprompt_api_token = 'your_api_key_here'; SET VARIABLE openprompt_model_name = 'qwen2.5:0.5b'; +``` + +Alternatively the following ENV variables can be used at runtime +``` + OPEN_PROMPT_API_URL='http://localhost:11434/v1/chat/completions' + OPEN_PROMPT_API_TOKEN='your_api_key_here' + OPEN_PROMPT_MODEL_NAME='qwen2.5:0.5b' + OPEN_PROMPT_API_TIMEOUT='30' +``` +For persistent usage, configure parameters using DuckDB SECRETS +```sql +CREATE SECRET IF NOT EXISTS open_prompt ( + TYPE open_prompt, + PROVIDER config, + api_token 'your-api-token', + api_url 'http://localhost:11434/v1/chat/completions', + model_name 'qwen2.5:0.5b', + api_timeout '30' + ); ``` + ### Usage ```sql D SELECT open_prompt('Write a one-line poem about ducks') AS response; diff --git a/duckdb b/duckdb index af39bd0..b9e368e 160000 --- a/duckdb +++ b/duckdb @@ -1 +1 @@ -Subproject commit af39bd0dcf66876e09ac2a7c3baa28fe1b301151 +Subproject commit b9e368e888fed036f598acf4994bd30fbfe97472 diff --git a/extension-ci-tools b/extension-ci-tools index 00831df..f473553 160000 --- a/extension-ci-tools +++ b/extension-ci-tools @@ -1 +1 @@ -Subproject commit 00831df06713072df217d3fb2f6b5e0fae78742f +Subproject commit f473553168fd1db490aaa9f440b8f812af0568da diff --git a/src/include/open_prompt_secret.hpp b/src/include/open_prompt_secret.hpp new file mode 100644 index 0000000..d1f902f --- /dev/null +++ b/src/include/open_prompt_secret.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include "duckdb/main/secret/secret.hpp" +#include "duckdb/main/extension_util.hpp" + +namespace duckdb { + +struct CreateOpenPromptSecretFunctions { +public: + static void Register(DatabaseInstance &instance); +}; + +} // namespace duckdb diff --git a/src/open_prompt_extension.cpp b/src/open_prompt_extension.cpp index d4c07e1..4daf812 100644 --- a/src/open_prompt_extension.cpp +++ b/src/open_prompt_extension.cpp @@ -7,6 +7,12 @@ #include "duckdb/common/exception/http_exception.hpp" #include +#include "duckdb/main/secret/secret_manager.hpp" +#include "duckdb/main/secret/secret.hpp" +#include "duckdb/main/secret/secret_storage.hpp" + +#include "open_prompt_secret.hpp" + #ifdef USE_ZLIB #define CPPHTTPLIB_ZLIB_SUPPORT #endif @@ -14,6 +20,9 @@ #define CPPHTTPLIB_OPENSSL_SUPPORT #include "httplib.hpp" +#include +#include +#include #include #include #include @@ -29,13 +38,13 @@ namespace duckdb { idx_t model_idx; idx_t json_schema_idx; idx_t json_system_prompt_idx; - unique_ptr Copy() const { - auto res = make_uniq(); - res->model_idx = model_idx; - res->json_schema_idx = json_schema_idx; - res->json_system_prompt_idx = json_system_prompt_idx; - return res; - }; + unique_ptr Copy() const override { + auto res = make_uniq(); + res->model_idx = model_idx; + res->json_schema_idx = json_schema_idx; + res->json_system_prompt_idx = json_system_prompt_idx; + return unique_ptr(std::move(res)); + }; bool Equals(const FunctionData &other) const { return model_idx == other.Cast().model_idx && json_schema_idx == other.Cast().json_schema_idx && @@ -142,14 +151,75 @@ namespace duckdb { // Settings management static std::string GetConfigValue(ClientContext &context, const string &var_name, const string &default_value) { - Value value; - auto &config = ClientConfig::GetConfig(context); - if (!config.GetUserVariable(var_name, value) || value.IsNull()) { - return default_value; + // Try environment variables + { + // Create uppercase ENV version: OPEN_PROMPT_SETTING + std::string stripped_name = var_name; + const std::string prefix = "openprompt_"; + if (stripped_name.substr(0, prefix.length()) == prefix) { + stripped_name = stripped_name.substr(prefix.length()); + } + std::string env_var_name = "OPEN_PROMPT_" + stripped_name; + std::transform(env_var_name.begin(), env_var_name.end(), env_var_name.begin(), ::toupper); + // std::cout << "SEARCH ENV FOR " << env_var_name << "\n"; + + const char* env_value = std::getenv(env_var_name.c_str()); + if (env_value != nullptr && strlen(env_value) > 0) { + // std::cout << "USING ENV FOR " << var_name << "\n"; + std::string result(env_value); + return result; + } + } + + // Try to get from secrets + { + // Create lowercase secret version: open_prompt_setting + std::string secret_key = var_name; + const std::string prefix = "openprompt_"; + if (secret_key.substr(0, prefix.length()) == prefix) { + secret_key = secret_key.substr(prefix.length()); + } + // secret_key = "open_prompt_" + secret_key; + std::transform(secret_key.begin(), secret_key.end(), secret_key.begin(), ::tolower); + + auto &secret_manager = SecretManager::Get(context); + try { + // std::cout << "SEARCH SECRET FOR " << secret_key << "\n"; + auto transaction = CatalogTransaction::GetSystemCatalogTransaction(context); + auto secret_match = secret_manager.LookupSecret(transaction, "open_prompt", "open_prompt"); + if (secret_match.HasMatch()) { + auto &secret = secret_match.GetSecret(); + if (secret.GetType() != "open_prompt") { + throw InvalidInputException("Invalid secret type. Expected 'open_prompt', got '%s'", secret.GetType()); + } + const auto *kv_secret = dynamic_cast(&secret); + if (!kv_secret) { + throw InvalidInputException("Invalid secret format for 'open_prompt' secret"); + } + Value secret_value; + if (kv_secret->TryGetValue(secret_key, secret_value)) { + // std::cout << "USING SECRET FOR " << var_name << "\n"; + return secret_value.ToString(); + } + } + } catch (...) { + // If secret lookup fails, fall back to user variables } - return value.ToString(); } + // Fall back to user variables if secret not found (using original var_name) + Value value; + auto &config = ClientConfig::GetConfig(context); + if (!config.GetUserVariable(var_name, value) || value.IsNull()) { + // std::cout << "USING SET FOR " << var_name << "\n"; + return default_value; + } + + // std::cout << "USING DEFAULT FOR " << var_name << "\n"; + return value.ToString(); + } + + static void SetConfigValue(DataChunk &args, ExpressionState &state, Vector &result, const string &var_name, const string &value_type) { UnaryExecutor::Execute(args.data[0], result, args.size(), @@ -356,6 +426,9 @@ namespace duckdb { LogicalType::VARCHAR, OpenPromptRequestFunction, OpenPromptBind)); + // Register Secret functions + CreateOpenPromptSecretFunctions::Register(instance); + ExtensionUtil::RegisterFunction(instance, open_prompt); ExtensionUtil::RegisterFunction(instance, ScalarFunction( diff --git a/src/open_prompt_secret.cpp b/src/open_prompt_secret.cpp new file mode 100644 index 0000000..4beeda1 --- /dev/null +++ b/src/open_prompt_secret.cpp @@ -0,0 +1,60 @@ +#include "open_prompt_secret.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/secret/secret.hpp" +#include "duckdb/main/extension_util.hpp" + +namespace duckdb { + +static void CopySecret(const std::string &key, const CreateSecretInput &input, KeyValueSecret &result) { + auto val = input.options.find(key); + if (val != input.options.end()) { + result.secret_map[key] = val->second; + } +} + +static void RegisterCommonSecretParameters(CreateSecretFunction &function) { + // Register open_prompt common parameters + function.named_parameters["api_token"] = LogicalType::VARCHAR; + function.named_parameters["api_url"] = LogicalType::VARCHAR; + function.named_parameters["model_name"] = LogicalType::VARCHAR; + function.named_parameters["api_timeout"] = LogicalType::VARCHAR; +} + +static void RedactCommonKeys(KeyValueSecret &result) { + // Redact sensitive information + result.redact_keys.insert("api_token"); +} + +static unique_ptr CreateOpenPromptSecretFromConfig(ClientContext &context, CreateSecretInput &input) { + auto scope = input.scope; + auto result = make_uniq(scope, input.type, input.provider, input.name); + + // Copy all relevant secrets + CopySecret("api_token", input, *result); + CopySecret("api_url", input, *result); + CopySecret("model_name", input, *result); + CopySecret("api_timeout", input, *result); + + // Redact sensitive keys + RedactCommonKeys(*result); + + return std::move(result); +} + +void CreateOpenPromptSecretFunctions::Register(DatabaseInstance &instance) { + string type = "open_prompt"; + + // Register the new type + SecretType secret_type; + secret_type.name = type; + secret_type.deserializer = KeyValueSecret::Deserialize; + secret_type.default_provider = "config"; + ExtensionUtil::RegisterSecretType(instance, secret_type); + + // Register the config secret provider + CreateSecretFunction config_function = {type, "config", CreateOpenPromptSecretFromConfig}; + RegisterCommonSecretParameters(config_function); + ExtensionUtil::RegisterFunction(instance, config_function); +} + +} // namespace duckdb diff --git a/test/sql/open_prompt.test b/test/sql/open_prompt.test new file mode 100644 index 0000000..276d7e5 --- /dev/null +++ b/test/sql/open_prompt.test @@ -0,0 +1,31 @@ +# name: test/sql/rusty_quack.test +# description: test rusty_quack extension +# group: [quack] + +# Before we load the extension, this will fail +statement error +SELECT open_prompt('error'); +---- +Catalog Error: Scalar Function with name open_prompt does not exist! + +# Require statement will ensure the extension is loaded from now on +require open_prompt + +# Confirm the extension works by setting a secret +query I +CREATE SECRET IF NOT EXISTS open_prompt ( + TYPE open_prompt, + PROVIDER config, + api_token 'xxxxx', + api_url 'https://api.groq.com/openai/v1/chat/completions', + model_name 'llama-3.3-70b-versatile', + api_timeout '30' + ); +---- +true + +# Confirm the secret exists +query I +SELECT name FROM duckdb_secrets() WHERE name = 'open_prompt' ; +---- +open_prompt