Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement openAI endpoint invoker for nuget #15797

Merged
merged 25 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
a1feb54
implement openAi endpoint in nuget
RandyShuai May 3, 2023
dfc5461
make memory allocation RAII
RandyShuai May 4, 2023
8aaac2f
add curl to deps
RandyShuai May 4, 2023
6cdae88
Merge branch 'main' into rashuai/AzureEP_1_15_nuget
RandyShuai May 4, 2023
5c7c67c
wrap up curl handles
RandyShuai May 4, 2023
842f7fd
set curl dep commit id
RandyShuai May 4, 2023
1fec44a
address comments
RandyShuai May 4, 2023
a6f18e3
write sized data to stream
RandyShuai May 5, 2023
9c6bb2e
attach curl license
RandyShuai May 5, 2023
7e14fd6
merge main
RandyShuai May 9, 2023
c0e72cc
merge main
RandyShuai May 9, 2023
e558c32
remove zlib from cmake
RandyShuai May 9, 2023
6b88835
update 3rd party notice
RandyShuai May 10, 2023
dc7ae1f
merge main
RandyShuai May 10, 2023
94279c1
stick to triton client r23.05
RandyShuai May 10, 2023
2973579
Merge branch 'main' of https://github.com/microsoft/onnxruntime
RandyShuai May 10, 2023
952eb52
Merge branch 'main' into rashuai/AzureEP_1_15_nuget
RandyShuai May 10, 2023
ae7d534
fix comments
RandyShuai May 11, 2023
7e51a9c
fix header issue
RandyShuai May 11, 2023
fdad7c7
format code
RandyShuai May 11, 2023
c013693
Merge branch 'main' of https://github.com/microsoft/onnxruntime
RandyShuai May 11, 2023
e0dcb22
Merge branch 'main' into rashuai/AzureEP_1_15_nuget
RandyShuai May 11, 2023
866db7e
clear fetch before emplace
RandyShuai May 11, 2023
d099bd6
Merge branch 'main' of https://github.com/microsoft/onnxruntime
RandyShuai May 11, 2023
17ae944
resolve conflict
RandyShuai May 11, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ boost;https://github.com/boostorg/boost/archive/refs/tags/boost-1.81.0.zip;f6ab0
b64;https://github.com/libb64/libb64/archive/refs/tags/v2.0.0.1.zip;815b6d31d50d9e63df55b25ce555e7b787153c28
pthread;https://sourceforge.net/projects/pthreads4w/files/pthreads4w-code-v3.0.0.zip;3b9e417e4474c34542b76ad40529e396ac109fb4
triton;https://github.com/triton-inference-server/server/archive/refs/tags/v2.28.0.zip;4b305570aa1e889946e20e36050b6770e4108fee
# above are deps introduced by triton client, might remove after 1.14 release
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/81e7799c69044c745239202085eb0a98f102937b.zip;d53487035174a046628359289ad27aa0ac0380c9
curl;https://github.com/curl/curl/archive/refs/tags/curl-8_0_1.zip;b16d1fa8ee567b52c09a0f89940b07d8491b881d
2 changes: 1 addition & 1 deletion cmake/onnxruntime_framework.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ onnxruntime_add_static_library(onnxruntime_framework ${onnxruntime_framework_src
if (onnxruntime_USE_AZURE)

add_dependencies(onnxruntime_framework triton)
target_include_directories(onnxruntime_framework PRIVATE ${TRITON_BIN}/include)
target_include_directories(onnxruntime_framework PRIVATE ${TRITON_BIN}/include ${TRITON_THIRD_PARTY}/curl/include)
link_directories(${TRITON_BIN}/lib ${TRITON_BIN}/lib64 ${TRITON_THIRD_PARTY}/curl/lib ${TRITON_THIRD_PARTY}/curl/lib64)

if (WIN32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
<!-- only set the .net6 targets if we're building an ORT package.
we can add .net6 support to other packages later as needed -->
<PropertyGroup Condition="('$(OrtPackageId)' == 'Microsoft.ML.OnnxRuntime' OR
'$(OrtPackageId)' == 'Microsoft.ML.OnnxRuntime.Azure' OR
'$(OrtPackageId)' == 'Microsoft.ML.OnnxRuntime.Gpu')">
<Net6Targets>net6.0;net6.0-android;net6.0-ios;net6.0-macos</Net6Targets>
</PropertyGroup>
Expand Down
4 changes: 2 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,10 @@ public IntPtr Appender(IntPtr handle, IntPtr[] optKeys, IntPtr[] optValues, UInt
/// <param name="providerOptions">Optional key/value pairs to specify execution provider options.</param>
public void AppendExecutionProvider(string providerName, Dictionary<string, string> providerOptions = null)
{
if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN")
if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" && providerName != "AZURE")
{
throw new NotSupportedException(
"Only QNN, SNPE and XNNPACK execution providers can be enabled by this method.");
"Only QNN, SNPE, XNNPACK and AZURE execution providers can be enabled by this method.");
}

if (providerOptions == null)
Expand Down
3 changes: 2 additions & 1 deletion java/src/main/java/ai/onnxruntime/OrtProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ public enum OrtProvider {
ARM_NN("ArmNNExecutionProvider"),
ROCM("ROCMExecutionProvider"),
CORE_ML("CoreMLExecutionProvider"),
XNNPACK("XnnpackExecutionProvider");
XNNPACK("XnnpackExecutionProvider"),
AZURE("AzureExecutionProvider");
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved

private static final Map<String, OrtProvider> valueMap = new HashMap<>(values().length);

Expand Down
189 changes: 182 additions & 7 deletions onnxruntime/core/framework/cloud_invoker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// Licensed under the MIT License.

#ifdef USE_AZURE
#define CURL_STATICLIB
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
#include "http_client.h"
#include "curl/curl.h"
#include "core/common/common.h"
#include "core/framework/cloud_invoker.h"
#include "core/framework/ort_value.h"
Expand All @@ -18,13 +20,14 @@ namespace onnxruntime {

namespace tc = triton::client;

const char* kAzureUri = "azure.uri";
const char* kAzureModelName = "azure.model_name";
const char* kAzureModelVer = "azure.model_version";
const char* kAzureVerbose = "azure.verbose";
const char* kAzureEndpointType = "azure.endpoint_type";
const char* kAzureAuthKey = "azure.auth_key";
const char* kAzureTriton = "triton";
constexpr const char* kAzureUri = "azure.uri";
constexpr const char* kAzureModelName = "azure.model_name";
constexpr const char* kAzureModelVer = "azure.model_version";
constexpr const char* kAzureVerbose = "azure.verbose";
constexpr const char* kAzureEndpointType = "azure.endpoint_type";
constexpr const char* kAzureAuthKey = "azure.auth_key";
constexpr const char* kAzureTriton = "triton";
constexpr const char* kAzureOpenAI = "openai";

CloudEndPointInvoker::CloudEndPointInvoker(const CloudEndPointConfig& config,
const AllocatorPtr& allocator) : config_(config), allocator_(allocator) {
Expand All @@ -33,6 +36,175 @@ CloudEndPointInvoker::CloudEndPointInvoker(const CloudEndPointConfig& config,
}
}

class CurlGlobal {
public:
static void Init() {
static CurlGlobal curl_global;
}

private:
CurlGlobal() {
// Thread-safety is a must since curl might also be initialized in triton client.
const auto* info = curl_version_info(CURLVERSION_NOW);
ORT_ENFORCE(info->features & CURL_VERSION_THREADSAFE, "curl global init not thread-safe, need to upgrade curl version!");
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
ORT_ENFORCE(curl_global_init(CURL_GLOBAL_DEFAULT) == CURLE_OK, "Failed to initialize curl global env!");
Copy link
Member

@snnn snnn May 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even through curl_global_init is thread-safe, we still should not call it from any global variable's constructor since at that time the static variables in libcurl itself might have not been initialized. So I think the best way to ensure this is to put this call in OrtEnv's constructor.

Copy link
Contributor Author

@RandySheriffH RandySheriffH May 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't follow, how come curl statics are not initialized on OpenAIInvoker::OpenAIInvoker(...), while OrtEnv constructor could avoid that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked with Randy offline. The most difficult part is: this file also uses TritonClient lib which also has similar code that initialize and deinit curl. I don't see an easy way to coordinate them together. Will leave the discussion later.

}
~CurlGlobal() {
curl_global_cleanup();
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CurlGlobal);
};

// OpenAIInvoker
class OpenAIInvoker : public CloudEndPointInvoker {
public:
OpenAIInvoker(const CloudEndPointConfig& config, const AllocatorPtr& allocator);
onnxruntime::Status Send(const CloudEndPointConfig& run_options,
const InlinedVector<std::string>& input_names,
gsl::span<const OrtValue> ort_inputs,
const InlinedVector<std::string>& output_names,
std::vector<OrtValue>& ort_outputs) const override;

private:
std::string uri_;
std::string model_name_;
};

OpenAIInvoker::OpenAIInvoker(const CloudEndPointConfig& config,
const AllocatorPtr& allocator) : CloudEndPointInvoker(config, allocator) {
CurlGlobal::Init();
ReadConfig(kAzureUri, uri_);
ReadConfig(kAzureModelName, model_name_);
}

struct StringBuffer {
StringBuffer() = default;
~StringBuffer() = default;
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(StringBuffer);
std::stringstream ss_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you need a std::string, instead of stringstream.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stringstream buffers extra space from coming contents, which is more suitable for the scenario here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't std::string do the same thing? Every string has a size and capacity. Capacity means how much memory has been allocated for this string. Size means how much memory that has been used.

Copy link
Contributor Author

@RandySheriffH RandySheriffH May 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stringstream is designd for buffering content grows dynamically.

};

// applies only when contents is a string
static size_t WriteStringCallback(void* contents, size_t size, size_t nmemb, void* userp) {
size_t realsize = size * nmemb;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the C style callbacks should not throw exceptions, since C could not handle C++ exceptions. So all such callbacks need be marked as nothrow.

Copy link
Contributor Author

@RandySheriffH RandySheriffH May 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if c++ exceptions throw from the callback, wouldn't ort be the catcher?
For sure it's none of curl's business.
Another thing, suppose exception for some reason indeed occurred in write-callback, wouldn't it be mostly appropriate to throw it to the upper logic since the response is totally busted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sdnow returns CURLE_WRITE_ERROR on exception.

auto buffer = reinterpret_cast<struct StringBuffer*>(userp);
buffer->ss_ << reinterpret_cast<const char*>(contents);
return realsize;
}

using CurlWriteCallBack = size_t (*)(void*, size_t, size_t, void*);

class CurlHandler {
public:
CurlHandler(CurlWriteCallBack call_back) {
curl_ = curl_easy_init();
curl_easy_setopt(curl_, CURLOPT_BUFFERSIZE, 102400L);
curl_easy_setopt(curl_, CURLOPT_NOPROGRESS, 1L);
curl_easy_setopt(curl_, CURLOPT_USERAGENT, "curl/7.83.1");
curl_easy_setopt(curl_, CURLOPT_MAXREDIRS, 50L);
curl_easy_setopt(curl_, CURLOPT_FTP_SKIP_PASV_IP, 1L);
curl_easy_setopt(curl_, CURLOPT_TCP_KEEPALIVE, 1L);
curl_easy_setopt(curl_, CURLOPT_WRITEFUNCTION, call_back);
}
~CurlHandler() {
if (curl_) {
curl_easy_cleanup(curl_);
curl_ = {};
}
if (mime1_) {
curl_mime_free(mime1_);
mime1_ = {};
}
if (headers_) {
curl_slist_free_all(headers_);
headers_ = {};
}
if (from_) {
curl_formfree(from_);
from_ = {};
}
}
void AddHeader(const char* data) {
headers_ = curl_slist_append(headers_, data);
}
template <typename... Args>
void AddForm(Args... args) {
curl_formadd(&from_, &last_, args...);
}
template <typename T>
void SetOption(CURLoption opt, T val) {
curl_easy_setopt(curl_, opt, val);
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CurlHandler);
CURLcode Perform() {
SetOption(CURLOPT_HTTPHEADER, headers_);
SetOption(CURLOPT_HTTPPOST, from_);
return curl_easy_perform(curl_);
}

private:
CURL* curl_{};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may change it to a std::unique_ptr. See https://stackoverflow.com/questions/27440953/stdunique-ptr-for-c-functions-that-need-free for an example. Instead of using std::free to free the memory, in your case you need to use curl_easy_cleanup

Copy link
Contributor Author

@RandySheriffH RandySheriffH May 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine but we got to call curl_easy_cleanup whatsoever in destructor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We see std::unique_ptr is safer than raw pointers, which means your current code is not wrong, but using std::unique_ptr could make the code easier to read and verify. Just a suggestion. You don't have to take it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

curl_mime* mime1_{};
struct curl_slist* headers_{};
struct curl_httppost* from_{};
struct curl_httppost* last_{};
};

onnxruntime::Status OpenAIInvoker::Send(const CloudEndPointConfig& run_options,
const InlinedVector<std::string>& /*input_names*/,
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
gsl::span<const OrtValue> ort_inputs,
const InlinedVector<std::string>& /*output_names*/,
std::vector<OrtValue>& ort_outputs) const {
const auto auth_key_iter = run_options.find(kAzureAuthKey);
if (run_options.end() == auth_key_iter || auth_key_iter->second.empty()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"auth key must be specified for openai client");
}
long verbose = 0;
const auto verbose_iter = run_options.find(kAzureVerbose);
if (run_options.end() != verbose_iter) {
verbose = verbose_iter->second != "0" ? 1L : 0L;
}

CURLcode ret{};
CurlHandler curl_handler(WriteStringCallback);
StringBuffer string_buffer;

std::string full_auth = std::string{"Authorization: Bearer "} + auth_key_iter->second;
curl_handler.AddHeader(full_auth.c_str());
curl_handler.AddHeader("Content-Type: multipart/form-data");

const auto& tensor = ort_inputs[0].Get<Tensor>();
auto data_size = tensor.SizeInBytes();
curl_handler.AddForm(CURLFORM_COPYNAME, "model", CURLFORM_COPYCONTENTS, model_name_.c_str(), CURLFORM_END);
curl_handler.AddForm(CURLFORM_COPYNAME, "response_format", CURLFORM_COPYCONTENTS, "text", CURLFORM_END);
curl_handler.AddForm(CURLFORM_COPYNAME, "file", CURLFORM_BUFFER, "non_exist.wav", CURLFORM_BUFFERPTR, tensor.DataRaw(),
CURLFORM_BUFFERLENGTH, data_size, CURLFORM_END);

curl_handler.SetOption(CURLOPT_URL, uri_.c_str());
curl_handler.SetOption(CURLOPT_VERBOSE, verbose);
curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&string_buffer);

ret = curl_handler.Perform();
if (ret != CURLE_OK) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, curl_easy_strerror(ret));
}

auto output_tensor = std::make_unique<Tensor>(onnxruntime::DataTypeImpl::GetType<std::string>(), TensorShape{1}, allocator_);
if (!output_tensor) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor");
}

auto* output_string = output_tensor->MutableData<std::string>();
*output_string = string_buffer.ss_.str();

ort_outputs.resize(1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just emplace_back output_tensor to the vector?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto tensor_type = DataTypeImpl::GetType<Tensor>();
snnn marked this conversation as resolved.
Show resolved Hide resolved
ort_outputs[0].Init(output_tensor.release(), tensor_type, tensor_type->GetDeleteFunc());
return Status::OK();
}

// AzureTritonInvoker
class AzureTritonInvoker : public CloudEndPointInvoker {
public:
AzureTritonInvoker(const CloudEndPointConfig& config, const AllocatorPtr& allocator);
Expand Down Expand Up @@ -287,6 +459,9 @@ Status CloudEndPointInvoker::CreateInvoker(const CloudEndPointConfig& config,
if (iter->second == kAzureTriton) {
invoker = std::make_unique<AzureTritonInvoker>(config, allocator);
return status;
} else if (iter->second == kAzureOpenAI) {
invoker = std::make_unique<OpenAIInvoker>(config, allocator);
return status;
} // else other endpoint types ...
}
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ parameters:
type: boolean
default: false

- name: NugetPackageSuffix
displayName: Suffix to append to nuget package
type: string
default: ''

- name: AdditionalBuildFlag
displayName: Build flags to append to build command
type: string
default: ''

resources:
repositories:
- repository: onnxruntime-inference-examples # The name used to reference this repository in the checkout step
Expand All @@ -43,9 +53,9 @@ stages:
DoCompliance: ${{ parameters.DoCompliance }}
DoEsrp: ${{ parameters.DoEsrp }}
IsReleaseBuild: ${{ parameters.IsReleaseBuild }}
OrtNugetPackageId: 'Microsoft.ML.OnnxRuntime'
OrtNugetPackageId: 'Microsoft.ML.OnnxRuntime${{ parameters.NugetPackageSuffix }}'
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
AdditionalBuildFlags: ''
AdditionalWinBuildFlags: '--enable_onnx_tests --enable_wcos'
AdditionalWinBuildFlags: '--enable_onnx_tests --enable_wcos ${{parameters.AdditionalBuildFlag}}'
BuildVariant: 'default'

- template: templates/ondevice-training-cpu-packaging-pipeline.yml
Expand Down
23 changes: 22 additions & 1 deletion tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,25 @@ stages:
ORT_EP_NAME: CPU
GenerateDocumentation: false
WITH_CACHE: true
MachinePool: 'onnxruntime-Win2019-CPU-training-AMD'
MachinePool: 'onnxruntime-Win2019-CPU-training-AMD'

- stage: x64_release_azure
dependsOn: []
jobs:
- template: templates/win-ci-vs-2019.yml
parameters:
BuildConfig: 'RelWithDebInfo'
EnvSetupScript: setup_env_azure.bat
buildArch: x64
additionalBuildFlags: --use_azure
msbuildPlatform: x64
isX86: false
job_name_suffix: x64_release_azure
RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }}
RunStaticCodeAnalysis: false
EnablePython: false
isTraining: false
ORT_EP_NAME: CPU
GenerateDocumentation: false
WITH_CACHE: true
MachinePool: 'onnxruntime-Win-CPU-2019'