Skip to content

Commit

Permalink
Add AWS Bedrock support (dapr#3563)
Browse files Browse the repository at this point in the history
Signed-off-by: yaron2 <[email protected]>
Signed-off-by: Artur Souza <[email protected]>
Co-authored-by: yaron2 <[email protected]>
Signed-off-by: Elena Kolevska <[email protected]>
  • Loading branch information
2 people authored and elena-kolevska committed Oct 24, 2024
1 parent 9cf7574 commit 4664736
Show file tree
Hide file tree
Showing 22 changed files with 805 additions and 261 deletions.
2 changes: 1 addition & 1 deletion .build-tools/go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/dapr/components-contrib/build-tools

go 1.23
go 1.23.0

require (
github.com/dapr/components-contrib v0.0.0
Expand Down
24 changes: 24 additions & 0 deletions common/authentication/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"strconv"
"time"

awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
v2creds "github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
Expand All @@ -37,6 +38,29 @@ type EnvironmentSettings struct {
Metadata map[string]string
}

func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) {
optFns := []func(*config.LoadOptions) error{}
if region != "" {
optFns = append(optFns, config.WithRegion(region))
}

if accessKey != "" && secretKey != "" {
provider := v2creds.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken)
optFns = append(optFns, config.WithCredentialsProvider(provider))
}

awsCfg, err := config.LoadDefaultConfig(context.Background(), optFns...)
if err != nil {
return awsv2.Config{}, err
}

if endpoint != "" {
awsCfg.BaseEndpoint = &endpoint
}

return awsCfg, nil
}

func GetClient(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (*session.Session, error) {
awsConfig := aws.NewConfig()

Expand Down
147 changes: 147 additions & 0 deletions conversation/aws/bedrock/bedrock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package bedrock

import (
"context"
"reflect"

awsAuth "github.com/dapr/components-contrib/common/authentication/aws"
"github.com/dapr/components-contrib/conversation"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/kit/logger"
kmeta "github.com/dapr/kit/metadata"

"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/bedrock"
)

type AWSBedrock struct {
model string
llm *bedrock.LLM

logger logger.Logger
}

type AWSBedrockMetadata struct {
Region string `json:"region"`
Endpoint string `json:"endpoint"`
AccessKey string `json:"accessKey"`
SecretKey string `json:"secretKey"`
SessionToken string `json:"sessionToken"`
Model string `json:"model"`
}

func NewAWSBedrock(logger logger.Logger) conversation.Conversation {
b := &AWSBedrock{
logger: logger,
}

return b
}

func convertRole(role conversation.Role) llms.ChatMessageType {
switch role {
case conversation.RoleSystem:
return llms.ChatMessageTypeSystem
case conversation.RoleUser:
return llms.ChatMessageTypeHuman
case conversation.RoleAssistant:
return llms.ChatMessageTypeAI
case conversation.RoleTool:
return llms.ChatMessageTypeTool
case conversation.RoleFunction:
return llms.ChatMessageTypeFunction
default:
return llms.ChatMessageTypeHuman
}
}

func (b *AWSBedrock) Init(ctx context.Context, meta conversation.Metadata) error {
m := AWSBedrockMetadata{}
err := kmeta.DecodeMetadata(meta.Properties, &m)
if err != nil {
return err
}

awsConfig, err := awsAuth.GetConfigV2(m.AccessKey, m.SecretKey, m.SessionToken, m.Region, m.Endpoint)
if err != nil {
return err
}

bedrockClient := bedrockruntime.NewFromConfig(awsConfig)

opts := []bedrock.Option{bedrock.WithClient(bedrockClient)}
if m.Model != "" {
opts = append(opts, bedrock.WithModel(m.Model))
}
b.model = m.Model

llm, err := bedrock.New(
opts...,
)
if err != nil {
return err
}

b.llm = llm
return nil
}

func (b *AWSBedrock) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
metadataStruct := AWSBedrockMetadata{}
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.ConversationType)
return
}

func (b *AWSBedrock) Converse(ctx context.Context, r *conversation.ConversationRequest) (res *conversation.ConversationResponse, err error) {
messages := make([]llms.MessageContent, 0, len(r.Inputs))

for _, input := range r.Inputs {
role := convertRole(input.Role)

messages = append(messages, llms.MessageContent{
Role: role,
Parts: []llms.ContentPart{
llms.TextPart(input.Message),
},
})
}

resp, err := b.llm.GenerateContent(ctx, messages)
if err != nil {
return nil, err
}

outputs := make([]conversation.ConversationResult, 0, len(resp.Choices))

for i := range resp.Choices {
outputs = append(outputs, conversation.ConversationResult{
Result: resp.Choices[i].Content,
Parameters: r.Parameters,
})
}

res = &conversation.ConversationResponse{
Outputs: outputs,
}

return res, nil
}

