-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement openAI endpoint invoker for nuget #15797
Changes from all commits
a1feb54
dfc5461
8aaac2f
6cdae88
5c7c67c
842f7fd
1fec44a
a6f18e3
9c6bb2e
7e14fd6
c0e72cc
e558c32
6b88835
dc7ae1f
94279c1
2973579
952eb52
ae7d534
7e51a9c
fdad7c7
c013693
e0dcb22
866db7e
d099bd6
17ae944
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5993,3 +5993,32 @@ https://github.com/tensorflow/tfjs | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
|
||
—— | ||
|
||
curl/curl | ||
|
||
https://github.com/curl | ||
|
||
COPYRIGHT AND PERMISSION NOTICE | ||
|
||
Copyright (C) Daniel Stenberg, <[email protected]>, and many | ||
contributors, see the THANKS file. | ||
|
||
All rights reserved. | ||
|
||
Permission to use, copy, modify, and distribute this software for any purpose | ||
with or without fee is hereby granted, provided that the above copyright | ||
notice and this permission notice appear in all copies. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. IN | ||
NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, | ||
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR | ||
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE | ||
OR OTHER DEALINGS IN THE SOFTWARE. | ||
|
||
Except as contained in this notice, the name of a copyright holder shall not | ||
be used in advertising or otherwise to promote the sale, use or other dealings | ||
in this Software without prior written authorization of the copyright holder. |
+21 −7 | README.md | |
+0 −11 | bazel/BUILD | |
+1 −23 | bazel/emscripten_deps.bzl | |
+1 −1 | bazel/emscripten_toolchain/emscripten_config | |
+0 −37 | bazel/revisions.bzl | |
+2 −2 | docker/Dockerfile | |
+2 −12 | emscripten-releases-tags.json | |
+0 −2 | emsdk.py | |
+7 −64 | emsdk_manifest.json | |
+2 −3 | scripts/update_node.py | |
+3 −3 | test/test.py | |
+4 −0 | test/test_activation.ps1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,13 +44,11 @@ if (WIN32) | |
vcpkg_install(re2) | ||
vcpkg_install(boost-interprocess) | ||
vcpkg_install(boost-stacktrace) | ||
vcpkg_install(zlib) | ||
vcpkg_install(pthread) | ||
vcpkg_install(b64) | ||
|
||
add_dependencies(getb64 getpthread) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They all need be added to ThirdPartyNotices.txt since the EP you built will contains code from the 3rd-party libraries. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is a removal. |
||
add_dependencies(getpthread getzlib) | ||
add_dependencies(getzlib getboost-stacktrace) | ||
add_dependencies(getpthread getboost-stacktrace) | ||
add_dependencies(getboost-stacktrace getboost-interprocess) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already have boost. We cannot have two versions of the same lib. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like above. |
||
add_dependencies(getboost-interprocess getre2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already have re2. We cannot have two versions of the same lib. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but triton won't build - it goes to vcpkg repository for dependencies, not us. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Triton's http-client-library is relatively simple. It only has a few sources files and doesn't have many dependencies. You may write a CMakeList.txt for it if theirs doesn't meet our needs. For example, for ONNX we have https://github.com/microsoft/onnxruntime/blob/main/cmake/external/onnx_minimal.cmake . And if your need deeper at their source code, you might found it just does the same thing what you are doing for the OpenAI EP. You might rewrite the code and extract the common part and reuse the common parts for the two different types of endpoints. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For example, if you are writing a custom op, you are free to use anyway to get your dependencies. Because custom ops live in a separated DLL. We do not worry about the their versions as long as our custom op API has a stable binary interface(ABI). |
||
add_dependencies(getre2 getrapidjson) | ||
|
@@ -59,11 +57,11 @@ if (WIN32) | |
|
||
ExternalProject_Add(triton | ||
GIT_REPOSITORY https://github.com/triton-inference-server/client.git | ||
GIT_TAG r22.12 | ||
GIT_TAG r23.05 | ||
PREFIX triton | ||
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src | ||
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build | ||
CMAKE_ARGS -DVCPKG_TARGET_TRIPLET=${onnxruntime_target_platform}-windows -DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON | ||
CMAKE_ARGS -DVCPKG_TARGET_TRIPLET=${onnxruntime_target_platform}-windows -DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF | ||
INSTALL_COMMAND "" | ||
UPDATE_COMMAND "") | ||
|
||
|
@@ -85,11 +83,11 @@ else() | |
|
||
ExternalProject_Add(triton | ||
GIT_REPOSITORY https://github.com/triton-inference-server/client.git | ||
GIT_TAG r22.12 | ||
GIT_TAG r23.05 | ||
PREFIX triton | ||
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src | ||
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build | ||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON | ||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF | ||
INSTALL_COMMAND "" | ||
UPDATE_COMMAND "") | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,9 @@ | |
// Licensed under the MIT License. | ||
|
||
#ifdef USE_AZURE | ||
#define CURL_STATICLIB | ||
RandySheriffH marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#include "http_client.h" | ||
#include "curl/curl.h" | ||
#include "core/common/common.h" | ||
#include "core/framework/cloud_invoker.h" | ||
#include "core/framework/ort_value.h" | ||
|
@@ -18,13 +20,14 @@ namespace onnxruntime { | |
|
||
namespace tc = triton::client; | ||
|
||
const char* kAzureUri = "azure.uri"; | ||
const char* kAzureModelName = "azure.model_name"; | ||
const char* kAzureModelVer = "azure.model_version"; | ||
const char* kAzureVerbose = "azure.verbose"; | ||
const char* kAzureEndpointType = "azure.endpoint_type"; | ||
const char* kAzureAuthKey = "azure.auth_key"; | ||
const char* kAzureTriton = "triton"; | ||
constexpr const char* kAzureUri = "azure.uri"; | ||
constexpr const char* kAzureModelName = "azure.model_name"; | ||
constexpr const char* kAzureModelVer = "azure.model_version"; | ||
constexpr const char* kAzureVerbose = "azure.verbose"; | ||
constexpr const char* kAzureEndpointType = "azure.endpoint_type"; | ||
constexpr const char* kAzureAuthKey = "azure.auth_key"; | ||
constexpr const char* kAzureTriton = "triton"; | ||
constexpr const char* kAzureOpenAI = "openai"; | ||
|
||
CloudEndPointInvoker::CloudEndPointInvoker(const CloudEndPointConfig& config, | ||
const AllocatorPtr& allocator) : config_(config), allocator_(allocator) { | ||
|
@@ -33,6 +36,163 @@ CloudEndPointInvoker::CloudEndPointInvoker(const CloudEndPointConfig& config, | |
} | ||
} | ||
|
||
class CurlGlobal { | ||
public: | ||
static void Init() { | ||
static CurlGlobal curl_global; | ||
} | ||
|
||
private: | ||
CurlGlobal() { | ||
// Thread-safety is a must since curl might also be initialized in triton client. | ||
const auto* info = curl_version_info(CURLVERSION_NOW); | ||
ORT_ENFORCE(info->features & CURL_VERSION_THREADSAFE, "curl global init not thread-safe, need to upgrade curl version!"); | ||
RandySheriffH marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ORT_ENFORCE(curl_global_init(CURL_GLOBAL_DEFAULT) == CURLE_OK, "Failed to initialize curl global env!"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even through curl_global_init is thread-safe, we still should not call it from any global variable's constructor since at that time the static variables in libcurl itself might have not been initialized. So I think the best way to ensure this is to put this call in OrtEnv's constructor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't follow, how come curl statics are not initialized on OpenAIInvoker::OpenAIInvoker(...), while OrtEnv constructor could avoid that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Talked with Randy offline. The most difficult part is: this file also uses TritonClient lib which also has similar code that initialize and deinit curl. I don't see an easy way to coordinate them together. Will leave the discussion later. |
||
} | ||
~CurlGlobal() { | ||
curl_global_cleanup(); | ||
} | ||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CurlGlobal); | ||
}; | ||
|
||
// OpenAIInvoker | ||
class OpenAIInvoker : public CloudEndPointInvoker { | ||
public: | ||
OpenAIInvoker(const CloudEndPointConfig& config, const AllocatorPtr& allocator); | ||
onnxruntime::Status Send(const CloudEndPointConfig& run_options, | ||
const InlinedVector<std::string>& input_names, | ||
gsl::span<const OrtValue> ort_inputs, | ||
const InlinedVector<std::string>& output_names, | ||
std::vector<OrtValue>& ort_outputs) const override; | ||
|
||
private: | ||
std::string uri_; | ||
std::string model_name_; | ||
}; | ||
|
||
OpenAIInvoker::OpenAIInvoker(const CloudEndPointConfig& config, | ||
const AllocatorPtr& allocator) : CloudEndPointInvoker(config, allocator) { | ||
CurlGlobal::Init(); | ||
ReadConfig(kAzureUri, uri_); | ||
ReadConfig(kAzureModelName, model_name_); | ||
} | ||
|
||
struct StringBuffer { | ||
StringBuffer() = default; | ||
~StringBuffer() = default; | ||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(StringBuffer); | ||
std::stringstream ss_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here you need a std::string, instead of stringstream. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. stringstream buffers extra space from coming contents, which is more suitable for the scenario here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't std::string do the same thing? Every string has a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. stringstream is designd for buffering content grows dynamically. |
||
}; | ||
|
||
// apply the callback only when response is for sure to be a '/0' terminated string | ||
snnn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
static size_t WriteStringCallback(void* contents, size_t size, size_t nmemb, void* userp) { | ||
try { | ||
size_t realsize = size * nmemb; | ||
auto buffer = reinterpret_cast<struct StringBuffer*>(userp); | ||
buffer->ss_.write(reinterpret_cast<const char*>(contents), realsize); | ||
return realsize; | ||
} catch (...) { | ||
// exception caught, abort write | ||
return CURLcode::CURLE_WRITE_ERROR; | ||
} | ||
} | ||
|
||
using CurlWriteCallBack = size_t (*)(void*, size_t, size_t, void*); | ||
|
||
class CurlHandler { | ||
public: | ||
CurlHandler(CurlWriteCallBack call_back) : curl_(curl_easy_init(), curl_easy_cleanup), | ||
headers_(nullptr, curl_slist_free_all), | ||
from_holder_(from_, curl_formfree) { | ||
curl_easy_setopt(curl_.get(), CURLOPT_BUFFERSIZE, 102400L); | ||
curl_easy_setopt(curl_.get(), CURLOPT_NOPROGRESS, 1L); | ||
curl_easy_setopt(curl_.get(), CURLOPT_USERAGENT, "curl/7.83.1"); | ||
curl_easy_setopt(curl_.get(), CURLOPT_MAXREDIRS, 50L); | ||
curl_easy_setopt(curl_.get(), CURLOPT_FTP_SKIP_PASV_IP, 1L); | ||
curl_easy_setopt(curl_.get(), CURLOPT_TCP_KEEPALIVE, 1L); | ||
curl_easy_setopt(curl_.get(), CURLOPT_WRITEFUNCTION, call_back); | ||
} | ||
~CurlHandler() = default; | ||
|
||
void AddHeader(const char* data) { | ||
headers_.reset(curl_slist_append(headers_.release(), data)); | ||
} | ||
template <typename... Args> | ||
void AddForm(Args... args) { | ||
curl_formadd(&from_, &last_, args...); | ||
} | ||
template <typename T> | ||
void SetOption(CURLoption opt, T val) { | ||
curl_easy_setopt(curl_.get(), opt, val); | ||
} | ||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CurlHandler); | ||
CURLcode Perform() { | ||
SetOption(CURLOPT_HTTPHEADER, headers_.get()); | ||
SetOption(CURLOPT_HTTPPOST, from_); | ||
return curl_easy_perform(curl_.get()); | ||
} | ||
|
||
private: | ||
std::unique_ptr<CURL, decltype(curl_easy_cleanup)*> curl_; | ||
std::unique_ptr<curl_slist, decltype(curl_slist_free_all)*> headers_; | ||
curl_httppost* from_{}; | ||
curl_httppost* last_{}; | ||
std::unique_ptr<curl_httppost, decltype(curl_formfree)*> from_holder_; | ||
}; | ||
|
||
onnxruntime::Status OpenAIInvoker::Send(const CloudEndPointConfig& run_options, | ||
const InlinedVector<std::string>& /*input_names*/, | ||
RandySheriffH marked this conversation as resolved.
Show resolved
Hide resolved
|
||
gsl::span<const OrtValue> ort_inputs, | ||
const InlinedVector<std::string>& /*output_names*/, | ||
std::vector<OrtValue>& ort_outputs) const { | ||
const auto auth_key_iter = run_options.find(kAzureAuthKey); | ||
if (run_options.end() == auth_key_iter || auth_key_iter->second.empty()) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
"auth key must be specified for openai client"); | ||
} | ||
long verbose = 0; | ||
const auto verbose_iter = run_options.find(kAzureVerbose); | ||
if (run_options.end() != verbose_iter) { | ||
verbose = verbose_iter->second != "0" ? 1L : 0L; | ||
} | ||
|
||
CurlHandler curl_handler(WriteStringCallback); | ||
StringBuffer string_buffer; | ||
|
||
std::string full_auth = std::string{"Authorization: Bearer "} + auth_key_iter->second; | ||
curl_handler.AddHeader(full_auth.c_str()); | ||
curl_handler.AddHeader("Content-Type: multipart/form-data"); | ||
|
||
const auto& tensor = ort_inputs[0].Get<Tensor>(); | ||
auto data_size = tensor.SizeInBytes(); | ||
curl_handler.AddForm(CURLFORM_COPYNAME, "model", CURLFORM_COPYCONTENTS, model_name_.c_str(), CURLFORM_END); | ||
curl_handler.AddForm(CURLFORM_COPYNAME, "response_format", CURLFORM_COPYCONTENTS, "text", CURLFORM_END); | ||
curl_handler.AddForm(CURLFORM_COPYNAME, "file", CURLFORM_BUFFER, "non_exist.wav", CURLFORM_BUFFERPTR, tensor.DataRaw(), | ||
CURLFORM_BUFFERLENGTH, data_size, CURLFORM_END); | ||
|
||
curl_handler.SetOption(CURLOPT_URL, uri_.c_str()); | ||
curl_handler.SetOption(CURLOPT_VERBOSE, verbose); | ||
curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&string_buffer); | ||
|
||
auto curl_ret = curl_handler.Perform(); | ||
if (CURLE_OK != curl_ret) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, curl_easy_strerror(curl_ret)); | ||
} | ||
|
||
auto output_tensor = std::make_unique<Tensor>(onnxruntime::DataTypeImpl::GetType<std::string>(), TensorShape{1}, allocator_); | ||
if (!output_tensor) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor"); | ||
} | ||
|
||
auto* output_string = output_tensor->MutableData<std::string>(); | ||
*output_string = string_buffer.ss_.str(); | ||
auto tensor_type = DataTypeImpl::GetType<Tensor>(); | ||
snnn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ort_outputs.clear(); | ||
ort_outputs.emplace_back(output_tensor.release(), tensor_type, tensor_type->GetDeleteFunc()); | ||
return Status::OK(); | ||
} | ||
|
||
// AzureTritonInvoker | ||
class AzureTritonInvoker : public CloudEndPointInvoker { | ||
public: | ||
AzureTritonInvoker(const CloudEndPointConfig& config, const AllocatorPtr& allocator); | ||
|
@@ -287,6 +447,9 @@ Status CloudEndPointInvoker::CreateInvoker(const CloudEndPointConfig& config, | |
if (iter->second == kAzureTriton) { | ||
invoker = std::make_unique<AzureTritonInvoker>(config, allocator); | ||
return status; | ||
} else if (iter->second == kAzureOpenAI) { | ||
invoker = std::make_unique<OpenAIInvoker>(config, allocator); | ||
return status; | ||
} // else other endpoint types ... | ||
} | ||
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please note this library uses a non-standard license type that might need extra internal reviews.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are consuming the curl from vcpkg, not from git directly - would vcpkg licence type override the one seeing here?