Skip to content

Commit

Permalink
feat: add support for custom OpenAI base URL
Browse files Browse the repository at this point in the history
Co-authored-by: Sam McLeod <[email protected]>

Close #310
  • Loading branch information
leeebo committed Nov 20, 2023
1 parent f27c58e commit a62dccb
Show file tree
Hide file tree
Showing 13 changed files with 68 additions and 43 deletions.
7 changes: 7 additions & 0 deletions components/openai/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# ChangeLog

## v0.2.0 - 2023-11-15

### Enhancements

* Support config the OpenAI Base URL through API `OpenAIChangeBaseURL` or `menuconfig`. Thanks [@Sam McLeod](https://github.com/sammcj)
* Using MP3 format audio file in unit test

## v0.1.2 - 2023-8-4

### Docs
Expand Down
8 changes: 7 additions & 1 deletion components/openai/Kconfig
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
menu "OpenAI"

config DEFAULT_OPENAI_BASE_URL
string "Default Base URL"
default "https://api.openai.com/v1/"
help
Default Base URL for OpenAI API

config ENABLE_EMBEDDING
bool "Enable Embedding"
default y
help
Enable OpenAI Embedding

config ENABLE_MODERATION
bool "Enable Moderation"
default y
Expand Down
66 changes: 38 additions & 28 deletions components/openai/OpenAI.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

static const char *TAG = "OpenAI";

#define OPENAI_DEFAULT_BASE_URL CONFIG_DEFAULT_OPENAI_BASE_URL

#define OPENAI_ERROR_CHECK(a, str, ret) if(!(a)) { \
ESP_LOGE(TAG,"%s:%d (%s):%s", __FILE__, __LINE__, __FUNCTION__, str); \
return (ret); \
Expand Down Expand Up @@ -121,11 +123,12 @@ char *getJsonError(cJSON *json)
typedef struct {
OpenAI_t parent; /*!< Parent object */
char *api_key; /*!< API key for OpenAI */
char *base_url; /*!< Base URL for OpenAI or Other compatible API */

char *(*get)(const char *api_key, const char *endpoint); /*!< Perform an HTTP GET request. */
char *(*del)(const char *api_key, const char *endpoint); /*!< Perform an HTTP DELETE request. */
char *(*post)(const char *api_key, const char *endpoint, char *jsonBody); /*!< Perform an HTTP POST request. */
char *(*upload)(const char *api_key, const char *endpoint, const char *boundary, uint8_t *data, size_t len); /*!< Upload data using an HTTP request. */
char *(*get)(const char *base_url, const char *api_key, const char *endpoint); /*!< Perform an HTTP GET request. */
char *(*del)(const char *base_url, const char *api_key, const char *endpoint); /*!< Perform an HTTP DELETE request. */
char *(*post)(const char *base_url, const char *api_key, const char *endpoint, char *jsonBody); /*!< Perform an HTTP POST request. */
char *(*upload)(const char *base_url, const char *api_key, const char *endpoint, const char *boundary, uint8_t *data, size_t len); /*!< Upload data using an HTTP request. */
} _OpenAI_t;

//
Expand Down Expand Up @@ -867,7 +870,7 @@ static OpenAI_StringResponse_t *OpenAI_CompletionPrompt(OpenAI_Completion_t *com
}
char *jsonBody = cJSON_Print(req);
cJSON_Delete(req);
char *res = _completion->oai->post(_completion->oai->api_key, endpoint, jsonBody);
char *res = _completion->oai->post(_completion->oai->base_url, _completion->oai->api_key, endpoint, jsonBody);
free(jsonBody);
OPENAI_ERROR_CHECK(res != NULL, "OpenAI API call failed", NULL);

Expand Down Expand Up @@ -1143,7 +1146,7 @@ OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *c
}
char *jsonBody = cJSON_Print(req);
cJSON_Delete(req);
char *res = _chatCompletion->oai->post(_chatCompletion->oai->api_key, endpoint, jsonBody);
char *res = _chatCompletion->oai->post(_chatCompletion->oai->base_url, _chatCompletion->oai->api_key, endpoint, jsonBody);
free(jsonBody);
OPENAI_ERROR_CHECK(res != NULL, "Empty result!", result);
if (save) {
Expand Down Expand Up @@ -1281,7 +1284,7 @@ OpenAI_StringResponse_t *OpenAI_EditProcess(OpenAI_Edit_t *edit, char *instructi
}
char *jsonBody = cJSON_Print(req);
cJSON_Delete(req);
char *res = _edit->oai->post(_edit->oai->api_key, endpoint, jsonBody);
char *res = _edit->oai->post(_edit->oai->base_url, _edit->oai->api_key, endpoint, jsonBody);
free(jsonBody);
OPENAI_ERROR_CHECK(res != NULL, "Empty result!", result);
return OpenAI_StringResponseCreate(res);
Expand Down Expand Up @@ -1403,7 +1406,7 @@ OpenAI_ImageResponse_t *OpenAI_ImageGenerationPrompt(OpenAI_ImageGeneration_t *i
}
char *jsonBody = cJSON_Print(req);
cJSON_Delete(req);
char *res = _imageGeneration->oai->post(_imageGeneration->oai->api_key, endpoint, jsonBody);
char *res = _imageGeneration->oai->post(_imageGeneration->oai->base_url, _imageGeneration->oai->api_key, endpoint, jsonBody);
free(jsonBody);
OPENAI_ERROR_CHECK(res != NULL, "Empty result!", result);
return OpenAI_ImageResponseCreate(res);
Expand Down Expand Up @@ -1544,7 +1547,7 @@ static OpenAI_ImageResponse_t *OpenAI_ImageVariationImage(OpenAI_ImageVariation_
free(reqBody);
free(reqEndBody);
free(itemPrefix);
char *res = _imageVariation->oai->upload(_imageVariation->oai->api_key, endpoint, boundary, data, len);
char *res = _imageVariation->oai->upload(_imageVariation->oai->base_url, _imageVariation->oai->api_key, endpoint, boundary, data, len);
free(data);
OPENAI_ERROR_CHECK(res != NULL, "Empty result!", NULL);
return OpenAI_ImageResponseCreate(res);
Expand Down Expand Up @@ -1722,7 +1725,7 @@ static OpenAI_ImageResponse_t *OpenAI_ImageEditImage(OpenAI_ImageEdit_t *imageEd
if (maskBody != NULL) {
free(maskBody);
}
char *res = _imageEdit->oai->upload(_imageEdit->oai->api_key, endpoint, boundary, data, len);
char *res = _imageEdit->oai->upload(_imageEdit->oai->base_url, _imageEdit->oai->api_key, endpoint, boundary, data, len);
free(data);
OPENAI_ERROR_CHECK(res != NULL, "Empty result!", NULL);
return OpenAI_ImageResponseCreate(res);
Expand Down Expand Up @@ -1898,7 +1901,7 @@ static char *OpenAI_AudioTranscriptionFile(OpenAI_AudioTranscription_t *audioTra
free(reqEndBody);
free(itemPrefix);

char *result = _audioTranscription->oai->upload(_audioTranscription->oai->api_key, endpoint, boundary, data, len);
char *result = _audioTranscription->oai->upload(_audioTranscription->oai->base_url, _audioTranscription->oai->api_key, endpoint, boundary, data, len);
free(data);
OPENAI_ERROR_CHECK(result != NULL, "Empty result!", NULL);
cJSON *json = cJSON_Parse(result);
Expand Down Expand Up @@ -2039,7 +2042,7 @@ static char *OpenAI_AudioTranslationFile(OpenAI_AudioTranslation_t *audioTransla
free(itemPrefix);
free(reqBody);
free(reqEndBody);
char *result = _audioTranslation->oai->upload(_audioTranslation->oai->api_key, endpoint, boundary, data, len);
char *result = _audioTranslation->oai->upload(_audioTranslation->oai->base_url, _audioTranslation->oai->api_key, endpoint, boundary, data, len);
free(data);
OPENAI_ERROR_CHECK(result != NULL, "Empty result!", NULL);
cJSON *json = cJSON_Parse(result);
Expand Down Expand Up @@ -2108,7 +2111,7 @@ OpenAI_EmbeddingResponse_t *OpenAI_EmbeddingCreate(OpenAI_t *openai, char *input
char *jsonBody = cJSON_Print(req);
cJSON_Delete(req);
_OpenAI_t *_openai = __containerof(openai, _OpenAI_t, parent);
char *response = _openai->post(_openai->api_key, endpoint, jsonBody);
char *response = _openai->post(_openai->base_url, _openai->api_key, endpoint, jsonBody);
free(jsonBody);
OPENAI_ERROR_CHECK(response != NULL, "Empty response!", NULL);
return OpenAI_EmbeddingResponseCreate(response);
Expand Down Expand Up @@ -2147,35 +2150,41 @@ OpenAI_ModerationResponse_t *OpenAI_ModerationCreate(OpenAI_t *openai, char *inp
char *jsonBody = cJSON_Print(req);
cJSON_Delete(req);
_OpenAI_t *_openai = __containerof(openai, _OpenAI_t, parent);
res = _openai->post(_openai->api_key, endpoint, jsonBody);
res = _openai->post(_openai->base_url, _openai->api_key, endpoint, jsonBody);
free(jsonBody);
OPENAI_ERROR_CHECK(res != NULL, "Empty result!", NULL);
return OpenAI_ModerationResponseCreate(res);
}

//
// Open AI
//

void OpenAIDelete(OpenAI_t *oai)
{
_OpenAI_t *_oai = __containerof(oai, _OpenAI_t, parent);
if (_oai != NULL) {
if (_oai->api_key != NULL) {
free(_oai->api_key);
free(_oai->base_url);
_oai->api_key = NULL;
}
free(_oai);
_oai = NULL;
}
}

static char *OpenAI_Request(const char *api_key, const char *endpoint, const char *content_type, esp_http_client_method_t method, const char *boundary, uint8_t *data, size_t len)
void OpenAIChangeBaseURL(OpenAI_t *oai, const char *baseURL)
{
_OpenAI_t *_oai = __containerof(oai, _OpenAI_t, parent);
if (_oai->base_url != NULL) {
free(_oai->base_url);
}
_oai->base_url = strdup(baseURL);
}

static char *OpenAI_Request(const char *base_url, const char *api_key, const char *endpoint, const char *content_type, esp_http_client_method_t method, const char *boundary, uint8_t *data, size_t len)
{
ESP_LOGD(TAG, "\"%s\", len=%u", endpoint, len);
char *url = NULL;
char *result = NULL;
asprintf(&url, "https://api.openai.com/v1/%s", endpoint);
asprintf(&url, "%s%s", base_url, endpoint);
OPENAI_ERROR_CHECK(url != NULL, "Failed to allocate url!", NULL);
esp_http_client_config_t config = {
.url = url,
Expand Down Expand Up @@ -2229,24 +2238,24 @@ static char *OpenAI_Request(const char *api_key, const char *endpoint, const cha
return result != NULL ? result : NULL;
}

static char *OpenAI_Upload(const char *api_key, const char *endpoint, const char *boundary, uint8_t *data, size_t len)
static char *OpenAI_Upload(const char *base_url, const char *api_key, const char *endpoint, const char *boundary, uint8_t *data, size_t len)
{
return OpenAI_Request(api_key, endpoint, "multipart/form-data", HTTP_METHOD_POST, boundary, data, len);
return OpenAI_Request(base_url, api_key, endpoint, "multipart/form-data", HTTP_METHOD_POST, boundary, data, len);
}

static char *OpenAI_Post(const char *api_key, const char *endpoint, char *jsonBody)
static char *OpenAI_Post(const char *base_url, const char *api_key, const char *endpoint, char *jsonBody)
{
return OpenAI_Request(api_key, endpoint, "application/json", HTTP_METHOD_POST, NULL, (uint8_t *)jsonBody, strlen(jsonBody));
return OpenAI_Request(base_url, api_key, endpoint, "application/json", HTTP_METHOD_POST, NULL, (uint8_t *)jsonBody, strlen(jsonBody));
}

static char *OpenAI_Get(const char *api_key, const char *endpoint)
static char *OpenAI_Get(const char *base_url, const char *api_key, const char *endpoint)
{
return OpenAI_Request(api_key, endpoint, "application/json", HTTP_METHOD_GET, NULL, NULL, 0);
return OpenAI_Request(base_url, api_key, endpoint, "application/json", HTTP_METHOD_GET, NULL, NULL, 0);
}

static char *OpenAI_Del(const char *api_key, const char *endpoint)
static char *OpenAI_Del(const char *base_url, const char *api_key, const char *endpoint)
{
return OpenAI_Request(api_key, endpoint, "application/json", HTTP_METHOD_DELETE, NULL, NULL, 0);
return OpenAI_Request(base_url, api_key, endpoint, "application/json", HTTP_METHOD_DELETE, NULL, NULL, 0);
}

OpenAI_t *OpenAICreate(const char *api_key)
Expand All @@ -2255,6 +2264,7 @@ OpenAI_t *OpenAICreate(const char *api_key)
_OpenAI_t *_oai = (_OpenAI_t *)calloc(1, sizeof(_OpenAI_t));
OPENAI_ERROR_CHECK(_oai != NULL, "Failed to allocate _OpenAI!", NULL);
_oai->api_key = strdup(api_key);
_oai->base_url = strdup(OPENAI_DEFAULT_BASE_URL);

#if CONFIG_ENABLE_EMBEDDING
_oai->parent.embeddingCreate = &OpenAI_EmbeddingCreate;
Expand Down
2 changes: 1 addition & 1 deletion components/openai/idf_component.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version: "0.1.2"
version: "0.2.0"
description: OpenAI library compatible with ESP-IDF
url: https://github.com/espressif/esp-iot-solution
repository: https://github.com/espressif/esp-iot-solution.git
Expand Down
6 changes: 6 additions & 0 deletions components/openai/include/OpenAI.h
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,12 @@ OpenAI_t *OpenAICreate(const char *api_key);
*/
void OpenAIDelete(OpenAI_t *oai);

/**
* @brief Modify the Base URL of the OpenAI object
*
*/
void OpenAIChangeBaseURL(OpenAI_t *oai, const char *baseURL);

#ifdef __cplusplus
}
#endif
Binary file not shown.
Binary file not shown.
Binary file removed components/openai/test_apps/audio/turn_on_tv_en.wav
Binary file not shown.
Binary file removed components/openai/test_apps/audio/zhengzai_cn.wav
Binary file not shown.
2 changes: 1 addition & 1 deletion components/openai/test_apps/main/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
idf_component_register(SRC_DIRS "."
INCLUDE_DIRS "."
REQUIRES unity test_utils openai protocol_examples_common esp_netif nvs_flash esp_wifi driver
EMBED_FILES "../audio/turn_on_tv_en.wav" "../audio/zhengzai_cn.wav"
EMBED_FILES "../audio/turn_on_tv_en.mp3" "../audio/introduce_espressif.mp3"
)

add_definitions(-DCI_OPENAI_KEY="$ENV{CI_OPENAI_KEY}")
16 changes: 8 additions & 8 deletions components/openai/test_apps/main/test_openai.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ static const char *TAG = "openai_test";
#define TEST_MEMORY_LEAK_THRESHOLD (-3000)

static char *openai_key = CI_OPENAI_KEY;
extern const uint8_t turn_on_tv_en_wav_start[] asm("_binary_turn_on_tv_en_wav_start");
extern const uint8_t turn_on_tv_en_wav_end[] asm("_binary_turn_on_tv_en_wav_end");
extern const uint8_t zhengzai_cn_wav_start[] asm("_binary_zhengzai_cn_wav_start");
extern const uint8_t zhengzai_cn_wav_end[] asm("_binary_zhengzai_cn_wav_end");
extern const uint8_t turn_on_tv_en_mp3_start[] asm("_binary_turn_on_tv_en_mp3_start");
extern const uint8_t turn_on_tv_en_mp3_end[] asm("_binary_turn_on_tv_en_mp3_end");
extern const uint8_t introduce_espressif_mp3_start[] asm("_binary_introduce_espressif_mp3_start");
extern const uint8_t introduce_espressif_mp3_end[] asm("_binary_introduce_espressif_mp3_end");

TEST_CASE("test ChatCompletion", "[ChatCompletion]")
{
Expand Down Expand Up @@ -108,11 +108,11 @@ TEST_CASE("test AudioTranscription en", "[AudioTranscription]")
OpenAI_t *openai = OpenAICreate(openai_key);
OpenAI_AudioTranscription_t *audioTranscription = openai->audioTranscriptionCreate(openai);
TEST_ASSERT_NOT_NULL(audioTranscription);
size_t length = turn_on_tv_en_wav_end - turn_on_tv_en_wav_start;
size_t length = turn_on_tv_en_mp3_end - turn_on_tv_en_mp3_start;
audioTranscription->setResponseFormat(audioTranscription, OPENAI_AUDIO_RESPONSE_FORMAT_JSON);
audioTranscription->setTemperature(audioTranscription,0.2); //float between 0 and 1. Higher value gives more random results.
audioTranscription->setLanguage(audioTranscription,"en"); //Set to English to make GPT return faster and more accurate
char *text = audioTranscription->file(audioTranscription, (uint8_t *)turn_on_tv_en_wav_start, length, OPENAI_AUDIO_INPUT_FORMAT_WAV);
char *text = audioTranscription->file(audioTranscription, (uint8_t *)turn_on_tv_en_mp3_start, length, OPENAI_AUDIO_INPUT_FORMAT_MP3);
TEST_ASSERT_NOT_NULL(text);
ESP_LOGI(TAG, "Text: %s", text);
free(text);
Expand All @@ -133,12 +133,12 @@ TEST_CASE("test AudioTranscription cn", "[AudioTranscription]")
OpenAI_t *openai = OpenAICreate(openai_key);
OpenAI_AudioTranscription_t *audioTranscription = openai->audioTranscriptionCreate(openai);
TEST_ASSERT_NOT_NULL(audioTranscription);
size_t length = zhengzai_cn_wav_end - zhengzai_cn_wav_start;
size_t length = introduce_espressif_mp3_end - introduce_espressif_mp3_start;
audioTranscription->setResponseFormat(audioTranscription, OPENAI_AUDIO_RESPONSE_FORMAT_JSON);
audioTranscription->setPrompt(audioTranscription, "请回复简体中文"); //The default will return Traditional Chinese, here we add prompt to make GPT return Simplified Chinese
audioTranscription->setTemperature(audioTranscription,0.2); //float between 0 and 1. Higher value gives more random results.
audioTranscription->setLanguage(audioTranscription,"zh"); //Set to Chinese to make GPT return faster and more accurate
char *text = audioTranscription->file(audioTranscription, (uint8_t *)zhengzai_cn_wav_start, length, OPENAI_AUDIO_INPUT_FORMAT_WAV);
char *text = audioTranscription->file(audioTranscription, (uint8_t *)introduce_espressif_mp3_start, length, OPENAI_AUDIO_INPUT_FORMAT_MP3);
TEST_ASSERT_NOT_NULL(text);
ESP_LOGI(TAG, "Text: %s", text);
free(text);
Expand Down
2 changes: 0 additions & 2 deletions components/openai/test_apps/sdkconfig.ci.defaults
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ CONFIG_ESP32S2_DEFAULT_CPU_FREQ_240=y
CONFIG_ESP32S3_DEFAULT_CPU_FREQ_240=y
CONFIG_ESP_TASK_WDT=n

CONFIG_SOC_SPIRAM_SUPPORTED=y
CONFIG_PARTITION_TABLE_CUSTOM=y
CONFIG_PARTITION_TABLE_CUSTOM_FILENAME="partitions.csv"

Expand All @@ -18,4 +17,3 @@ CONFIG_ESP_MAIN_TASK_STACK_SIZE=5120

CONFIG_EXAMPLE_WIFI_SSID="${CI_TEST_WIFI_SSID_2_4G}"
CONFIG_EXAMPLE_WIFI_PASSWORD="${CI_TEST_WIFI_PSW_2_4G}"
CONFIG_SPIRAM=y
2 changes: 0 additions & 2 deletions components/openai/test_apps/sdkconfig.defaults
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@ CONFIG_ESP32S2_DEFAULT_CPU_FREQ_240=y
CONFIG_ESP32S3_DEFAULT_CPU_FREQ_240=y
CONFIG_ESP_TASK_WDT=n

CONFIG_SOC_SPIRAM_SUPPORTED=y
CONFIG_PARTITION_TABLE_CUSTOM=y
CONFIG_PARTITION_TABLE_CUSTOM_FILENAME="partitions.csv"

CONFIG_ESPTOOLPY_FLASHSIZE_4MB=y
CONFIG_ESP_MAIN_TASK_STACK_SIZE=5120
CONFIG_SPIRAM=y

0 comments on commit a62dccb

Please sign in to comment.