Skip to content

Commit

Permalink
[azopenai] Make the deployment ID a per-request parameter, rather tha…
Browse files Browse the repository at this point in the history
…n a constructor parameter. (#21223)

A rather big change.

Prior to this we required the deployment ID to be passed as a parameter when constructing the client. This had some advantages in that it matched the underlying physical properties for Azure (it's a route parameter, not a request parameteR), but it broke the de-facto conventions established by OpenAI, where the model is considered a request parameter.

This is now changed, with the one concession being that the `Model` field has been renamed to `DeploymentID`, to match with the other Azure SDKs.
  • Loading branch information
richardpark-msft authored Jul 20, 2023
1 parent 054b45e commit eacd271
Show file tree
Hide file tree
Showing 23 changed files with 309 additions and 260 deletions.
2 changes: 1 addition & 1 deletion sdk/cognitiveservices/azopenai/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Release History

## 0.1.0 (2023-07-19)
## 0.1.0 (2023-07-20)

* Initial release of the `azopenai` library
2 changes: 1 addition & 1 deletion sdk/cognitiveservices/azopenai/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "go",
"TagPrefix": "go/cognitiveservices/azopenai",
"Tag": "go/cognitiveservices/azopenai_e8362ae205"
"Tag": "go/cognitiveservices/azopenai_8fdad86997"
}
20 changes: 19 additions & 1 deletion sdk/cognitiveservices/azopenai/autorest.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ directive:
# allow interception of formatting the URL path
- from: client.go
where: $
transform: return $.replace(/runtime\.JoinPaths\(client.endpoint, urlPath\)/g, "client.formatURL(urlPath)");
transform: |
return $
.replace(/runtime\.JoinPaths\(client.endpoint, urlPath\)/g, "client.formatURL(urlPath, getDeploymentID(body))");
# Some ImageGenerations hackery to represent the ImageLocation/ImagePayload polymorphism.
# - Remove the auto-generated ImageGenerationsDataItem.
Expand Down Expand Up @@ -276,4 +278,20 @@ directive:
- from: client.go
where: $
transform: return $.replace(/runtime\.NewResponseError/sg, "client.newError");

#
# rename `Model` to `DeploymentID`
#
- from: models.go
where: $
transform: |
return $
.replace(/\/\/ The model name.*?Model \*string/sg, "// DeploymentID specifies the name of the deployment (for Azure OpenAI) or model (for OpenAI) to use for this request.\nDeploymentID string");
- from: models_serde.go
where: $
transform: |
return $
.replace(/populate\(objectMap, "model", (c|e).Model\)/g, 'populate(objectMap, "model", &$1.DeploymentID)')
.replace(/err = unpopulate\(val, "Model", &(c|e).Model\)/g, 'err = unpopulate(val, "Model", &$1.DeploymentID)');
```
8 changes: 4 additions & 4 deletions sdk/cognitiveservices/azopenai/client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

72 changes: 36 additions & 36 deletions sdk/cognitiveservices/azopenai/client_chat_completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,36 @@ import (
"github.com/stretchr/testify/require"
)

var chatCompletionsRequest = azopenai.ChatCompletionsOptions{
Messages: []azopenai.ChatMessage{
{
Role: to.Ptr(azopenai.ChatRole("user")),
Content: to.Ptr("Count to 10, with a comma between each number, no newlines and a period at the end. E.g., 1, 2, 3, ..."),
func newTestChatCompletionOptions(tv testVars) azopenai.ChatCompletionsOptions {
return azopenai.ChatCompletionsOptions{
Messages: []azopenai.ChatMessage{
{
Role: to.Ptr(azopenai.ChatRole("user")),
Content: to.Ptr("Count to 10, with a comma between each number, no newlines and a period at the end. E.g., 1, 2, 3, ..."),
},
},
},
MaxTokens: to.Ptr(int32(1024)),
Temperature: to.Ptr(float32(0.0)),
Model: &openAIChatCompletionsModel,
MaxTokens: to.Ptr(int32(1024)),
Temperature: to.Ptr(float32(0.0)),
DeploymentID: tv.ChatCompletions,
}
}

var expectedContent = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10."
var expectedRole = azopenai.ChatRoleAssistant

func TestClient_GetChatCompletions(t *testing.T) {
cred, err := azopenai.NewKeyCredential(apiKey)
cred, err := azopenai.NewKeyCredential(azureOpenAI.APIKey)
require.NoError(t, err)

chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, chatCompletionsModelDeployment, newClientOptionsForTest(t))
chatClient, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Endpoint, cred, newClientOptionsForTest(t))
require.NoError(t, err)

testGetChatCompletions(t, chatClient, true)
testGetChatCompletions(t, chatClient, azureOpenAI)
}

func TestClient_GetChatCompletionsStream(t *testing.T) {
chatClient := newAzureOpenAIClientForTest(t, canaryChatCompletionsModelDeployment, true)
testGetChatCompletionsStream(t, chatClient)
chatClient := newAzureOpenAIClientForTest(t, azureOpenAICanary)
testGetChatCompletionsStream(t, chatClient, azureOpenAICanary)
}

func TestClient_OpenAI_GetChatCompletions(t *testing.T) {
Expand All @@ -58,7 +60,7 @@ func TestClient_OpenAI_GetChatCompletions(t *testing.T) {
}

chatClient := newOpenAIClientForTest(t)
testGetChatCompletions(t, chatClient, false)
testGetChatCompletions(t, chatClient, openAI)
}

func TestClient_OpenAI_GetChatCompletionsStream(t *testing.T) {
Expand All @@ -67,10 +69,10 @@ func TestClient_OpenAI_GetChatCompletionsStream(t *testing.T) {
}

chatClient := newOpenAIClientForTest(t)
testGetChatCompletionsStream(t, chatClient)
testGetChatCompletionsStream(t, chatClient, openAI)
}

func testGetChatCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
func testGetChatCompletions(t *testing.T, client *azopenai.Client, tv testVars) {
expected := azopenai.ChatCompletions{
Choices: []azopenai.ChatChoice{
{
Expand All @@ -91,10 +93,10 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client, isAzure bool)
},
}

resp, err := client.GetChatCompletions(context.Background(), chatCompletionsRequest, nil)
resp, err := client.GetChatCompletions(context.Background(), newTestChatCompletionOptions(tv), nil)
require.NoError(t, err)

if isAzure {
if tv.Azure {
// Azure also provides content-filtering. This particular prompt and responses
// will be considered safe.
expected.PromptAnnotations = []azopenai.PromptFilterResult{
Expand All @@ -112,8 +114,8 @@ func testGetChatCompletions(t *testing.T, client *azopenai.Client, isAzure bool)
require.Equal(t, expected, resp.ChatCompletions)
}

func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client) {
streamResp, err := client.GetChatCompletionsStream(context.Background(), chatCompletionsRequest, nil)
func testGetChatCompletionsStream(t *testing.T, client *azopenai.Client, tv testVars) {
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(tv), nil)
require.NoError(t, err)

// the data comes back differently for streaming
Expand Down Expand Up @@ -178,19 +180,19 @@ func TestClient_GetChatCompletions_DefaultAzureCredential(t *testing.T) {
})
require.NoError(t, err)

chatClient, err := azopenai.NewClient(endpoint, dac, chatCompletionsModelDeployment, &azopenai.ClientOptions{
chatClient, err := azopenai.NewClient(azureOpenAI.Endpoint, dac, &azopenai.ClientOptions{
ClientOptions: policy.ClientOptions{Transport: recordingTransporter},
})
require.NoError(t, err)

testGetChatCompletions(t, chatClient, true)
testGetChatCompletions(t, chatClient, azureOpenAI)
}

func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
cred, err := azopenai.NewKeyCredential(apiKey)
cred, err := azopenai.NewKeyCredential(azureOpenAI.APIKey)
require.NoError(t, err)

chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, "thisdoesntexist", newClientOptionsForTest(t))
chatClient, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Endpoint, cred, newClientOptionsForTest(t))
require.NoError(t, err)

_, err = chatClient.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
Expand All @@ -200,8 +202,9 @@ func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
Content: to.Ptr("Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..."),
},
},
MaxTokens: to.Ptr(int32(1024)),
Temperature: to.Ptr(float32(0.0)),
MaxTokens: to.Ptr(int32(1024)),
Temperature: to.Ptr(float32(0.0)),
DeploymentID: "invalid model name",
}, nil)

var respErr *azcore.ResponseError
Expand All @@ -214,20 +217,17 @@ func TestClient_GetChatCompletionsStream_Error(t *testing.T) {
t.Skip()
}

doTest := func(t *testing.T, client *azopenai.Client) {
t.Helper()
streamResp, err := client.GetChatCompletionsStream(context.Background(), chatCompletionsRequest, nil)
t.Run("AzureOpenAI", func(t *testing.T) {
client := newBogusAzureOpenAIClient(t)
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(azureOpenAI), nil)
require.Empty(t, streamResp)
assertResponseIsError(t, err)
}

t.Run("AzureOpenAI", func(t *testing.T) {
client := newBogusAzureOpenAIClient(t, chatCompletionsModelDeployment)
doTest(t, client)
})

t.Run("OpenAI", func(t *testing.T) {
client := newBogusOpenAIClient(t)
doTest(t, client)
streamResp, err := client.GetChatCompletionsStream(context.Background(), newTestChatCompletionOptions(openAI), nil)
require.Empty(t, streamResp)
assertResponseIsError(t, err)
})
}
24 changes: 15 additions & 9 deletions sdk/cognitiveservices/azopenai/client_completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ import (
)

func TestClient_GetCompletions_AzureOpenAI(t *testing.T) {
cred, err := azopenai.NewKeyCredential(apiKey)
cred, err := azopenai.NewKeyCredential(azureOpenAI.APIKey)
require.NoError(t, err)

client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, completionsModelDeployment, newClientOptionsForTest(t))
client, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Endpoint, cred, newClientOptionsForTest(t))
require.NoError(t, err)

testGetCompletions(t, client)
testGetCompletions(t, client, true)
}

func TestClient_GetCompletions_OpenAI(t *testing.T) {
Expand All @@ -31,15 +31,21 @@ func TestClient_GetCompletions_OpenAI(t *testing.T) {
}

client := newOpenAIClientForTest(t)
testGetCompletions(t, client)
testGetCompletions(t, client, false)
}

func testGetCompletions(t *testing.T, client *azopenai.Client) {
func testGetCompletions(t *testing.T, client *azopenai.Client, isAzure bool) {
deploymentID := openAI.Completions

if isAzure {
deploymentID = azureOpenAI.Completions
}

resp, err := client.GetCompletions(context.Background(), azopenai.CompletionsOptions{
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048 - 127)),
Temperature: to.Ptr(float32(0.0)),
Model: &openAICompletionsModel,
Prompt: []string{"What is Azure OpenAI?"},
MaxTokens: to.Ptr(int32(2048 - 127)),
Temperature: to.Ptr(float32(0.0)),
DeploymentID: deploymentID,
}, nil)
require.NoError(t, err)

Expand Down
24 changes: 11 additions & 13 deletions sdk/cognitiveservices/azopenai/client_embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ import (
)

func TestClient_GetEmbeddings_InvalidModel(t *testing.T) {
cred, err := azopenai.NewKeyCredential(apiKey)
cred, err := azopenai.NewKeyCredential(azureOpenAI.APIKey)
require.NoError(t, err)

chatClient, err := azopenai.NewClientWithKeyCredential(endpoint, cred, "thisdoesntexist", newClientOptionsForTest(t))
chatClient, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Endpoint, cred, newClientOptionsForTest(t))
require.NoError(t, err)

_, err = chatClient.GetEmbeddings(context.Background(), azopenai.EmbeddingsOptions{}, nil)
_, err = chatClient.GetEmbeddings(context.Background(), azopenai.EmbeddingsOptions{
DeploymentID: "thisdoesntexist",
}, nil)

var respErr *azcore.ResponseError
require.ErrorAs(t, err, &respErr)
Expand All @@ -32,21 +34,17 @@ func TestClient_OpenAI_GetEmbeddings(t *testing.T) {
}

client := newOpenAIClientForTest(t)
modelID := "text-similarity-curie-001"
testGetEmbeddings(t, client, modelID)
testGetEmbeddings(t, client, openAI.Embeddings)
}

func TestClient_GetEmbeddings(t *testing.T) {
// model deployment points to `text-similarity-curie-001`
deploymentID := "embedding"

cred, err := azopenai.NewKeyCredential(apiKey)
cred, err := azopenai.NewKeyCredential(azureOpenAI.APIKey)
require.NoError(t, err)

client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, deploymentID, newClientOptionsForTest(t))
client, err := azopenai.NewClientWithKeyCredential(azureOpenAI.Endpoint, cred, newClientOptionsForTest(t))
require.NoError(t, err)

testGetEmbeddings(t, client, deploymentID)
testGetEmbeddings(t, client, azureOpenAI.Embeddings)
}

func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentID string) {
Expand All @@ -71,8 +69,8 @@ func testGetEmbeddings(t *testing.T, client *azopenai.Client, modelOrDeploymentI
ctx: context.TODO(),
deploymentID: modelOrDeploymentID,
body: azopenai.EmbeddingsOptions{
Input: []string{"\"Your text string goes here\""},
Model: &modelOrDeploymentID,
Input: []string{"\"Your text string goes here\""},
DeploymentID: modelOrDeploymentID,
},
options: nil,
},
Expand Down
16 changes: 9 additions & 7 deletions sdk/cognitiveservices/azopenai/client_functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@ func TestGetChatCompletions_usingFunctions(t *testing.T) {

t.Run("OpenAI", func(t *testing.T) {
chatClient := newOpenAIClientForTest(t)
testChatCompletionsFunctions(t, chatClient)
testChatCompletionsFunctions(t, chatClient, openAI)
})

t.Run("AzureOpenAI", func(t *testing.T) {
chatClient := newAzureOpenAIClientForTest(t, chatCompletionsModelDeployment, false)
testChatCompletionsFunctions(t, chatClient)
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
testChatCompletionsFunctions(t, chatClient, azureOpenAI)
})
}

func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client) {
resp, err := chatClient.GetChatCompletions(context.Background(), azopenai.ChatCompletionsOptions{
Model: to.Ptr("gpt-4-0613"),
func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, tv testVars) {
body := azopenai.ChatCompletionsOptions{
DeploymentID: tv.ChatCompletions,
Messages: []azopenai.ChatMessage{
{
Role: to.Ptr(azopenai.ChatRoleUser),
Expand Down Expand Up @@ -72,7 +72,9 @@ func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client) {
},
},
Temperature: to.Ptr[float32](0.0),
}, nil)
}

resp, err := chatClient.GetChatCompletions(context.Background(), body, nil)
require.NoError(t, err)

funcCall := resp.ChatCompletions.Choices[0].Message.FunctionCall
Expand Down
Loading

0 comments on commit eacd271

Please sign in to comment.