diff --git a/go.mod b/go.mod index 2a6e7bba1..166fa5604 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,8 @@ go 1.23.2 require ( buf.build/gen/go/grpc-ecosystem/grpc-gateway/protocolbuffers/go v1.35.2-20240617172850-a48fcebcf8f1.1 + github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.7.1 + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 github.com/Masterminds/semver/v3 v3.3.1 github.com/Masterminds/sprig/v3 v3.3.0 github.com/NYTimes/gziphandler v1.1.1 @@ -78,6 +80,7 @@ require ( github.com/Antonboom/errname v1.0.0 // indirect github.com/Antonboom/nilnil v1.0.0 // indirect github.com/Antonboom/testifylint v1.5.2 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c // indirect github.com/Crocmagnon/fatcontext v0.5.3 // indirect diff --git a/go.sum b/go.sum index f2c681393..98502a45b 100644 --- a/go.sum +++ b/go.sum @@ -63,6 +63,14 @@ github.com/Antonboom/nilnil v1.0.0/go.mod h1:fDJ1FSFoLN6yoG65ANb1WihItf6qt9PJVTn github.com/Antonboom/testifylint v1.5.2 h1:4s3Xhuv5AvdIgbd8wOOEeo0uZG7PbDKQyKY5lGoQazk= github.com/Antonboom/testifylint v1.5.2/go.mod h1:vxy8VJ0bc6NavlYqjZfmp6EfqXMtBgQ4+mhCojwC1P8= github.com/Azure/azure-pipeline-go v0.2.3/go.mod h1:x841ezTBIMG6O3lAcl8ATHnsOPVl2bqk7S3ta6S6u4k= +github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.7.1 h1:6njivKrpo02SQ3CsaGKIFh0c5ZhQyzjVhBmLIl84h4Q= +github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.7.1/go.mod h1:W+7E7pJtvdzscy/I4tqL5C0/weLsa32wyTbHbPdkkv0= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 h1:nyQWyZvwGTvunIMxi1Y9uXkcyr+I7TeNrr/foo4Kpk8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= github.com/Azure/azure-storage-blob-go v0.14.0/go.mod h1:SMqIBi+SuiQH32bvyjngEewEeXoPfKMgWlBDaYf6fck= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= @@ -72,6 +80,8 @@ github.com/Azure/go-autorest/autorest/date v0.3.0/go.mod h1:BI0uouVdmngYNUzGWeSY github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k= github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= +github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU= +github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c h1:pxW6RcqyfI9/kWtOwnv/G+AzdKuy2ZrqINhenH4HyNs= github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= @@ -396,6 +406,8 @@ github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= diff --git a/pkg/plugin/processor/builtin/impl/openaiembedding/config_paramgen.go b/pkg/plugin/processor/builtin/impl/openaiembedding/config_paramgen.go new file mode 100644 index 000000000..0a80158e9 --- /dev/null +++ b/pkg/plugin/processor/builtin/impl/openaiembedding/config_paramgen.go @@ -0,0 +1,71 @@ +// Code generated by paramgen. DO NOT EDIT. +// Source: github.com/ConduitIO/conduit-commons/tree/main/paramgen + +package openaiembedding + +import ( + "regexp" + + "github.com/conduitio/conduit-commons/config" +) + +const ( + procConfigApiKey = "apiKey" + procConfigEmbeddingEncodingFormat = "embeddingEncodingFormat" + procConfigEndpoint = "endpoint" + procConfigInputField = "inputField" + procConfigModel = "model" + procConfigOutputField = "outputField" +) + +func (procConfig) Parameters() map[string]config.Parameter { + return map[string]config.Parameter{ + procConfigApiKey: { + Default: "", + Description: "", + Type: config.ParameterTypeString, + Validations: []config.Validation{ + config.ValidationRequired{}, + }, + }, + procConfigEmbeddingEncodingFormat: { + Default: "float", + Description: "", + Type: config.ParameterTypeString, + Validations: []config.Validation{ + config.ValidationInclusion{List: []string{"float", "base64"}}, + }, + }, + procConfigEndpoint: { + Default: "https://api.openai.com/v1", + Description: "", + Type: config.ParameterTypeString, + Validations: []config.Validation{}, + }, + procConfigInputField: { + Default: ".Payload.After", + Description: "", + Type: config.ParameterTypeString, + Validations: []config.Validation{ + config.ValidationRegex{Regex: regexp.MustCompile("^\\.(Payload|Key).*")}, + }, + }, + procConfigModel: { + Default: "", + Description: "", + Type: config.ParameterTypeString, + Validations: []config.Validation{ + config.ValidationRequired{}, + config.ValidationInclusion{List: []string{"text-embedding-3-small", "text-embedding-3-large"}}, + }, + }, + procConfigOutputField: { + Default: ".Payload.After.vectors", + Description: "", + Type: config.ParameterTypeString, + Validations: []config.Validation{ + config.ValidationRegex{Regex: regexp.MustCompile("^\\.(Payload|Key).*")}, + }, + }, + } +} diff --git a/pkg/plugin/processor/builtin/impl/openaiembedding/openai_embedding.go b/pkg/plugin/processor/builtin/impl/openaiembedding/openai_embedding.go new file mode 100644 index 000000000..e5f7de6a1 --- /dev/null +++ b/pkg/plugin/processor/builtin/impl/openaiembedding/openai_embedding.go @@ -0,0 +1,186 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:generate paramgen -output=config_paramgen.go procConfig + +package openaiembedding + +import ( + "context" + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/conduitio/conduit-commons/config" + "github.com/conduitio/conduit-commons/opencdc" + sdk "github.com/conduitio/conduit-processor-sdk" + "github.com/conduitio/conduit/pkg/foundation/cerrors" + "github.com/conduitio/conduit/pkg/foundation/log" +) + +type procConfig struct { + APIKey string `json:"apiKey" validate:"required"` + Endpoint string `json:"endpoint" default:"https://api.openai.com/v1"` + EmbeddingEncodingFormat string `json:"embeddingEncodingFormat" default:"float" validate:"inclusion=float|base64"` + Model string `json:"model" validate:"required,inclusion=text-embedding-3-small|text-embedding-3-large"` + InputField string `json:"inputField" validate:"regex=^\\.(Payload|Key).*" default:".Payload.After"` + OutputField string `json:"outputField" validate:"regex=^\\.(Payload|Key).*" default:".Payload.After.vectors"` +} + +type processor struct { + sdk.UnimplementedProcessor + + cfg procConfig + + inputFieldRefResolver sdk.ReferenceResolver + outputFieldRefResolver sdk.ReferenceResolver + + client *azopenai.Client + logger log.CtxLogger +} + +func NewProcessor(log log.CtxLogger) sdk.Processor { + return &processor{ + logger: log.WithComponent("openai_embedding"), + } +} + +func (p *processor) Specification() (sdk.Specification, error) { + return sdk.Specification{ + Name: "openai.embedding", + Summary: "Generate OpenAI embeddings.", + Description: "", + Version: "v0.1.0", + Author: "Meroxa, Inc.", + Parameters: procConfig{}.Parameters(), + }, nil +} + +func (p *processor) Configure(ctx context.Context, c config.Config) error { + cfg := procConfig{} + err := sdk.ParseConfig(ctx, c, &cfg, procConfig{}.Parameters()) + if err != nil { + return cerrors.Errorf("failed to parse configuration: %w", err) + } + + inputResolver, err := sdk.NewReferenceResolver(cfg.InputField) + if err != nil { + return cerrors.Errorf(`failed to create a field resolver for %v parameter: %w`, cfg.InputField, err) + } + p.inputFieldRefResolver = inputResolver + + outputResolver, err := sdk.NewReferenceResolver(cfg.OutputField) + if err != nil { + return cerrors.Errorf(`failed to create a field resolver for %v parameter: %w`, cfg.OutputField, err) + } + p.outputFieldRefResolver = outputResolver + + p.cfg = cfg + return nil +} + +func (p *processor) Open(ctx context.Context) error { + keyCredential := azcore.NewKeyCredential(p.cfg.APIKey) + + // NOTE: this constructor creates a client that connects to the public OpenAI endpoint. + // To connect to an Azure OpenAI endpoint, use azopenai.NewClient() or azopenai.NewClientWithyKeyCredential. + client, err := azopenai.NewClientForOpenAI("https://api.openai.com/v1", keyCredential, nil) + if err != nil { + return cerrors.Errorf("failed to create OpenAI client: %w", err) + } + + p.client = client + return nil +} + +func (p *processor) Process(ctx context.Context, records []opencdc.Record) []sdk.ProcessedRecord { + out := make([]sdk.ProcessedRecord, 0, len(records)) + + // Prepare request (embedding inputs) + var embeddingInputs []string + for _, record := range records { + inRef, err := p.inputFieldRefResolver.Resolve(&record) + if err != nil { + out = append(out, sdk.ErrorRecord{Error: fmt.Errorf("failed to resolve reference %v: %w", p.cfg.InputField, err)}) + continue + } + + embeddingInputs = append(embeddingInputs, p.getEmbeddingInput(inRef.Get())) + } + + // Execute request (get embeddings) + embeddings, err := p.client.GetEmbeddings( + ctx, + azopenai.EmbeddingsOptions{ + Input: embeddingInputs, + Dimensions: nil, + EncodingFormat: (*azopenai.EmbeddingEncodingFormat)(&p.cfg.EmbeddingEncodingFormat), + DeploymentName: &p.cfg.Model, + InputType: nil, + User: nil, + }, + nil, + ) + // If the request failed, declare processing for all records as failed + if err != nil { + for range len(records) { + out = append(out, sdk.ErrorRecord{Error: err}) + } + + return out + } + + p.logger.Trace(ctx). + Any("embedding_input", embeddingInputs). + Any("embedding_output", embeddings). + Msg("got embeddings") + + for i, record := range records { + outRef, err := p.outputFieldRefResolver.Resolve(&record) + if err != nil { + out = append(out, sdk.ErrorRecord{Error: cerrors.Errorf("failed to resolve reference %v: %w", p.cfg.OutputField, err)}) + continue + } + + embeddingsMap := opencdc.StructuredData{ + // todo if the encoding format is base64, this needs to change + "embeddings": embeddings.Data[i].Embedding, + } + err = outRef.Set(embeddingsMap) + if err != nil { + out = append(out, sdk.ErrorRecord{Error: cerrors.Errorf("failed to set embeddings to %v: %w", p.cfg.OutputField, err)}) + continue + } + + // todo add metadata related to the embeddings + out = append(out, sdk.SingleRecord(record)) + } + + return out +} + +func (p *processor) Teardown(ctx context.Context) error { + return nil +} + +func (p *processor) getEmbeddingInput(val any) string { + switch v := val.(type) { + case opencdc.RawData: + return string(v) + case opencdc.StructuredData: + return string(v.Bytes()) + default: + return fmt.Sprintf("%v", v) + } +} diff --git a/pkg/plugin/processor/builtin/registry.go b/pkg/plugin/processor/builtin/registry.go index b6e6a18f5..914596885 100644 --- a/pkg/plugin/processor/builtin/registry.go +++ b/pkg/plugin/processor/builtin/registry.go @@ -31,6 +31,7 @@ import ( "github.com/conduitio/conduit/pkg/plugin/processor/builtin/impl/custom" "github.com/conduitio/conduit/pkg/plugin/processor/builtin/impl/field" "github.com/conduitio/conduit/pkg/plugin/processor/builtin/impl/json" + "github.com/conduitio/conduit/pkg/plugin/processor/builtin/impl/openaiembedding" "github.com/conduitio/conduit/pkg/plugin/processor/builtin/impl/unwrap" "github.com/conduitio/conduit/pkg/plugin/processor/builtin/impl/webhook" "github.com/conduitio/conduit/pkg/plugin/processor/procutils" @@ -55,6 +56,7 @@ var DefaultBuiltinProcessors = map[string]ProcessorPluginConstructor{ "unwrap.kafkaconnect": unwrap.NewKafkaConnectProcessor, "unwrap.opencdc": unwrap.NewOpenCDCProcessor, "webhook.http": webhook.NewHTTPProcessor, + "openai.embedding": openaiembedding.NewProcessor, } type schemaRegistryProcessor interface {