Skip to content

Commit

Permalink
OpenAI embeddings processor
Browse files Browse the repository at this point in the history
  • Loading branch information
hariso committed Dec 10, 2024
1 parent 0345e4b commit fea96e0
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 0 deletions.
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ go 1.23.2

require (
buf.build/gen/go/grpc-ecosystem/grpc-gateway/protocolbuffers/go v1.35.2-20240617172850-a48fcebcf8f1.1
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.7.1
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0
github.com/Masterminds/semver/v3 v3.3.1
github.com/Masterminds/sprig/v3 v3.3.0
github.com/NYTimes/gziphandler v1.1.1
Expand Down Expand Up @@ -78,6 +80,7 @@ require (
github.com/Antonboom/errname v1.0.0 // indirect
github.com/Antonboom/nilnil v1.0.0 // indirect
github.com/Antonboom/testifylint v1.5.2 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c // indirect
github.com/Crocmagnon/fatcontext v0.5.3 // indirect
Expand Down
12 changes: 12 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ github.com/Antonboom/nilnil v1.0.0/go.mod h1:fDJ1FSFoLN6yoG65ANb1WihItf6qt9PJVTn
github.com/Antonboom/testifylint v1.5.2 h1:4s3Xhuv5AvdIgbd8wOOEeo0uZG7PbDKQyKY5lGoQazk=
github.com/Antonboom/testifylint v1.5.2/go.mod h1:vxy8VJ0bc6NavlYqjZfmp6EfqXMtBgQ4+mhCojwC1P8=
github.com/Azure/azure-pipeline-go v0.2.3/go.mod h1:x841ezTBIMG6O3lAcl8ATHnsOPVl2bqk7S3ta6S6u4k=
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.7.1 h1:6njivKrpo02SQ3CsaGKIFh0c5ZhQyzjVhBmLIl84h4Q=
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.7.1/go.mod h1:W+7E7pJtvdzscy/I4tqL5C0/weLsa32wyTbHbPdkkv0=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 h1:nyQWyZvwGTvunIMxi1Y9uXkcyr+I7TeNrr/foo4Kpk8=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY=
github.com/Azure/azure-storage-blob-go v0.14.0/go.mod h1:SMqIBi+SuiQH32bvyjngEewEeXoPfKMgWlBDaYf6fck=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
Expand All @@ -72,6 +80,8 @@ github.com/Azure/go-autorest/autorest/date v0.3.0/go.mod h1:BI0uouVdmngYNUzGWeSY
github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k=
github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8=
github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU=
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU=
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c h1:pxW6RcqyfI9/kWtOwnv/G+AzdKuy2ZrqINhenH4HyNs=
github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
Expand Down Expand Up @@ -396,6 +406,8 @@ github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E=
github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

186 changes: 186 additions & 0 deletions pkg/plugin/processor/builtin/impl/openaiembedding/openai_embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
// Copyright © 2024 Meroxa, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:generate paramgen -output=config_paramgen.go procConfig

package openaiembedding

import (
"context"
"fmt"

"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/conduitio/conduit-commons/config"
"github.com/conduitio/conduit-commons/opencdc"
sdk "github.com/conduitio/conduit-processor-sdk"
"github.com/conduitio/conduit/pkg/foundation/cerrors"
"github.com/conduitio/conduit/pkg/foundation/log"
)

type procConfig struct {
APIKey string `json:"apiKey" validate:"required"`
Endpoint string `json:"endpoint" default:"https://api.openai.com/v1"`
EmbeddingEncodingFormat string `json:"embeddingEncodingFormat" default:"float" validate:"inclusion=float|base64"`
Model string `json:"model" validate:"required,inclusion=text-embedding-3-small|text-embedding-3-large"`
InputField string `json:"inputField" validate:"regex=^\\.(Payload|Key).*" default:".Payload.After"`
OutputField string `json:"outputField" validate:"regex=^\\.(Payload|Key).*" default:".Payload.After.vectors"`
}

type processor struct {
sdk.UnimplementedProcessor

cfg procConfig

inputFieldRefResolver sdk.ReferenceResolver
outputFieldRefResolver sdk.ReferenceResolver

client *azopenai.Client
logger log.CtxLogger
}

func NewProcessor(log log.CtxLogger) sdk.Processor {
return &processor{
logger: log.WithComponent("openai_embedding"),
}
}

func (p *processor) Specification() (sdk.Specification, error) {
return sdk.Specification{
Name: "openai.embedding",
Summary: "Generate OpenAI embeddings.",
Description: "",
Version: "v0.1.0",
Author: "Meroxa, Inc.",
Parameters: procConfig{}.Parameters(),
}, nil
}

func (p *processor) Configure(ctx context.Context, c config.Config) error {
cfg := procConfig{}
err := sdk.ParseConfig(ctx, c, &cfg, procConfig{}.Parameters())
if err != nil {
return cerrors.Errorf("failed to parse configuration: %w", err)
}

inputResolver, err := sdk.NewReferenceResolver(cfg.InputField)
if err != nil {
return cerrors.Errorf(`failed to create a field resolver for %v parameter: %w`, cfg.InputField, err)
}
p.inputFieldRefResolver = inputResolver

outputResolver, err := sdk.NewReferenceResolver(cfg.OutputField)
if err != nil {
return cerrors.Errorf(`failed to create a field resolver for %v parameter: %w`, cfg.OutputField, err)
}
p.outputFieldRefResolver = outputResolver

p.cfg = cfg
return nil
}

func (p *processor) Open(ctx context.Context) error {
keyCredential := azcore.NewKeyCredential(p.cfg.APIKey)

// NOTE: this constructor creates a client that connects to the public OpenAI endpoint.
// To connect to an Azure OpenAI endpoint, use azopenai.NewClient() or azopenai.NewClientWithyKeyCredential.
client, err := azopenai.NewClientForOpenAI("https://api.openai.com/v1", keyCredential, nil)
if err != nil {
return cerrors.Errorf("failed to create OpenAI client: %w", err)
}

p.client = client
return nil
}

func (p *processor) Process(ctx context.Context, records []opencdc.Record) []sdk.ProcessedRecord {
out := make([]sdk.ProcessedRecord, 0, len(records))

// Prepare request (embedding inputs)
var embeddingInputs []string
for _, record := range records {
inRef, err := p.inputFieldRefResolver.Resolve(&record)
if err != nil {
out = append(out, sdk.ErrorRecord{Error: fmt.Errorf("failed to resolve reference %v: %w", p.cfg.InputField, err)})
continue
}

embeddingInputs = append(embeddingInputs, p.getEmbeddingInput(inRef.Get()))
}

// Execute request (get embeddings)
embeddings, err := p.client.GetEmbeddings(
ctx,
azopenai.EmbeddingsOptions{
Input: embeddingInputs,
Dimensions: nil,
EncodingFormat: (*azopenai.EmbeddingEncodingFormat)(&p.cfg.EmbeddingEncodingFormat),
DeploymentName: &p.cfg.Model,
InputType: nil,
User: nil,
},
nil,
)
// If the request failed, declare processing for all records as failed
if err != nil {
for range len(records) {
out = append(out, sdk.ErrorRecord{Error: err})
}

return out
}

p.logger.Trace(ctx).
Any("embedding_input", embeddingInputs).
Any("embedding_output", embeddings).
Msg("got embeddings")

for i, record := range records {
outRef, err := p.outputFieldRefResolver.Resolve(&record)
if err != nil {
out = append(out, sdk.ErrorRecord{Error: cerrors.Errorf("failed to resolve reference %v: %w", p.cfg.OutputField, err)})
continue
}

embeddingsMap := opencdc.StructuredData{
// todo if the encoding format is base64, this needs to change
"embeddings": embeddings.Data[i].Embedding,
}
err = outRef.Set(embeddingsMap)
if err != nil {
out = append(out, sdk.ErrorRecord{Error: cerrors.Errorf("failed to set embeddings to %v: %w", p.cfg.OutputField, err)})
continue
}

// todo add metadata related to the embeddings
out = append(out, sdk.SingleRecord(record))
}

return out
}

func (p *processor) Teardown(ctx context.Context) error {
return nil
}

func (p *processor) getEmbeddingInput(val any) string {
switch v := val.(type) {
case opencdc.RawData:
return string(v)
case opencdc.StructuredData:
return string(v.Bytes())
default:
return fmt.Sprintf("%v", v)
}
}
2 changes: 2 additions & 0 deletions pkg/plugin/processor/builtin/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/conduitio/conduit/pkg/plugin/processor/builtin/impl/custom"
"github.com/conduitio/conduit/pkg/plugin/processor/builtin/impl/field"
"github.com/conduitio/conduit/pkg/plugin/processor/builtin/impl/json"
"github.com/conduitio/conduit/pkg/plugin/processor/builtin/impl/openaiembedding"
"github.com/conduitio/conduit/pkg/plugin/processor/builtin/impl/unwrap"
"github.com/conduitio/conduit/pkg/plugin/processor/builtin/impl/webhook"
"github.com/conduitio/conduit/pkg/plugin/processor/procutils"
Expand All @@ -55,6 +56,7 @@ var DefaultBuiltinProcessors = map[string]ProcessorPluginConstructor{
"unwrap.kafkaconnect": unwrap.NewKafkaConnectProcessor,
"unwrap.opencdc": unwrap.NewOpenCDCProcessor,
"webhook.http": webhook.NewHTTPProcessor,
"openai.embedding": openaiembedding.NewProcessor,
}

type schemaRegistryProcessor interface {
Expand Down

0 comments on commit fea96e0

Please sign in to comment.