From b71e3b44ea501186828a796de5ff3427619453b1 Mon Sep 17 00:00:00 2001 From: Lorenzo Mangani Date: Thu, 14 Nov 2024 12:52:31 +0100 Subject: [PATCH] set_api_timeout Optional `set_api_timeout` settings function --- src/open_prompt_extension.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/open_prompt_extension.cpp b/src/open_prompt_extension.cpp index e6625ba..216524b 100644 --- a/src/open_prompt_extension.cpp +++ b/src/open_prompt_extension.cpp @@ -173,6 +173,10 @@ static void SetModelName(DataChunk &args, ExpressionState &state, Vector &result SetConfigValue(args, state, result, "openprompt_model_name", "Model name"); } +static void SetApiTimeout(DataChunk &args, ExpressionState &state, Vector &result) { + SetConfigValue(args, state, result, "openprompt_api_timeout", "API timeout"); +} + // Main Function static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.data.size() >= 1); // At least prompt required @@ -187,6 +191,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V "http://localhost:11434/v1/chat/completions"); std::string api_token = GetConfigValue(context, "openprompt_api_token", ""); std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b"); + std::string api_timeout = GetConfigValue(context, "openprompt_api_timeout", ""); std::string json_schema; std::string system_prompt; @@ -259,6 +264,10 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V headers.emplace("Authorization", "Bearer " + api_token); } + if (!api_timeout.empty()){ + client.set_read_timeout(stoi(api_timeout), 0); + } + auto res = client.Post(path.c_str(), headers, str_request_body, "application/json"); if (!res) { @@ -349,6 +358,8 @@ static void LoadInternal(DatabaseInstance &instance) { "set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiUrl)); ExtensionUtil::RegisterFunction(instance, ScalarFunction( "set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName)); + ExtensionUtil::RegisterFunction(instance, ScalarFunction( + "set_api_timeout", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiTimeout)); } void OpenPromptExtension::Load(DuckDB &db) {