Skip to content

Commit

Permalink
Implement openAI endpoint invoker for nuget (#15797)
Browse files Browse the repository at this point in the history
Implement openAI audio endpoint, and enable nuget packaging.

---------

Co-authored-by: Randy Shuai <[email protected]>
  • Loading branch information
2 people authored and Prathik Rao committed May 16, 2023
1 parent 1821f25 commit 898e440
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 25 deletions.
29 changes: 29 additions & 0 deletions ThirdPartyNotices.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5993,3 +5993,32 @@ https://github.com/tensorflow/tfjs
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

——

curl/curl

https://github.com/curl

COPYRIGHT AND PERMISSION NOTICE

Copyright (C) Daniel Stenberg, <[email protected]>, and many
contributors, see the THANKS file.

All rights reserved.

Permission to use, copy, modify, and distribute this software for any purpose
with or without fee is hereby granted, provided that the above copyright
notice and this permission notice appear in all copies.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. IN
NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
OR OTHER DEALINGS IN THE SOFTWARE.

Except as contained in this notice, the name of a copyright holder shall not
be used in advertising or otherwise to promote the sale, use or other dealings
in this Software without prior written authorization of the copyright holder.
10 changes: 10 additions & 0 deletions cgmanifests/generated/cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -448,5 +448,15 @@
"comments": "extensions"
}
}
{
"component": {
"type": "git",
"git": {
"commitHash": "b16d1fa8ee567b52c09a0f89940b07d8491b881d",
"repositoryUrl": "https://github.com/curl/curl.git"
},
"comments": "curl"
}
}
]
}
3 changes: 2 additions & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,5 @@ pthread;https://sourceforge.net/projects/pthreads4w/files/pthreads4w-code-v3.0.0
triton;https://github.com/triton-inference-server/server/archive/refs/tags/v2.28.0.zip;4b305570aa1e889946e20e36050b6770e4108fee
# above are deps introduced by triton client, might remove after 1.14 release
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c
eigen;https://gitlab.com/libeigen/eigen/-/archive/3.4/eigen-3.4.zip;ee201b07085203ea7bd8eb97cbcb31b07cfa3efb
eigen;https://gitlab.com/libeigen/eigen/-/archive/3.4/eigen-3.4.zip;ee201b07085203ea7bd8eb97cbcb31b07cfa3efb
curl;https://github.com/curl/curl/archive/refs/tags/curl-8_0_1.zip;b16d1fa8ee567b52c09a0f89940b07d8491b881d
12 changes: 5 additions & 7 deletions cmake/external/triton.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,11 @@ if (WIN32)
vcpkg_install(re2)
vcpkg_install(boost-interprocess)
vcpkg_install(boost-stacktrace)
vcpkg_install(zlib)
vcpkg_install(pthread)
vcpkg_install(b64)

add_dependencies(getb64 getpthread)
add_dependencies(getpthread getzlib)
add_dependencies(getzlib getboost-stacktrace)
add_dependencies(getpthread getboost-stacktrace)
add_dependencies(getboost-stacktrace getboost-interprocess)
add_dependencies(getboost-interprocess getre2)
add_dependencies(getre2 getrapidjson)
Expand All @@ -59,11 +57,11 @@ if (WIN32)

ExternalProject_Add(triton
GIT_REPOSITORY https://github.com/triton-inference-server/client.git
GIT_TAG r22.12
GIT_TAG r23.05
PREFIX triton
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build
CMAKE_ARGS -DVCPKG_TARGET_TRIPLET=${onnxruntime_target_platform}-windows -DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON
CMAKE_ARGS -DVCPKG_TARGET_TRIPLET=${onnxruntime_target_platform}-windows -DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF
INSTALL_COMMAND ""
UPDATE_COMMAND "")

Expand All @@ -85,11 +83,11 @@ else()

ExternalProject_Add(triton
GIT_REPOSITORY https://github.com/triton-inference-server/client.git
GIT_TAG r22.12
GIT_TAG r23.05
PREFIX triton
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF
INSTALL_COMMAND ""
UPDATE_COMMAND "")

