Skip to content

Commit

Permalink
[azopenai] Updating to the 2023-07-01 API surface (#21169)
Browse files Browse the repository at this point in the history
* Updating to the 2023-07-01 API surface
- Adding in functions support and example.
- Added in accomodation for content filtering info.
- Make it so we can use separate service instances for some tests so we can test against the latest upcoming fixes/changes.
  • Loading branch information
richardpark-msft authored Jul 14, 2023
1 parent c39ac95 commit 9f67e71
Show file tree
Hide file tree
Showing 25 changed files with 1,426 additions and 151 deletions.
2 changes: 1 addition & 1 deletion sdk/cognitiveservices/azopenai/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "go",
"TagPrefix": "go/cognitiveservices/azopenai",
"Tag": "go/cognitiveservices/azopenai_2b6f93a94d"
"Tag": "go/cognitiveservices/azopenai_63852f374c"
}
60 changes: 52 additions & 8 deletions sdk/cognitiveservices/azopenai/autorest.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ go: true
use: "@autorest/[email protected]"
title: "OpenAI"
slice-elements-byval: true
remove-non-reference-schema: true
# can't use this since it removes an innererror type that we want ()
# remove-non-reference-schema: true
```

## Transformations
Expand Down Expand Up @@ -81,7 +82,7 @@ directive:
where: $.components.schemas["ImageOperation"].properties.status
transform: $["$ref"] = $.anyOf[0]["$ref"];delete $.anyOf;
- from: openapi-document
where: $.components.schemas["ImageGenerationOptions"].properties
where: $.components.schemas.ImageGenerationOptions.properties
transform: |
$.size["$ref"] = "#/components/schemas/ImageSize"; delete $.allOf;
$.response_format["$ref"] = "#/components/schemas/ImageGenerationResponseFormat"; delete $.allOf;
Expand All @@ -93,11 +94,12 @@ directive:
- from: openapi-document
where: $.components.schemas["ImageOperationStatus"].properties.status
transform: $["$ref"] = "#/components/schemas/State"; delete $.allOf;
- from: openapi-document
where: $.components.schemas["ContentFilterResult"].properties.severity
transform: $["$ref"] = "#/components/schemas/ContentFilterSeverity"; delete $.allOf;
- from: openapi-document
where: $.components.schemas["ChatChoice"].properties.finish_reason
transform: >
delete $.oneOf;
$["$ref"] = "#/components/schemas/CompletionsFinishReason";
transform: $["$ref"] = "#/components/schemas/CompletionsFinishReason"; delete $.oneOf;
# Fix "AutoGenerated" models
- from: openapi-document
where: $.components.schemas["ChatCompletions"].properties.usage
Expand Down Expand Up @@ -163,7 +165,7 @@ directive:
- client.go
- models.go
- options.go
- response_types.go
- response_types.go
where: $
transform: return $.replace(/Client(\w+)((?:Options|Response))/g, "$1$2");

Expand All @@ -172,10 +174,19 @@ directive:
where: $
transform: return $.replace(/runtime\.JoinPaths\(client.endpoint, urlPath\)/g, "client.formatURL(urlPath)");

# Some ImageGenerations hackery to represent the ImageLocation/ImagePayload polymorphism.
# - Remove the auto-generated ImageGenerationsDataItem.
# - Replace the ImageGenerations.Data type with []ImageGenerationDataItem
# - from: models.go
# where: $
# transform: |
# return $.replace(/type ImageGenerationsDataItem struct {[^}]+}/, "// ImageGenerationsDataItem represents an image URL or payload\ntype ImageGenerationsDataItem struct{\nImageLocation\nImagePayload\n}")
# $.replace(/(type ImageGenerations struct.+?)Data any/g, "$1Data []ImageGenerationsDataItem")

- from: models.go
where: $
transform: |
return $.replace(/type ImageGenerationsDataItem struct {[^}]+}/, "// ImageGenerationsDataItem represents an image URL or payload\ntype ImageGenerationsDataItem struct{\nImageLocation\nImagePayload\n}");
return $.replace(/(type ImageGenerations struct.+?)Data any/sg, "$1Data []ImageGenerationsDataItem")
# delete the auto-generated ImageGenerationsDataItem, we handle that custom
- from: models.go
Expand Down Expand Up @@ -218,6 +229,17 @@ directive:
.replace(/BeginAzureBatchImageGenerationInternal/g, "beginAzureBatchImageGeneration")
.replace(/BatchImageGenerationOperationResponse/g, "batchImageGenerationOperationResponse");
# BUG: ChatCompletionsOptionsFunctionCall is another one of those "here's mutually exclusive values" options...
- from:
- models.go
- models_serde.go
where: $
transform: |
return $
.replace(/populateAny\(objectMap, "function_call", c.FunctionCall\)/, 'populate(objectMap, "function_call", c.FunctionCall)')
.replace(/\/\/ ChatCompletionsOptionsFunctionCall.+?\n}/, "")
.replace(/FunctionCall any/, "FunctionCall *ChatCompletionsOptionsFunctionCall");
# fix some casing
- from:
- client.go
Expand All @@ -228,8 +250,30 @@ directive:
where: $
transform: return $.replace(/Logprobs/g, "LogProbs")

# remove PossibleazureOpenAIOperationStateValues, since we don't expose the poller
# delete ContentFilterResult in favor of our custom representation.
- from:
- models.go
- models_serde.go
where: $
transform: |
return $.replace(/\/\/ ContentFilterResult.+?\n}/s, "")
.replace(/\/\/ MarshalJSON implements the json.Marshaller interface for type ContentFilterResult.+?\n}/s, "")
.replace(/\/\/ UnmarshalJSON implements the json.Unmarshaller interface for type ContentFilterResult.+?\n}/s, "");
- from: constants.go
where: $
transform: return $.replace(/\/\/ PossibleazureOpenAIOperationStateValues returns.+?\n}/s, "");

# fix incorrect property name for content filtering
# TODO: I imagine we should able to fix this in the tsp?
- from: models_serde.go
where: $
transform: |
return $
.replace(/ case "selfHarm":/g, ' case "self_harm":')
.replace(/populate\(objectMap, "selfHarm", c.SelfHarm\)/g, 'populate(objectMap, "self_harm", c.SelfHarm)');
- from: client.go
where: $
transform: return $.replace(/runtime\.NewResponseError/sg, "client.newError");
```
26 changes: 13 additions & 13 deletions sdk/cognitiveservices/azopenai/client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 37 additions & 21 deletions sdk/cognitiveservices/azopenai/client_chat_completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,33 @@ func TestClient_GetChatCompletions(t *testing.T) {
chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, chatCompletionsModelDeployment, newClientOptionsForTest(t))
require.NoError(t, err)

