diff --git a/go/ai/embedder.go b/go/ai/embedder.go index 041f56e76..bec1ff492 100644 --- a/go/ai/embedder.go +++ b/go/ai/embedder.go @@ -20,11 +20,9 @@ import ( "github.com/firebase/genkit/go/core" ) -// Embedder is the interface used to convert a document to a -// multidimensional vector. A [DocumentStore] will use a value of this type. -type Embedder interface { - Embed(context.Context, *EmbedRequest) ([]float32, error) -} +// EmbedderAction is used to convert a document to a +// multidimensional vector. +type EmbedderAction = core.Action[*EmbedRequest, []float32, struct{}] // EmbedRequest is the data we pass to convert a document // to a multidimensional vector. @@ -34,25 +32,22 @@ type EmbedRequest struct { } // DefineEmbedder registers the given embed function as an action, and returns an -// [Embedder] whose Embed method runs it. -func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) Embedder { - return embedder{core.DefineAction(provider, name, core.ActionTypeEmbedder, nil, embed)} +// [EmbedderAction] that runs it. +func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) *EmbedderAction { + return core.DefineAction(provider, name, core.ActionTypeEmbedder, nil, embed) } -// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder]. -// It returns nil if the Embedder was not defined. -func LookupEmbedder(provider, name string) Embedder { +// LookupEmbedder looks up an [EmbedderAction] registered by [DefineEmbedder]. +// It returns nil if the embedder was not defined. +func LookupEmbedder(provider, name string) *EmbedderAction { action := core.LookupActionFor[*EmbedRequest, []float32, struct{}](core.ActionTypeEmbedder, provider, name) if action == nil { return nil } - return embedder{action} -} - -type embedder struct { - embedAction *core.Action[*EmbedRequest, []float32, struct{}] + return action } -func (e embedder) Embed(ctx context.Context, req *EmbedRequest) ([]float32, error) { - return e.embedAction.Run(ctx, req, nil) +// Embed runs the given [EmbedderAction]. +func Embed(ctx context.Context, emb *EmbedderAction, req *EmbedRequest) ([]float32, error) { + return emb.Run(ctx, req, nil) } diff --git a/go/internal/fakeembedder/fakeembedder.go b/go/internal/fakeembedder/fakeembedder.go index 9b5b7da01..47c2e06b6 100644 --- a/go/internal/fakeembedder/fakeembedder.go +++ b/go/internal/fakeembedder/fakeembedder.go @@ -43,7 +43,6 @@ func (e *Embedder) Register(d *ai.Document, vals []float32) { e.registry[d] = vals } -// Embed implements genkit.Embedder. func (e *Embedder) Embed(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) { vals, ok := e.registry[req.Document] if !ok { diff --git a/go/internal/fakeembedder/fakeembedder_test.go b/go/internal/fakeembedder/fakeembedder_test.go index ea50dfe9e..85319054f 100644 --- a/go/internal/fakeembedder/fakeembedder_test.go +++ b/go/internal/fakeembedder/fakeembedder_test.go @@ -24,19 +24,17 @@ import ( func TestFakeEmbedder(t *testing.T) { embed := New() - + embedAction := ai.DefineEmbedder("fake", "embed", embed.Embed) d := ai.DocumentFromText("fakeembedder test", nil) vals := []float32{1, 2} embed.Register(d, vals) - var genkitEmbedder ai.Embedder - genkitEmbedder = embed req := &ai.EmbedRequest{ Document: d, } ctx := context.Background() - got, err := genkitEmbedder.Embed(ctx, req) + got, err := ai.Embed(ctx, embedAction, req) if err != nil { t.Fatal(err) } @@ -45,7 +43,7 @@ func TestFakeEmbedder(t *testing.T) { } req.Document = ai.DocumentFromText("missing document", nil) - if _, err = genkitEmbedder.Embed(ctx, req); err == nil { + if _, err = ai.Embed(ctx, embedAction, req); err == nil { t.Error("embedding unknown document succeeded unexpectedly") } } diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 01f06e470..098de5f15 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -121,9 +121,9 @@ func Model(name string) *ai.ModelAction { return ai.LookupModel(provider, name) } -// Embedder returns the embedder with the given name. +// Embedder returns the [ai.EmbedderAction] with the given name. // It returns nil if the embedder was not configured. -func Embedder(name string) ai.Embedder { +func Embedder(name string) *ai.EmbedderAction { return ai.LookupEmbedder(provider, name) } diff --git a/go/plugins/googleai/googleai_test.go b/go/plugins/googleai/googleai_test.go index 83a677975..dc1656ca2 100644 --- a/go/plugins/googleai/googleai_test.go +++ b/go/plugins/googleai/googleai_test.go @@ -82,7 +82,7 @@ func TestLive(t *testing.T) { }, ) t.Run("embedder", func(t *testing.T) { - out, err := googleai.Embedder(embeddingModel).Embed(ctx, &ai.EmbedRequest{ + out, err := ai.Embed(ctx, googleai.Embedder(embeddingModel), &ai.EmbedRequest{ Document: ai.DocumentFromText("yellow banana", nil), }) if err != nil { diff --git a/go/plugins/localvec/localvec.go b/go/plugins/localvec/localvec.go index 428f31802..8213c1f3d 100644 --- a/go/plugins/localvec/localvec.go +++ b/go/plugins/localvec/localvec.go @@ -38,7 +38,7 @@ import ( // retriever with genkit, and also return it. // This retriever may only be used by a single goroutine at a time. // This is based on js/plugins/dev-local-vectorstore/src/index.ts. -func New(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOptions any) (ai.DocumentStore, error) { +func New(ctx context.Context, dir, name string, embedder *ai.EmbedderAction, embedderOptions any) (ai.DocumentStore, error) { r, err := newDocStore(ctx, dir, name, embedder, embedderOptions) if err != nil { return nil, err @@ -50,7 +50,7 @@ func New(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOp // for a local vector database. type docStore struct { filename string - embedder ai.Embedder + embedder *ai.EmbedderAction embedderOptions any data map[string]dbValue } @@ -62,7 +62,7 @@ type dbValue struct { } // newDocStore returns a new ai.DocumentStore to register. -func newDocStore(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOptions any) (ai.DocumentStore, error) { +func newDocStore(ctx context.Context, dir, name string, embedder *ai.EmbedderAction, embedderOptions any) (ai.DocumentStore, error) { if err := os.MkdirAll(dir, 0o755); err != nil { return nil, err } @@ -98,7 +98,7 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { Document: doc, Options: ds.embedderOptions, } - vals, err := ds.embedder.Embed(ctx, ereq) + vals, err := ai.Embed(ctx, ds.embedder, ereq) if err != nil { return fmt.Errorf("localvec index embedding failed: %v", err) } @@ -160,7 +160,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai Document: req.Document, Options: ds.embedderOptions, } - vals, err := ds.embedder.Embed(ctx, ereq) + vals, err := ai.Embed(ctx, ds.embedder, ereq) if err != nil { return nil, fmt.Errorf("localvec retrieve embedding failed: %v", err) } diff --git a/go/plugins/localvec/localvec_test.go b/go/plugins/localvec/localvec_test.go index b48bcb648..8a81cc913 100644 --- a/go/plugins/localvec/localvec_test.go +++ b/go/plugins/localvec/localvec_test.go @@ -50,8 +50,8 @@ func TestLocalVec(t *testing.T) { embedder.Register(d1, v1) embedder.Register(d2, v2) embedder.Register(d3, v3) - - ds, err := newDocStore(ctx, t.TempDir(), "testLocalVec", embedder, nil) + embedAction := ai.DefineEmbedder("fake", "embedder1", embedder.Embed) + ds, err := newDocStore(ctx, t.TempDir(), "testLocalVec", embedAction, nil) if err != nil { t.Fatal(err) } @@ -110,10 +110,11 @@ func TestPersistentIndexing(t *testing.T) { embedder.Register(d1, v1) embedder.Register(d2, v2) embedder.Register(d3, v3) + embedAction := ai.DefineEmbedder("fake", "embedder2", embedder.Embed) tDir := t.TempDir() - ds, err := newDocStore(ctx, tDir, "testLocalVec", embedder, nil) + ds, err := newDocStore(ctx, tDir, "testLocalVec", embedAction, nil) if err != nil { t.Fatal(err) } @@ -144,7 +145,7 @@ func TestPersistentIndexing(t *testing.T) { t.Errorf("got %d results, expected 2", len(docs)) } - dsAnother, err := newDocStore(ctx, tDir, "testLocalVec", embedder, nil) + dsAnother, err := newDocStore(ctx, tDir, "testLocalVec", embedAction, nil) if err != nil { t.Fatal(err) } diff --git a/go/plugins/pinecone/genkit.go b/go/plugins/pinecone/genkit.go index 021d0270a..355fbb04d 100644 --- a/go/plugins/pinecone/genkit.go +++ b/go/plugins/pinecone/genkit.go @@ -50,7 +50,7 @@ const defaultTextKey = "_content" // // The textKey parameter is the metadata key to use to store document text // in Pinecone; the default is "_content". -func New(ctx context.Context, apiKey, host string, embedder ai.Embedder, embedderOptions any, textKey string) (ai.DocumentStore, error) { +func New(ctx context.Context, apiKey, host string, embedder *ai.EmbedderAction, embedderOptions any, textKey string) (ai.DocumentStore, error) { client, err := NewClient(ctx, apiKey) if err != nil { return nil, err @@ -90,7 +90,7 @@ type RetrieverOptions struct { // docStore implements the genkit [ai.DocumentStore] interface. type docStore struct { index *Index - embedder ai.Embedder + embedder *ai.EmbedderAction embedderOptions any textKey string } @@ -120,7 +120,7 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error { Document: doc, Options: ds.embedderOptions, } - vals, err := ds.embedder.Embed(ctx, ereq) + vals, err := ai.Embed(ctx, ds.embedder, ereq) if err != nil { return fmt.Errorf("pinecone index embedding failed: %v", err) } @@ -215,7 +215,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai Document: req.Document, Options: ds.embedderOptions, } - vals, err := ds.embedder.Embed(ctx, ereq) + vals, err := ai.Embed(ctx, ds.embedder, ereq) if err != nil { return nil, fmt.Errorf("pinecone retrieve embedding failed: %v", err) } diff --git a/go/plugins/pinecone/genkit_test.go b/go/plugins/pinecone/genkit_test.go index 73f01eb85..d2837918a 100644 --- a/go/plugins/pinecone/genkit_test.go +++ b/go/plugins/pinecone/genkit_test.go @@ -71,8 +71,9 @@ func TestGenkit(t *testing.T) { embedder.Register(d1, v1) embedder.Register(d2, v2) embedder.Register(d3, v3) + embedAction := ai.DefineEmbedder("fake", "embedder3", embedder.Embed) - r, err := New(ctx, *testAPIKey, indexData.Host, embedder, nil, "") + r, err := New(ctx, *testAPIKey, indexData.Host, embedAction, nil, "") if err != nil { t.Fatal(err) } diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index f750e3e14..409cab13d 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -113,9 +113,9 @@ func Model(name string) *ai.ModelAction { return ai.LookupModel(provider, name) } -// Embedder returns the embedder with the given name. +// Embedder returns the [ai.EmbedderAction] with the given name. // It returns nil if the embedder was not configured. -func Embedder(name string) ai.Embedder { +func Embedder(name string) *ai.EmbedderAction { return ai.LookupEmbedder(provider, name) } diff --git a/go/plugins/vertexai/vertexai_test.go b/go/plugins/vertexai/vertexai_test.go index e5ae734ff..d1563c5aa 100644 --- a/go/plugins/vertexai/vertexai_test.go +++ b/go/plugins/vertexai/vertexai_test.go @@ -129,7 +129,7 @@ func TestLive(t *testing.T) { } }) t.Run("embedder", func(t *testing.T) { - out, err := vertexai.Embedder(embedderName).Embed(ctx, &ai.EmbedRequest{ + out, err := ai.Embed(ctx, vertexai.Embedder(embedderName), &ai.EmbedRequest{ Document: ai.DocumentFromText("time flies like an arrow", nil), }) if err != nil {