Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat(azure-openai): Add support for Cognitive Search datasource #115

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ Cargo.lock

# directory used to store images
data

.env
.idea
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "openapi/openai-openapi"]
path = openapi/openai-openapi
url = [email protected]:openai/openai-openapi.git
6 changes: 5 additions & 1 deletion async-openai/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ impl<'c, C: Config> Chat<'c, C> {
"When stream is true, use Chat::create_stream".into(),
));
}
self.client.post("/chat/completions", request).await
if request.data_sources.is_none() {
self.client.post("/chat/completions", request).await
} else {
self.client.post("/extensions/chat/completions", request).await
}
}

/// Creates a completion for the chat message
Expand Down
38 changes: 34 additions & 4 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::fmt::{Display, Formatter};
use std::pin::Pin;
use std::str::from_utf8;

use futures::{stream::StreamExt, Stream};
use reqwest::header::{CONTENT_TYPE, HeaderValue};
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
use serde::{de::DeserializeOwned, Serialize};

Expand All @@ -13,6 +16,7 @@ use crate::{
moderation::Moderations,
Audio, Chat, Completions, Embeddings, FineTunes, Models,
};
use crate::types::Stop::String;

#[derive(Debug, Clone)]
/// Client is a container for config, backoff and http_client
Expand Down Expand Up @@ -151,12 +155,38 @@ impl<C: Config> Client<C> {
O: DeserializeOwned,
{
let request_maker = || async {
let url = self.config.url(path);
let query = &self.config.query();
let mut headers = self.config.headers();

let body = serde_json::to_vec(&request)?;
if !headers.contains_key(CONTENT_TYPE) {
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
}

struct BodyDisplay {
body: Vec<u8>
}
impl Display for BodyDisplay {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match from_utf8(&self.body) {
Ok(body) => f.write_str(body),
Err(_) => f.write_str("Cannot display body"),
}
}
}

tracing::debug!("url: {}", url);
tracing::debug!("query: {:?}", query);
tracing::debug!("headers: {:?}", headers);
tracing::debug!("body: {}", BodyDisplay{body: body.clone()});

Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.post(url)
.query(query)
.headers(headers)
.body(body)
.build()?)
};

Expand Down
10 changes: 10 additions & 0 deletions async-openai/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Errors originating from API calls, parsing responses, and reading-or-writing to the file system.
use serde::Deserialize;
use serde_json::Error;

#[derive(Debug, thiserror::Error)]
pub enum OpenAIError {
Expand All @@ -12,6 +13,9 @@ pub enum OpenAIError {
/// Error when a response cannot be deserialized into a Rust type
#[error("failed to deserialize api response: {0}")]
JSONDeserialize(serde_json::Error),
/// Error when serialize request for send to api
#[error("failed to serialize request for send to api: {0}")]
JSONSerialize(serde_json::Error),
/// Error on the client side when saving file to file system
#[error("failed to save file: {0}")]
FileSaveError(String),
Expand Down Expand Up @@ -49,3 +53,9 @@ pub(crate) fn map_deserialization_error(e: serde_json::Error, bytes: &[u8]) -> O
);
OpenAIError::JSONDeserialize(e)
}

impl From<serde_json::Error> for OpenAIError {
fn from(value: Error) -> Self {
OpenAIError::JSONSerialize(value)
}
}
57 changes: 57 additions & 0 deletions async-openai/src/types/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,56 @@ pub struct ChatCompletionFunctions {
pub parameters: Option<serde_json::Value>,
}

/// parameters:
// endpoint: https://mysearchexample.search.windows.net
// key: '***(admin key)'
// indexName: my-chunk-index
// fieldsMapping:
// titleField: productName
// urlField: productUrl
// filepathField: productFilePath
// contentFields:
// - productDescription
// contentFieldsSeparator: |2+
//
// topNDocuments: 5
// queryType: semantic
// semanticConfiguration: defaultConfiguration
// inScope: true
// roleInformation: roleInformation
#[derive(Clone, Serialize, Default, Debug, Builder, Deserialize, PartialEq)]
#[builder(name = "AzureCognitiveSearchParametersArgs")]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)]
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
#[serde(rename_all = "camelCase")]
pub struct AzureCognitiveSearchParameters {
pub endpoint: String,
pub key: String,
pub index_name: String,
}

#[derive(Clone, Serialize, Debug, Builder, Deserialize, PartialEq)]
#[builder(name = "AzureCognitiveSearchDataSourceArgs")]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)]
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct AzureCognitiveSearchDataSource {
#[serde(rename = "type")]
pub _type: String,
pub parameters: AzureCognitiveSearchParameters,
}
impl Default for AzureCognitiveSearchDataSource {
fn default() -> Self {
Self {
_type: "AzureCognitiveSearch".to_string(),
parameters: Default::default(),
}
}
}

#[derive(Clone, Serialize, Default, Debug, Builder, Deserialize, PartialEq)]
#[builder(name = "CreateChatCompletionRequestArgs")]
#[builder(pattern = "mutable")]
Expand All @@ -771,11 +821,17 @@ pub struct ChatCompletionFunctions {
pub struct CreateChatCompletionRequest {
/// ID of the model to use.
/// See the [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.
#[serde(skip_serializing_if = "String::is_empty")]
pub model: String,

/// A list of messages comprising the conversation so far. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb).
pub messages: Vec<ChatCompletionRequestMessage>, // min: 1

/// The data sources to be used for the Azure OpenAI on your data feature
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "dataSources")]
pub data_sources: Option<Vec<AzureCognitiveSearchDataSource>>,

/// A list of functions the model may generate JSON inputs for.
#[serde(skip_serializing_if = "Option::is_none")]
pub functions: Option<Vec<ChatCompletionFunctions>>,
Expand Down Expand Up @@ -820,6 +876,7 @@ pub struct CreateChatCompletionRequest {
/// The maximum number of [tokens](https://platform.openai.com/tokenizer) to generate in the chat completion.
///
/// The total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens.
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u16>,

/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
Expand Down
12 changes: 12 additions & 0 deletions openapi/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@


# [Official repository of OpenAPI for OpenAI ](https://github.com/openai/openai-openapi)

This repo was copy in [openai-openapi](./openai-openapi)


# [Official OpenAPI for OpenAI](specification/cognitiveservices/data-plane/AzureOpenAI/inference/preview/)

versions:
- [azure-openai-openapi 2023-09-01-preview.yaml](./azure-openai-openapi_2023-09-01-preview.yaml)

Loading