Expand Down
6 changes: 3 additions & 3 deletions cmake/onnxruntime_framework.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,19 @@ onnxruntime_add_static_library(onnxruntime_framework ${onnxruntime_framework_src
if (onnxruntime_USE_AZURE)

add_dependencies(onnxruntime_framework triton)
target_include_directories(onnxruntime_framework PRIVATE ${TRITON_BIN}/include)
target_include_directories(onnxruntime_framework PRIVATE ${TRITON_BIN}/include ${TRITON_THIRD_PARTY}/curl/include)
link_directories(${TRITON_BIN}/lib ${TRITON_BIN}/lib64 ${TRITON_THIRD_PARTY}/curl/lib ${TRITON_THIRD_PARTY}/curl/lib64)

if (WIN32)

link_directories(${VCPKG_SRC}/installed/${onnxruntime_target_platform}-windows/lib)
target_link_libraries(onnxruntime_framework PRIVATE libcurl httpclient_static ws2_32 crypt32 Wldap32 zlib)
target_link_libraries(onnxruntime_framework PRIVATE libcurl httpclient_static ws2_32 crypt32 Wldap32)

else()

find_package(ZLIB REQUIRED)
find_package(OpenSSL REQUIRED)
target_link_libraries(onnxruntime_framework PRIVATE httpclient_static curl ZLIB::ZLIB OpenSSL::Crypto OpenSSL::SSL)
target_link_libraries(onnxruntime_framework PRIVATE httpclient_static curl OpenSSL::Crypto OpenSSL::SSL)

endif() #if (WIN32)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
<!-- only set the .net6 targets if we're building an ORT package.
we can add .net6 support to other packages later as needed -->
<PropertyGroup Condition="('$(OrtPackageId)' == 'Microsoft.ML.OnnxRuntime' OR
'$(OrtPackageId)' == 'Microsoft.ML.OnnxRuntime.Azure' OR
'$(OrtPackageId)' == 'Microsoft.ML.OnnxRuntime.Gpu')">
<Net6Targets>net6.0;net6.0-android;net6.0-ios;net6.0-macos</Net6Targets>
</PropertyGroup>
Expand Down
4 changes: 2 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,10 @@ public IntPtr Appender(IntPtr handle, IntPtr[] optKeys, IntPtr[] optValues, UInt
/// <param name="providerOptions">Optional key/value pairs to specify execution provider options.</param>
public void AppendExecutionProvider(string providerName, Dictionary<string, string> providerOptions = null)
{
if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN")
if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" && providerName != "AZURE")
{
throw new NotSupportedException(
"Only QNN, SNPE and XNNPACK execution providers can be enabled by this method.");
"Only QNN, SNPE, XNNPACK and AZURE execution providers can be enabled by this method.");
}

if (providerOptions == null)
Expand Down
3 changes: 2 additions & 1 deletion java/src/main/java/ai/onnxruntime/OrtProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ public enum OrtProvider {
ARM_NN("ArmNNExecutionProvider"),
ROCM("ROCMExecutionProvider"),
CORE_ML("CoreMLExecutionProvider"),
XNNPACK("XnnpackExecutionProvider");
XNNPACK("XnnpackExecutionProvider"),
AZURE("AzureExecutionProvider");

private static final Map<String, OrtProvider> valueMap = new HashMap<>(values().length);

Expand Down
177 changes: 170 additions & 7 deletions onnxruntime/core/framework/cloud_invoker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// Licensed under the MIT License.

#ifdef USE_AZURE
#define CURL_STATICLIB
#include "http_client.h"
#include "curl/curl.h"
#include "core/common/common.h"
#include "core/framework/cloud_invoker.h"
#include "core/framework/ort_value.h"
Expand All @@ -18,13 +20,14 @@ namespace onnxruntime {

namespace tc = triton::client;

const char* kAzureUri = "azure.uri";
const char* kAzureModelName = "azure.model_name";
const char* kAzureModelVer = "azure.model_version";
const char* kAzureVerbose = "azure.verbose";
const char* kAzureEndpointType = "azure.endpoint_type";
const char* kAzureAuthKey = "azure.auth_key";
const char* kAzureTriton = "triton";
constexpr const char* kAzureUri = "azure.uri";
constexpr const char* kAzureModelName = "azure.model_name";
constexpr const char* kAzureModelVer = "azure.model_version";
constexpr const char* kAzureVerbose = "azure.verbose";
constexpr const char* kAzureEndpointType = "azure.endpoint_type";
constexpr const char* kAzureAuthKey = "azure.auth_key";
constexpr const char* kAzureTriton = "triton";
constexpr const char* kAzureOpenAI = "openai";

CloudEndPointInvoker::CloudEndPointInvoker(const CloudEndPointConfig& config,
const AllocatorPtr& allocator) : config_(config), allocator_(allocator) {
Expand All @@ -33,6 +36,163 @@ CloudEndPointInvoker::CloudEndPointInvoker(const CloudEndPointConfig& config,
}
}

class CurlGlobal {
public:
static void Init() {
static CurlGlobal curl_global;
}

private:
CurlGlobal() {
// Thread-safety is a must since curl might also be initialized in triton client.
const auto* info = curl_version_info(CURLVERSION_NOW);
ORT_ENFORCE(info->features & CURL_VERSION_THREADSAFE, "curl global init not thread-safe, need to upgrade curl version!");
ORT_ENFORCE(curl_global_init(CURL_GLOBAL_DEFAULT) == CURLE_OK, "Failed to initialize curl global env!");
}
~CurlGlobal() {
curl_global_cleanup();
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CurlGlobal);
};

// OpenAIInvoker
class OpenAIInvoker : public CloudEndPointInvoker {
public:
OpenAIInvoker(const CloudEndPointConfig& config, const AllocatorPtr& allocator);
onnxruntime::Status Send(const CloudEndPointConfig& run_options,
const InlinedVector<std::string>& input_names,
gsl::span<const OrtValue> ort_inputs,
const InlinedVector<std::string>& output_names,
std::vector<OrtValue>& ort_outputs) const override;

private:
std::string uri_;
std::string model_name_;
};

OpenAIInvoker::OpenAIInvoker(const CloudEndPointConfig& config,
const AllocatorPtr& allocator) : CloudEndPointInvoker(config, allocator) {
CurlGlobal::Init();
ReadConfig(kAzureUri, uri_);
ReadConfig(kAzureModelName, model_name_);
}

struct StringBuffer {
StringBuffer() = default;
~StringBuffer() = default;
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(StringBuffer);
std::stringstream ss_;
};

// apply the callback only when response is for sure to be a '/0' terminated string
static size_t WriteStringCallback(void* contents, size_t size, size_t nmemb, void* userp) {
try {
size_t realsize = size * nmemb;
auto buffer = reinterpret_cast<struct StringBuffer*>(userp);
buffer->ss_.write(reinterpret_cast<const char*>(contents), realsize);
return realsize;
} catch (...) {
// exception caught, abort write
return CURLcode::CURLE_WRITE_ERROR;
}
}

using CurlWriteCallBack = size_t (*)(void*, size_t, size_t, void*);

class CurlHandler {
public:
CurlHandler(CurlWriteCallBack call_back) : curl_(curl_easy_init(), curl_easy_cleanup),
headers_(nullptr, curl_slist_free_all),
from_holder_(from_, curl_formfree) {
curl_easy_setopt(curl_.get(), CURLOPT_BUFFERSIZE, 102400L);
curl_easy_setopt(curl_.get(), CURLOPT_NOPROGRESS, 1L);
curl_easy_setopt(curl_.get(), CURLOPT_USERAGENT, "curl/7.83.1");
curl_easy_setopt(curl_.get(), CURLOPT_MAXREDIRS, 50L);
curl_easy_setopt(curl_.get(), CURLOPT_FTP_SKIP_PASV_IP, 1L);
curl_easy_setopt(curl_.get(), CURLOPT_TCP_KEEPALIVE, 1L);
curl_easy_setopt(curl_.get(), CURLOPT_WRITEFUNCTION, call_back);
}
~CurlHandler() = default;

void AddHeader(const char* data) {
headers_.reset(curl_slist_append(headers_.release(), data));
}
template <typename... Args>
void AddForm(Args... args) {
curl_formadd(&from_, &last_, args...);
}
template <typename T>
void SetOption(CURLoption opt, T val) {
curl_easy_setopt(curl_.get(), opt, val);
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CurlHandler);
CURLcode Perform() {
SetOption(CURLOPT_HTTPHEADER, headers_.get());
SetOption(CURLOPT_HTTPPOST, from_);
return curl_easy_perform(curl_.get());
}

private:
std::unique_ptr<CURL, decltype(curl_easy_cleanup)*> curl_;
std::unique_ptr<curl_slist, decltype(curl_slist_free_all)*> headers_;
curl_httppost* from_{};
curl_httppost* last_{};
std::unique_ptr<curl_httppost, decltype(curl_formfree)*> from_holder_;
};

onnxruntime::Status OpenAIInvoker::Send(const CloudEndPointConfig& run_options,
const InlinedVector<std::string>& /*input_names*/,
gsl::span<const OrtValue> ort_inputs,
const InlinedVector<std::string>& /*output_names*/,
std::vector<OrtValue>& ort_outputs) const {
const auto auth_key_iter = run_options.find(kAzureAuthKey);
if (run_options.end() == auth_key_iter || auth_key_iter->second.empty()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"auth key must be specified for openai client");
}
long verbose = 0;
const auto verbose_iter = run_options.find(kAzureVerbose);
if (run_options.end() != verbose_iter) {
verbose = verbose_iter->second != "0" ? 1L : 0L;
}

CurlHandler curl_handler(WriteStringCallback);
StringBuffer string_buffer;

std::string full_auth = std::string{"Authorization: Bearer "} + auth_key_iter->second;
curl_handler.AddHeader(full_auth.c_str());
curl_handler.AddHeader("Content-Type: multipart/form-data");

const auto& tensor = ort_inputs[0].Get<Tensor>();
auto data_size = tensor.SizeInBytes();
curl_handler.AddForm(CURLFORM_COPYNAME, "model", CURLFORM_COPYCONTENTS, model_name_.c_str(), CURLFORM_END);
curl_handler.AddForm(CURLFORM_COPYNAME, "response_format", CURLFORM_COPYCONTENTS, "text", CURLFORM_END);
curl_handler.AddForm(CURLFORM_COPYNAME, "file", CURLFORM_BUFFER, "non_exist.wav", CURLFORM_BUFFERPTR, tensor.DataRaw(),
CURLFORM_BUFFERLENGTH, data_size, CURLFORM_END);

curl_handler.SetOption(CURLOPT_URL, uri_.c_str());
curl_handler.SetOption(CURLOPT_VERBOSE, verbose);
curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&string_buffer);

auto curl_ret = curl_handler.Perform();
if (CURLE_OK != curl_ret) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, curl_easy_strerror(curl_ret));
}

auto output_tensor = std::make_unique<Tensor>(onnxruntime::DataTypeImpl::GetType<std::string>(), TensorShape{1}, allocator_);
if (!output_tensor) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor");
}

