Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update all Azure components #3475

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 50 additions & 28 deletions bindings/azure/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ package openai
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"

"github.com/dapr/components-contrib/bindings"
azauth "github.com/dapr/components-contrib/common/authentication/azure"
Expand Down Expand Up @@ -120,10 +121,10 @@ func (p *AzOpenAI) Init(ctx context.Context, meta bindings.Metadata) error {

if m.APIKey != "" {
// use API key authentication
var keyCredential azopenai.KeyCredential
keyCredential, err = azopenai.NewKeyCredential(m.APIKey)
if err != nil {
return fmt.Errorf("error getting credentials object: %w", err)
var keyCredential *azcore.KeyCredential
keyCredential = azcore.NewKeyCredential(m.APIKey)
if keyCredential == nil {
return errors.New("error getting credentials object")
}

p.client, err = azopenai.NewClientWithKeyCredential(m.Endpoint, keyCredential, nil)
Expand Down Expand Up @@ -163,7 +164,7 @@ func (p *AzOpenAI) Operations() []bindings.OperationKind {
// Invoke handles all invoke operations.
func (p *AzOpenAI) Invoke(ctx context.Context, req *bindings.InvokeRequest) (resp *bindings.InvokeResponse, err error) {
if req == nil || len(req.Metadata) == 0 {
return nil, fmt.Errorf("invalid request: metadata is required")
return nil, errors.New("invalid request: metadata is required")
}

startTime := time.Now().UTC()
Expand Down Expand Up @@ -228,7 +229,7 @@ func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[
}

if prompt.Prompt == "" {
return nil, fmt.Errorf("prompt is required for completion operation")
return nil, errors.New("prompt is required for completion operation")
}

if prompt.DeploymentID == "" {
Expand All @@ -240,13 +241,13 @@ func (p *AzOpenAI) completion(ctx context.Context, message []byte, metadata map[
}

resp, err := p.client.GetCompletions(ctx, azopenai.CompletionsOptions{
Deployment: prompt.DeploymentID,
Prompt: []string{prompt.Prompt},
MaxTokens: &prompt.MaxTokens,
Temperature: &prompt.Temperature,
TopP: &prompt.TopP,
N: &prompt.N,
Stop: prompt.Stop,
DeploymentName: &prompt.DeploymentID,
Prompt: []string{prompt.Prompt},
MaxTokens: &prompt.MaxTokens,
Temperature: &prompt.Temperature,
TopP: &prompt.TopP,
N: &prompt.N,
Stop: prompt.Stop,
}, nil)
if err != nil {
return nil, fmt.Errorf("error getting completion api: %w", err)
Expand Down Expand Up @@ -280,7 +281,7 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
}

if len(messages.Messages) == 0 {
return nil, fmt.Errorf("messages are required for chat-completion operation")
return nil, errors.New("messages are required for chat-completion operation")
}

if messages.DeploymentID == "" {
Expand All @@ -291,11 +292,32 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
messages.Stop = nil
}

messageReq := make([]azopenai.ChatMessage, len(messages.Messages))
messageReq := make([]azopenai.ChatRequestMessageClassification, len(messages.Messages))
for i, m := range messages.Messages {
messageReq[i] = azopenai.ChatMessage{
Role: to.Ptr(azopenai.ChatRole(m.Role)),
Content: to.Ptr(m.Message),
currentMsg := m.Message
switch azopenai.ChatRole(m.Role) {
case azopenai.ChatRoleUser:
messageReq[i] = &azopenai.ChatRequestUserMessage{
Content: azopenai.NewChatRequestUserMessageContent(currentMsg),
}
case azopenai.ChatRoleAssistant:
messageReq[i] = &azopenai.ChatRequestAssistantMessage{
Content: &currentMsg,
}
case azopenai.ChatRoleFunction:
messageReq[i] = &azopenai.ChatRequestFunctionMessage{
Content: &currentMsg,
}
case azopenai.ChatRoleSystem:
messageReq[i] = &azopenai.ChatRequestSystemMessage{
Content: &currentMsg,
}
case azopenai.ChatRoleTool:
messageReq[i] = &azopenai.ChatRequestToolMessage{
Content: &currentMsg,
}
default:
return nil, fmt.Errorf("invalid role: %s", m.Role)
}
}

Expand All @@ -305,13 +327,13 @@ func (p *AzOpenAI) chatCompletion(ctx context.Context, messageRequest []byte, me
}

res, err := p.client.GetChatCompletions(ctx, azopenai.ChatCompletionsOptions{
Deployment: messages.DeploymentID,
MaxTokens: maxTokens,
Temperature: &messages.Temperature,
TopP: &messages.TopP,
N: &messages.N,
Messages: messageReq,
Stop: messages.Stop,
DeploymentName: &messages.DeploymentID,
MaxTokens: maxTokens,
Temperature: &messages.Temperature,
TopP: &messages.TopP,
N: &messages.N,
Messages: messageReq,
Stop: messages.Stop,
}, nil)
if err != nil {
return nil, fmt.Errorf("error getting chat completion api: %w", err)
Expand Down Expand Up @@ -343,8 +365,8 @@ func (p *AzOpenAI) getEmbedding(ctx context.Context, messageRequest []byte, meta
}

res, err := p.client.GetEmbeddings(ctx, azopenai.EmbeddingsOptions{
Deployment: message.DeploymentID,
Input: []string{message.Message},
DeploymentName: &message.DeploymentID,
Input: []string{message.Message},
}, nil)
if err != nil {
return nil, fmt.Errorf("error getting embedding api: %w", err)
Expand Down
32 changes: 16 additions & 16 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ require (
cloud.google.com/go/secretmanager v1.11.5
cloud.google.com/go/storage v1.36.0
dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.3.0
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0
github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig v1.1.0
github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v1.0.1
github.com/Azure/azure-sdk-for-go/sdk/data/aztables v1.1.0
github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs v1.1.0
github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus v1.7.0
github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v1.0.3
github.com/Azure/azure-sdk-for-go/sdk/data/aztables v1.2.0
github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs v1.2.1
github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus v1.7.1
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/eventgrid/armeventgrid/v2 v2.2.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/eventhub/armeventhub v1.2.0
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.1.0
Expand Down Expand Up @@ -51,7 +51,7 @@ require (
github.com/chebyrash/promise v0.0.0-20230709133807-42ec49ba1459
github.com/cinience/go_rocketmq v0.0.2
github.com/cloudevents/sdk-go/binding/format/protobuf/v2 v2.14.0
github.com/cloudevents/sdk-go/v2 v2.14.0
github.com/cloudevents/sdk-go/v2 v2.15.2
github.com/cloudwego/kitex v0.5.0
github.com/cloudwego/kitex-examples v0.1.1
github.com/cyphar/filepath-securejoin v0.2.4
Expand Down Expand Up @@ -117,10 +117,10 @@ require (
go.mongodb.org/mongo-driver v1.12.1
go.uber.org/multierr v1.11.0
go.uber.org/ratelimit v0.3.0
golang.org/x/crypto v0.22.0
golang.org/x/crypto v0.24.0
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a
golang.org/x/mod v0.14.0
golang.org/x/net v0.24.0
golang.org/x/mod v0.17.0
golang.org/x/net v0.26.0
golang.org/x/oauth2 v0.17.0
google.golang.org/api v0.162.0
google.golang.org/grpc v1.63.0
Expand All @@ -147,7 +147,7 @@ require (
github.com/99designs/keyring v1.2.1 // indirect
github.com/AthenZ/athenz v1.10.39 // indirect
github.com/Azure/azure-sdk-for-go v68.0.0+incompatible // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.7.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect
github.com/DataDog/zstd v1.5.2 // indirect
Expand Down Expand Up @@ -383,12 +383,12 @@ require (
go.uber.org/atomic v1.10.0 // indirect
go.uber.org/zap v1.24.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/sync v0.6.0 // indirect
golang.org/x/sys v0.19.0 // indirect
golang.org/x/term v0.19.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/term v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.17.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de // indirect
Expand Down
Loading
Loading