From 527f5ca3a4d3cde91c59c2149ea70143ae9446c6 Mon Sep 17 00:00:00 2001 From: lmangani Date: Sun, 20 Oct 2024 21:58:24 +0000 Subject: [PATCH] open_prompt --- .../workflows/MainDistributionPipeline.yml | 6 +- .gitignore | 1 + CMakeLists.txt | 4 +- extension_config.cmake | 4 +- src/http_client_extension.cpp | 216 ---------------- src/include/http_metadata_cache.hpp | 91 +++++++ src/include/http_state.hpp | 106 ++++++++ ...xtension.hpp => open_prompt_extension.hpp} | 2 +- src/open_prompt_extension.cpp | 237 ++++++++++++++++++ 9 files changed, 443 insertions(+), 224 deletions(-) delete mode 100644 src/http_client_extension.cpp create mode 100644 src/include/http_metadata_cache.hpp create mode 100644 src/include/http_state.hpp rename src/include/{http_client_extension.hpp => open_prompt_extension.hpp} (84%) create mode 100644 src/open_prompt_extension.cpp diff --git a/.github/workflows/MainDistributionPipeline.yml b/.github/workflows/MainDistributionPipeline.yml index 8d379e2..bb108bf 100644 --- a/.github/workflows/MainDistributionPipeline.yml +++ b/.github/workflows/MainDistributionPipeline.yml @@ -18,7 +18,7 @@ jobs: with: duckdb_version: main ci_tools_version: main - extension_name: http_client + extension_name: open_prompt duckdb-stable-build: name: Build extension binaries @@ -26,7 +26,7 @@ jobs: with: duckdb_version: v1.1.1 ci_tools_version: v1.1.1 - extension_name: http_client + extension_name: open_prompt duckdb-stable-deploy: name: Deploy extension binaries @@ -35,5 +35,5 @@ jobs: secrets: inherit with: duckdb_version: v1.1.1 - extension_name: http_client + extension_name: open_prompt deploy_latest: ${{ startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main' }} diff --git a/.gitignore b/.gitignore index b9f264b..261a7eb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +README.md build .idea cmake-build-debug diff --git a/CMakeLists.txt b/CMakeLists.txt index a6ede33..8551c08 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.5) # Set extension name here -set(TARGET_NAME http_client) +set(TARGET_NAME open_prompt) # DuckDB's extension distribution supports vcpkg. As such, dependencies can be added in ./vcpkg.json and then # used in cmake with find_package. Feel free to remove or replace with other dependencies. @@ -14,7 +14,7 @@ set(LOADABLE_EXTENSION_NAME ${TARGET_NAME}_loadable_extension) project(${TARGET_NAME}) include_directories(src/include duckdb/third_party/httplib) -set(EXTENSION_SOURCES src/http_client_extension.cpp) +set(EXTENSION_SOURCES src/open_prompt_extension.cpp) if(MINGW) set(OPENSSL_USE_STATIC_LIBS TRUE) diff --git a/extension_config.cmake b/extension_config.cmake index 225b9f4..1ffbf65 100644 --- a/extension_config.cmake +++ b/extension_config.cmake @@ -1,10 +1,10 @@ # This file is included by DuckDB's build system. It specifies which extension to load # Extension from this repo -duckdb_extension_load(http_client +duckdb_extension_load(open_prompt SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR} LOAD_TESTS ) # Any extra extensions that should be built -duckdb_extension_load(json) +# duckdb_extension_load(json) diff --git a/src/http_client_extension.cpp b/src/http_client_extension.cpp deleted file mode 100644 index 5e6f93c..0000000 --- a/src/http_client_extension.cpp +++ /dev/null @@ -1,216 +0,0 @@ -#define DUCKDB_EXTENSION_MAIN -#include "http_client_extension.hpp" -#include "duckdb.hpp" -#include "duckdb/function/scalar_function.hpp" -#include "duckdb/main/extension_util.hpp" -#include "duckdb/common/atomic.hpp" -#include "duckdb/common/exception/http_exception.hpp" -#include - -#define CPPHTTPLIB_OPENSSL_SUPPORT -#include "httplib.hpp" - -#include -#include - -namespace duckdb { - -// Helper function to parse URL and setup client -static std::pair SetupHttpClient(const std::string &url) { - std::string scheme, domain, path; - size_t pos = url.find("://"); - std::string mod_url = url; - if (pos != std::string::npos) { - scheme = mod_url.substr(0, pos); - mod_url.erase(0, pos + 3); - } - - pos = mod_url.find("/"); - if (pos != std::string::npos) { - domain = mod_url.substr(0, pos); - path = mod_url.substr(pos); - } else { - domain = mod_url; - path = "/"; - } - - // Create client and set a reasonable timeout (e.g., 10 seconds) - duckdb_httplib_openssl::Client client(domain.c_str()); - client.set_read_timeout(10, 0); // 10 seconds - client.set_follow_location(true); // Follow redirects - - return std::make_pair(std::move(client), path); -} - -static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std::string &request_type) { - std::string err_message = "HTTP " + request_type + " request failed. "; - - switch (res.error()) { - case duckdb_httplib_openssl::Error::Connection: - err_message += "Connection error."; - break; - case duckdb_httplib_openssl::Error::BindIPAddress: - err_message += "Failed to bind IP address."; - break; - case duckdb_httplib_openssl::Error::Read: - err_message += "Error reading response."; - break; - case duckdb_httplib_openssl::Error::Write: - err_message += "Error writing request."; - break; - case duckdb_httplib_openssl::Error::ExceedRedirectCount: - err_message += "Too many redirects."; - break; - case duckdb_httplib_openssl::Error::Canceled: - err_message += "Request was canceled."; - break; - case duckdb_httplib_openssl::Error::SSLConnection: - err_message += "SSL connection failed."; - break; - case duckdb_httplib_openssl::Error::SSLLoadingCerts: - err_message += "Failed to load SSL certificates."; - break; - case duckdb_httplib_openssl::Error::SSLServerVerification: - err_message += "SSL server verification failed."; - break; - case duckdb_httplib_openssl::Error::UnsupportedMultipartBoundaryChars: - err_message += "Unsupported characters in multipart boundary."; - break; - case duckdb_httplib_openssl::Error::Compression: - err_message += "Error during compression."; - break; - default: - err_message += "Unknown error."; - break; - } - throw std::runtime_error(err_message); -} - - -static void HTTPGetRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.data.size() == 1); - - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t input) { - std::string url = input.GetString(); - - // Use helper to setup client and parse URL - auto client_and_path = SetupHttpClient(url); - auto &client = client_and_path.first; - auto &path = client_and_path.second; - - // Make the GET request - auto res = client.Get(path.c_str()); - if (res) { - if (res->status == 200) { - return StringVector::AddString(result, res->body); - } else { - throw std::runtime_error("HTTP GET error: " + std::to_string(res->status) + " - " + res->reason); - } - } else { - // Handle errors - HandleHttpError(res, "GET"); - } - // Ensure a return value in case of an error - return string_t(); - }); -} - -static void HTTPPostRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.data.size() == 3); - - auto &url_vector = args.data[0]; - auto &headers_vector = args.data[1]; - auto &body_vector = args.data[2]; - - TernaryExecutor::Execute( - url_vector, headers_vector, body_vector, result, args.size(), - [&](string_t url, string_t headers, string_t body) { - std::string url_str = url.GetString(); - - // Use helper to setup client and parse URL - auto client_and_path = SetupHttpClient(url_str); - auto &client = client_and_path.first; - auto &path = client_and_path.second; - - // Prepare headers - duckdb_httplib_openssl::Headers header_map; - std::istringstream header_stream(headers.GetString()); - std::string header; - while (std::getline(header_stream, header)) { - size_t colon_pos = header.find(':'); - if (colon_pos != std::string::npos) { - std::string key = header.substr(0, colon_pos); - std::string value = header.substr(colon_pos + 1); - // Trim leading and trailing whitespace - key.erase(0, key.find_first_not_of(" \t")); - key.erase(key.find_last_not_of(" \t") + 1); - value.erase(0, value.find_first_not_of(" \t")); - value.erase(value.find_last_not_of(" \t") + 1); - header_map.emplace(key, value); - } - } - - // Make the POST request with headers and body - auto res = client.Post(path.c_str(), header_map, body.GetString(), "application/json"); - if (res) { - if (res->status == 200) { - return StringVector::AddString(result, res->body); - } else { - throw std::runtime_error("HTTP POST error: " + std::to_string(res->status) + " - " + res->reason); - } - } else { - // Handle errors - HandleHttpError(res, "POST"); - } - // Ensure a return value in case of an error - return string_t(); - }); -} - - -static void LoadInternal(DatabaseInstance &instance) { - ScalarFunctionSet http_get("http_get"); - http_get.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, HTTPGetRequestFunction)); - ExtensionUtil::RegisterFunction(instance, http_get); - - ScalarFunctionSet http_post("http_post"); - http_post.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::VARCHAR, HTTPPostRequestFunction)); - ExtensionUtil::RegisterFunction(instance, http_post); -} - -void HttpClientExtension::Load(DuckDB &db) { - LoadInternal(*db.instance); -} - -std::string HttpClientExtension::Name() { - return "http_client"; -} - -std::string HttpClientExtension::Version() const { -#ifdef EXT_VERSION_HTTPCLIENT - return EXT_VERSION_HTTPCLIENT; -#else - return ""; -#endif -} - - -} // namespace duckdb - -extern "C" { -DUCKDB_EXTENSION_API void http_client_init(duckdb::DatabaseInstance &db) { - duckdb::DuckDB db_wrapper(db); - db_wrapper.LoadExtension(); -} - -DUCKDB_EXTENSION_API const char *http_client_version() { - return duckdb::DuckDB::LibraryVersion(); -} -} - -#ifndef DUCKDB_EXTENSION_MAIN -#error DUCKDB_EXTENSION_MAIN not defined -#endif - diff --git a/src/include/http_metadata_cache.hpp b/src/include/http_metadata_cache.hpp new file mode 100644 index 0000000..73d032b --- /dev/null +++ b/src/include/http_metadata_cache.hpp @@ -0,0 +1,91 @@ +#pragma once + +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/chrono.hpp" +#include "duckdb/common/list.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_context_state.hpp" + +#include +#include + +namespace duckdb { + +struct HTTPMetadataCacheEntry { + idx_t length; + time_t last_modified; +}; + +// Simple cache with a max age for an entry to be valid +class HTTPMetadataCache : public ClientContextState { +public: + explicit HTTPMetadataCache(bool flush_on_query_end_p, bool shared_p) + : flush_on_query_end(flush_on_query_end_p), shared(shared_p) {}; + + void Insert(const string &path, HTTPMetadataCacheEntry val) { + if (shared) { + lock_guard parallel_lock(lock); + map[path] = val; + } else { + map[path] = val; + } + }; + + void Erase(string path) { + if (shared) { + lock_guard parallel_lock(lock); + map.erase(path); + } else { + map.erase(path); + } + }; + + bool Find(string path, HTTPMetadataCacheEntry &ret_val) { + if (shared) { + lock_guard parallel_lock(lock); + auto lookup = map.find(path); + if (lookup != map.end()) { + ret_val = lookup->second; + return true; + } else { + return false; + } + } else { + auto lookup = map.find(path); + if (lookup != map.end()) { + ret_val = lookup->second; + return true; + } else { + return false; + } + } + }; + + void Clear() { + if (shared) { + lock_guard parallel_lock(lock); + map.clear(); + } else { + map.clear(); + } + } + + //! Called by the ClientContext when the current query ends + void QueryEnd(ClientContext &context) override { + if (flush_on_query_end) { + Clear(); + } + } + +protected: + mutex lock; + unordered_map map; + bool flush_on_query_end; + bool shared; +}; + +} // namespace duckdb diff --git a/src/include/http_state.hpp b/src/include/http_state.hpp new file mode 100644 index 0000000..93f4abf --- /dev/null +++ b/src/include/http_state.hpp @@ -0,0 +1,106 @@ +#pragma once + +#include "duckdb/common/file_opener.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/main/client_context_state.hpp" + +namespace duckdb { + +class CachedFileHandle; + +//! Represents a file that is intended to be fully downloaded, then used in parallel by multiple threads +class CachedFile : public enable_shared_from_this { + friend class CachedFileHandle; + +public: + unique_ptr GetHandle() { + auto this_ptr = shared_from_this(); + return make_uniq(this_ptr); + } + +private: + //! Cached Data + shared_ptr data; + //! Data capacity + uint64_t capacity = 0; + //! Size of file + idx_t size; + //! Lock for initializing the file + mutex lock; + //! When initialized is set to true, the file is safe for parallel reading without holding the lock + atomic initialized = {false}; +}; + +//! Handle to a CachedFile +class CachedFileHandle { +public: + explicit CachedFileHandle(shared_ptr &file_p); + + //! allocate a buffer for the file + void AllocateBuffer(idx_t size); + //! Indicate the file is fully downloaded and safe for parallel reading without lock + void SetInitialized(idx_t total_size); + //! Grow buffer to new size, copying over `bytes_to_copy` to the new buffer + void GrowBuffer(idx_t new_capacity, idx_t bytes_to_copy); + //! Write to the buffer + void Write(const char *buffer, idx_t length, idx_t offset = 0); + + bool Initialized() { + return file->initialized; + } + const char *GetData() { + return file->data.get(); + } + uint64_t GetCapacity() { + return file->capacity; + } + //! Return the size of the initialized file + idx_t GetSize() { + D_ASSERT(file->initialized); + return file->size; + } + +private: + unique_ptr> lock; + shared_ptr file; +}; + +class HTTPState : public ClientContextState { +public: + //! Reset all counters and cached files + void Reset(); + //! Get cache entry, create if not exists + shared_ptr &GetCachedFile(const string &path); + //! Helper functions to get the HTTP state + static shared_ptr TryGetState(ClientContext &context); + static shared_ptr TryGetState(optional_ptr opener); + + bool IsEmpty() { + return head_count == 0 && get_count == 0 && put_count == 0 && post_count == 0 && total_bytes_received == 0 && + total_bytes_sent == 0; + } + + atomic head_count {0}; + atomic get_count {0}; + atomic put_count {0}; + atomic post_count {0}; + atomic total_bytes_received {0}; + atomic total_bytes_sent {0}; + + //! Called by the ClientContext when the current query ends + void QueryEnd(ClientContext &context) override { + Reset(); + } + void WriteProfilingInformation(std::ostream &ss) override; + +private: + //! Mutex to lock when getting the cached file(Parallel Only) + mutex cached_files_mutex; + //! In case of fully downloading the file, the cached files of this query + unordered_map> cached_files; +}; + +} // namespace duckdb diff --git a/src/include/http_client_extension.hpp b/src/include/open_prompt_extension.hpp similarity index 84% rename from src/include/http_client_extension.hpp rename to src/include/open_prompt_extension.hpp index d592819..1cd3077 100644 --- a/src/include/http_client_extension.hpp +++ b/src/include/open_prompt_extension.hpp @@ -6,7 +6,7 @@ namespace duckdb { using HeaderMap = case_insensitive_map_t; -class HttpClientExtension : public Extension { +class OpenPromptExtension : public Extension { public: void Load(DuckDB &db) override; std::string Name() override; diff --git a/src/open_prompt_extension.cpp b/src/open_prompt_extension.cpp new file mode 100644 index 0000000..f3a8e7b --- /dev/null +++ b/src/open_prompt_extension.cpp @@ -0,0 +1,237 @@ +#define DUCKDB_EXTENSION_MAIN +#include "open_prompt_extension.hpp" +#include "duckdb.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/main/extension_util.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/exception/http_exception.hpp" +#include + +#define CPPHTTPLIB_OPENSSL_SUPPORT +#include "httplib.hpp" + +#include +#include + +namespace duckdb { + +// Helper function to parse URL and setup client +static std::pair SetupHttpClient(const std::string &url) { + std::string scheme, domain, path; + size_t pos = url.find("://"); + std::string mod_url = url; + if (pos != std::string::npos) { + scheme = mod_url.substr(0, pos); + mod_url.erase(0, pos + 3); + } + + pos = mod_url.find("/"); + if (pos != std::string::npos) { + domain = mod_url.substr(0, pos); + path = mod_url.substr(pos); + } else { + domain = mod_url; + path = "/"; + } + + // Create client and set a reasonable timeout (e.g., 10 seconds) + duckdb_httplib_openssl::Client client(domain.c_str()); + client.set_read_timeout(10, 0); // 10 seconds + client.set_follow_location(true); // Follow redirects + + return std::make_pair(std::move(client), path); +} + +static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std::string &request_type) { + std::string err_message = "HTTP " + request_type + " request failed. "; + + switch (res.error()) { + case duckdb_httplib_openssl::Error::Connection: + err_message += "Connection error."; + break; + case duckdb_httplib_openssl::Error::BindIPAddress: + err_message += "Failed to bind IP address."; + break; + case duckdb_httplib_openssl::Error::Read: + err_message += "Error reading response."; + break; + case duckdb_httplib_openssl::Error::Write: + err_message += "Error writing request."; + break; + case duckdb_httplib_openssl::Error::ExceedRedirectCount: + err_message += "Too many redirects."; + break; + case duckdb_httplib_openssl::Error::Canceled: + err_message += "Request was canceled."; + break; + case duckdb_httplib_openssl::Error::SSLConnection: + err_message += "SSL connection failed."; + break; + case duckdb_httplib_openssl::Error::SSLLoadingCerts: + err_message += "Failed to load SSL certificates."; + break; + case duckdb_httplib_openssl::Error::SSLServerVerification: + err_message += "SSL server verification failed."; + break; + case duckdb_httplib_openssl::Error::UnsupportedMultipartBoundaryChars: + err_message += "Unsupported characters in multipart boundary."; + break; + case duckdb_httplib_openssl::Error::Compression: + err_message += "Error during compression."; + break; + default: + err_message += "Unknown error."; + break; + } + throw std::runtime_error(err_message); +} + + +// Open Prompt +static std::string api_url = "http://localhost:11434/v1/chat/completions"; +static std::string api_token = ""; // Store your API token here +static std::string model_name = "llama2"; // Default model + +// Retrieve the API URL from the stored settings +static std::string GetApiUrl() { + return api_url.empty() ? "http://localhost:11434/v1/chat/completions" : api_url; +} + +// Retrieve the API token from the stored settings +static std::string GetApiToken() { + return api_token; +} + +// Retrieve the model name from the stored settings +static std::string GetModelName() { + return model_name.empty() ? "llama2" : model_name; +} + +// Function to set API token +void SetApiToken(const std::string &token) { + api_token = token; +} + +// Function to set API URL +void SetApiUrl(const std::string &url) { + api_url = url; +} + +// Function to set model name +void SetModelName(const std::string &model) { + model_name = model; +} + +// Open Prompt Function +static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.data.size() == 1); // Expecting only the prompt string + + UnaryExecutor::Execute(args.data[0], result, args.size(), + [&](string_t user_prompt) { + std::string api_url = GetApiUrl(); // Retrieve the API URL from settings + std::string api_token = GetApiToken(); // Retrieve the API Token from settings + std::string model_name = GetModelName(); // Retrieve the model name from settings + + // Prepare the JSON body + std::string request_body = "{"; + request_body += "\"model\":\"" + model_name + "\","; + request_body += "\"messages\":["; + request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},"; + request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}"; + request_body += "]}"; + + try { + // Make the POST request + auto client_and_path = SetupHttpClient(api_url); + auto &client = client_and_path.first; + auto &path = client_and_path.second; + + // Setup headers + duckdb_httplib_openssl::Headers header_map; + header_map.emplace("Content-Type", "application/json"); + if (!api_token.empty()) { + header_map.emplace("Authorization", "Bearer " + api_token); + } + + // Send the request + auto res = client.Post(path.c_str(), header_map, request_body, "application/json"); + if (res && res->status == 200) { + return StringVector::AddString(result, res->body); + } else { + throw std::runtime_error("HTTP POST error: " + std::to_string(res->status) + " - " + res->reason); + } + } catch (std::exception &e) { + // In case of any error, return the original input text to avoid disruption + return StringVector::AddString(result, user_prompt); + } + }); +} + + +static void LoadInternal(DatabaseInstance &instance) { + // Register open_prompt function + ScalarFunctionSet open_prompt("open_prompt"); + open_prompt.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction)); + ExtensionUtil::RegisterFunction(instance, open_prompt); + + // Function to set API token + ExtensionUtil::RegisterFunction(instance, ScalarFunction( + "set_api_token", {LogicalType::VARCHAR}, LogicalType::VARCHAR, // Change here + [](DataChunk &args, ExpressionState &state, Vector &result) { + SetApiToken(args.data[0].GetValue(0).ToString()); + return StringVector::AddString(result, "API token set successfully."); + })); + + // Function to set API URL + ExtensionUtil::RegisterFunction(instance, ScalarFunction( + "set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, // Change here + [](DataChunk &args, ExpressionState &state, Vector &result) { + SetApiUrl(args.data[0].GetValue(0).ToString()); + return StringVector::AddString(result, "API URL set successfully."); + })); + + // Function to set model name + ExtensionUtil::RegisterFunction(instance, ScalarFunction( + "set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, // Change here + [](DataChunk &args, ExpressionState &state, Vector &result) { + SetModelName(args.data[0].GetValue(0).ToString()); + return StringVector::AddString(result, "Model name set successfully."); + })); +} + + +void OpenPromptExtension::Load(DuckDB &db) { + LoadInternal(*db.instance); +} + +std::string OpenPromptExtension::Name() { + return "open_prompt"; +} + +std::string OpenPromptExtension::Version() const { +#ifdef EXT_VERSION_OPENPROMPT + return EXT_VERSION_OPENPROMPT; +#else + return ""; +#endif +} + + +} // namespace duckdb + +extern "C" { +DUCKDB_EXTENSION_API void open_prompt_init(duckdb::DatabaseInstance &db) { + duckdb::DuckDB db_wrapper(db); + db_wrapper.LoadExtension(); +} + +DUCKDB_EXTENSION_API const char *open_prompt_version() { + return duckdb::DuckDB::LibraryVersion(); +} +} + +#ifndef DUCKDB_EXTENSION_MAIN +#error DUCKDB_EXTENSION_MAIN not defined +#endif +