auto* output_string = output_tensor->MutableData<std::string>();
*output_string = string_buffer.ss_.str();
auto tensor_type = DataTypeImpl::GetType<Tensor>();
ort_outputs.clear();
ort_outputs.emplace_back(output_tensor.release(), tensor_type, tensor_type->GetDeleteFunc());
return Status::OK();
}

// AzureTritonInvoker
class AzureTritonInvoker : public CloudEndPointInvoker {
public:
AzureTritonInvoker(const CloudEndPointConfig& config, const AllocatorPtr& allocator);
Expand Down Expand Up @@ -287,6 +447,9 @@ Status CloudEndPointInvoker::CreateInvoker(const CloudEndPointConfig& config,
if (iter->second == kAzureTriton) {
invoker = std::make_unique<AzureTritonInvoker>(config, allocator);
return status;
} else if (iter->second == kAzureOpenAI) {
invoker = std::make_unique<OpenAIInvoker>(config, allocator);
return status;
} // else other endpoint types ...
}
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ parameters:
type: string
default: '0'

- name: NugetPackageSuffix
displayName: Suffix to append to nuget package
type: string
default: ''

- name: AdditionalBuildFlag
displayName: Build flags to append to build command
type: string
default: ''

resources:
repositories:
- repository: onnxruntime-inference-examples # The name used to reference this repository in the checkout step
Expand Down Expand Up @@ -100,9 +110,9 @@ stages:
DoCompliance: ${{ parameters.DoCompliance }}
DoEsrp: ${{ parameters.DoEsrp }}
IsReleaseBuild: ${{ parameters.IsReleaseBuild }}
OrtNugetPackageId: 'Microsoft.ML.OnnxRuntime'
OrtNugetPackageId: 'Microsoft.ML.OnnxRuntime${{ parameters.NugetPackageSuffix }}'
AdditionalBuildFlags: ''
AdditionalWinBuildFlags: '--enable_onnx_tests --enable_wcos'
AdditionalWinBuildFlags: '--enable_onnx_tests --enable_wcos ${{parameters.AdditionalBuildFlag}}'
BuildVariant: 'default'
SpecificArtifact: ${{ parameters.SpecificArtifact }}
BuildId: ${{ parameters.BuildId }}
Expand Down
Loading

0 comments on commit 898e440

Please sign in to comment.