func (b *AWSBedrock) Close() error {
return nil
}
25 changes: 25 additions & 0 deletions conversation/aws/bedrock/bedrock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package bedrock

import (
"testing"

"github.com/dapr/components-contrib/conversation"

"github.com/stretchr/testify/assert"
"github.com/tmc/langchaingo/llms"
)

func TestConvertRole(t *testing.T) {
roles := map[string]string{
conversation.RoleSystem: string(llms.ChatMessageTypeSystem),
conversation.RoleAssistant: string(llms.ChatMessageTypeAI),
conversation.RoleFunction: string(llms.ChatMessageTypeFunction),
conversation.RoleUser: string(llms.ChatMessageTypeHuman),
conversation.RoleTool: string(llms.ChatMessageTypeTool),
}

for k, v := range roles {
r := convertRole(conversation.Role(k))
assert.Equal(t, v, string(r))
}
}
26 changes: 26 additions & 0 deletions conversation/aws/bedrock/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# yaml-language-server: $schema=../../../component-metadata-schema.json
schemaVersion: v1
type: conversation
name: aws.bedrock
version: v1
status: alpha
title: "AWS Bedrock"
urls:
- title: Reference
url: https://docs.dapr.io/reference/components-reference/supported-conversation/setup-aws-bedrock/
builtinAuthenticationProfiles:
- name: "aws"
metadata:
- name: endpoint
required: false
description: |
AWS endpoint for the component to use, to connect to emulators.
Do not use this when running against production AWS.
example: '"http://localhost:4566"'
type: string
- name: model
required: false
description: |
The LLM to use. Defaults to Bedrock's default provider model from Amazon.
type: string
example: 'amazon.titan-text-express-v1'
17 changes: 16 additions & 1 deletion conversation/converse.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,13 @@ type Conversation interface {
io.Closer
}

type ConversationInput struct {
Message string `json:"string"`
Role Role `json:"role"`
}

type ConversationRequest struct {
Inputs []string `json:"inputs"`
Inputs []ConversationInput `json:"inputs"`
Parameters map[string]*anypb.Any `json:"parameters"`
ConversationContext string `json:"conversationContext"`

Expand All @@ -54,3 +59,13 @@ type ConversationResponse struct {
ConversationContext string `json:"conversationContext"`
Outputs []ConversationResult `json:"outputs"`
}

type Role string

const (
RoleSystem = "system"
RoleUser = "user"
RoleAssistant = "assistant"
RoleFunction = "function"
RoleTool = "tool"
)
4 changes: 2 additions & 2 deletions conversation/echo/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewEcho(logger logger.Logger) conversation.Conversation {

func (e *Echo) Init(ctx context.Context, meta conversation.Metadata) error {
r := &conversation.ConversationRequest{}
err := kmeta.DecodeMetadata(meta.Properties, &r)
err := kmeta.DecodeMetadata(meta.Properties, r)
if err != nil {
return err
}
Expand All @@ -62,7 +62,7 @@ func (e *Echo) Converse(ctx context.Context, r *conversation.ConversationRequest

for _, input := range r.Inputs {
outputs = append(outputs, conversation.ConversationResult{
Result: input,
Result: input.Message,
Parameters: r.Parameters,
})
}
Expand Down
28 changes: 28 additions & 0 deletions conversation/echo/echo_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package echo

import (
"context"
"testing"

"github.com/dapr/components-contrib/conversation"
"github.com/dapr/kit/logger"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestConverse(t *testing.T) {
e := NewEcho(logger.NewLogger("echo test"))
e.Init(context.Background(), conversation.Metadata{})

r, err := e.Converse(context.Background(), &conversation.ConversationRequest{
Inputs: []conversation.ConversationInput{
{
Message: "hello",
},
},
})
require.NoError(t, err)
assert.Len(t, r.Outputs, 1)
assert.Equal(t, "hello", r.Outputs[0].Result)
}
29 changes: 29 additions & 0 deletions conversation/openai/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# yaml-language-server: $schema=../../../component-metadata-schema.json
schemaVersion: v1
type: conversation
name: openai
version: v1
status: alpha
title: "OpenAI"
urls:
- title: Reference
url: https://docs.dapr.io/reference/components-reference/supported-conversation/setup-openai/
authenticationProfiles:
- title: "API Key"
description: "Authenticate using an API key"
metadata:
- name: key
type: string
required: true
sensitive: true
description: |
API key for OpenAI.
example: "**********"
default: ""
metadata:
- name: model
required: false
description: |
The OpenAI LLM to use. Defaults to gpt-4o
type: string
example: 'gpt-4-turbo'
Loading

0 comments on commit 4664736

Please sign in to comment.