testGetChatCompletions(t, chatClient)
testGetChatCompletions(t, chatClient, true)
}

func TestClient_GetChatCompletionsStream(t *testing.T) {
cred, err := azopenai.NewKeyCredential(apiKey)
require.NoError(t, err)

chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, chatCompletionsModelDeployment, newClientOptionsForTest(t))
require.NoError(t, err)

testGetChatCompletionsStream(t, chatClient, true)
chatClient := newAzureOpenAIClientForTest(t, canaryChatCompletionsModelDeployment, true)
testGetChatCompletionsStream(t, chatClient)
}

func TestClient_OpenAI_GetChatCompletions(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}

chatClient := newOpenAIClientForTest(t)
testGetChatCompletions(t, chatClient)
testGetChatCompletions(t, chatClient, false)
}

func TestClient_OpenAI_GetChatCompletionsStream(t *testing.T) {
if testing.Short() {
t.Skip("Skipping OpenAI tests when attempting to do quick tests")
}

chatClient := newOpenAIClientForTest(t)
testGetChatCompletionsStream(t, chatClient, false)
testGetChatCompletionsStream(t, chatClient)
}

func testGetChatCompletions(t *testing.T, client *azopenai.Client) {
func testGetChatCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
expected := azopenai.ChatCompletions{
Choices: []azopenai.ChatChoice{
{
Expand All @@ -91,6 +94,15 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client) {
resp, err := client.GetChatCompletions(context.Background(), chatCompletionsRequest, nil)
require.NoError(t, err)

if isAzure {
// Azure also provides content-filtering. This particular prompt and responses
// will be considered safe.
expected.PromptAnnotations = []azopenai.PromptFilterResult{
{PromptIndex: to.Ptr[int32](0), ContentFilterResults: (*azopenai.PromptFilterResultContentFilterResults)(safeContentFilter)},
}
expected.Choices[0].ContentFilterResults = safeContentFilter
}

require.NotEmpty(t, resp.ID)
require.NotEmpty(t, resp.Created)

Expand All @@ -100,17 +112,10 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client) {
require.Equal(t, expected, resp.ChatCompletions)
}

func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, isAzure bool) {
func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client) {
streamResp, err := client.GetChatCompletionsStream(context.Background(), chatCompletionsRequest, nil)
require.NoError(t, err)

if isAzure {
// there's a bug right now where the first event comes back empty
// Issue: https://github.com/Azure/azure-sdk-for-go/issues/21086
_, err := streamResp.ChatCompletionsStream.Read()
require.NoError(t, err)
}

// the data comes back differently for streaming
// 1. the text comes back in the ChatCompletion.Delta field
// 2. the role is only sent on the first streamed ChatCompletion
Expand All @@ -125,6 +130,18 @@ func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, isAzure
}

require.NoError(t, err)

if completion.PromptAnnotations != nil {
require.Equal(t, []azopenai.PromptFilterResult{
{PromptIndex: to.Ptr[int32](0), ContentFilterResults: (*azopenai.PromptFilterResultContentFilterResults)(safeContentFilter)},
}, completion.PromptAnnotations)
}

if len(completion.Choices) == 0 {
// you can get empty entries that contain just metadata (ie, prompt annotations)
continue
}

require.Equal(t, 1, len(completion.Choices))
choices = append(choices, completion.Choices[0])
}
Expand All @@ -140,7 +157,6 @@ func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, isAzure
}

require.Equal(t, expectedContent, message, "Ultimately, the same result as GetChatCompletions(), just sent across the .Delta field instead")

require.Equal(t, azopenai.ChatRoleAssistant, expectedRole)
}

Expand All @@ -167,7 +183,7 @@ func TestClient_GetChatCompletions_DefaultAzureCredential(t *testing.T) {
})
require.NoError(t, err)

testGetChatCompletions(t, chatClient)
testGetChatCompletions(t, chatClient, true)
}

func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
Expand Down
Loading

0 comments on commit 9f67e71

Please sign in to comment.