Skip to content

Commit

Permalink
set_api_timeout
Browse files Browse the repository at this point in the history
Optional `set_api_timeout` settings function
  • Loading branch information
lmangani authored Nov 14, 2024
1 parent 23a4f5e commit b71e3b4
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/open_prompt_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ static void SetModelName(DataChunk &args, ExpressionState &state, Vector &result
SetConfigValue(args, state, result, "openprompt_model_name", "Model name");
}

static void SetApiTimeout(DataChunk &args, ExpressionState &state, Vector &result) {
SetConfigValue(args, state, result, "openprompt_api_timeout", "API timeout");
}

// Main Function
static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) {
D_ASSERT(args.data.size() >= 1); // At least prompt required
Expand All @@ -187,6 +191,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
"http://localhost:11434/v1/chat/completions");
std::string api_token = GetConfigValue(context, "openprompt_api_token", "");
std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b");
std::string api_timeout = GetConfigValue(context, "openprompt_api_timeout", "");
std::string json_schema;
std::string system_prompt;

Expand Down Expand Up @@ -259,6 +264,10 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
headers.emplace("Authorization", "Bearer " + api_token);
}

if (!api_timeout.empty()){
client.set_read_timeout(stoi(api_timeout), 0);
}

auto res = client.Post(path.c_str(), headers, str_request_body, "application/json");

if (!res) {
Expand Down Expand Up @@ -349,6 +358,8 @@ static void LoadInternal(DatabaseInstance &instance) {
"set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiUrl));
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName));
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_api_timeout", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiTimeout));
}

void OpenPromptExtension::Load(DuckDB &db) {
Expand Down

0 comments on commit b71e3b4

Please sign in to comment.