Skip to content

Commit

Permalink
[Go] replace Embedder interface with EmbedderAction
Browse files Browse the repository at this point in the history
Remove the Embedder interface. Replace it with a type
alias to an Action with the appropriate type parameters.
  • Loading branch information
jba committed Jun 7, 2024
1 parent e9f5a97 commit 17d32ce
Show file tree
Hide file tree
Showing 11 changed files with 38 additions and 44 deletions.
31 changes: 13 additions & 18 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@ import (
"github.com/firebase/genkit/go/core"
)

// Embedder is the interface used to convert a document to a
// multidimensional vector. A [DocumentStore] will use a value of this type.
type Embedder interface {
Embed(context.Context, *EmbedRequest) ([]float32, error)
}
// EmbedderAction is used to convert a document to a
// multidimensional vector.
type EmbedderAction = core.Action[*EmbedRequest, []float32, struct{}]

// EmbedRequest is the data we pass to convert a document
// to a multidimensional vector.
Expand All @@ -34,25 +32,22 @@ type EmbedRequest struct {
}

// DefineEmbedder registers the given embed function as an action, and returns an
// [Embedder] whose Embed method runs it.
func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) Embedder {
return embedder{core.DefineAction(provider, name, core.ActionTypeEmbedder, nil, embed)}
// [EmbedderAction] that runs it.
func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) *EmbedderAction {
return core.DefineAction(provider, name, core.ActionTypeEmbedder, nil, embed)
}

// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder].
// It returns nil if the Embedder was not defined.
func LookupEmbedder(provider, name string) Embedder {
// LookupEmbedder looks up an [EmbedderAction] registered by [DefineEmbedder].
// It returns nil if the embedder was not defined.
func LookupEmbedder(provider, name string) *EmbedderAction {
action := core.LookupActionFor[*EmbedRequest, []float32, struct{}](core.ActionTypeEmbedder, provider, name)
if action == nil {
return nil
}
return embedder{action}
}

type embedder struct {
embedAction *core.Action[*EmbedRequest, []float32, struct{}]
return action
}

func (e embedder) Embed(ctx context.Context, req *EmbedRequest) ([]float32, error) {
return e.embedAction.Run(ctx, req, nil)
// Embed runs the given [EmbedderAction].
func Embed(ctx context.Context, emb *EmbedderAction, req *EmbedRequest) ([]float32, error) {
return emb.Run(ctx, req, nil)
}
1 change: 0 additions & 1 deletion go/internal/fakeembedder/fakeembedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ func (e *Embedder) Register(d *ai.Document, vals []float32) {
e.registry[d] = vals
}

// Embed implements genkit.Embedder.
func (e *Embedder) Embed(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) {
vals, ok := e.registry[req.Document]
if !ok {
Expand Down
8 changes: 3 additions & 5 deletions go/internal/fakeembedder/fakeembedder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,17 @@ import (

func TestFakeEmbedder(t *testing.T) {
embed := New()

embedAction := ai.DefineEmbedder("fake", "embed", embed.Embed)
d := ai.DocumentFromText("fakeembedder test", nil)

vals := []float32{1, 2}
embed.Register(d, vals)

var genkitEmbedder ai.Embedder
genkitEmbedder = embed
req := &ai.EmbedRequest{
Document: d,
}
ctx := context.Background()
got, err := genkitEmbedder.Embed(ctx, req)
got, err := ai.Embed(ctx, embedAction, req)
if err != nil {
t.Fatal(err)
}
Expand All @@ -45,7 +43,7 @@ func TestFakeEmbedder(t *testing.T) {
}

req.Document = ai.DocumentFromText("missing document", nil)
if _, err = genkitEmbedder.Embed(ctx, req); err == nil {
if _, err = ai.Embed(ctx, embedAction, req); err == nil {
t.Error("embedding unknown document succeeded unexpectedly")
}
}
4 changes: 2 additions & 2 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ func Model(name string) *ai.ModelAction {
return ai.LookupModel(provider, name)
}

// Embedder returns the embedder with the given name.
// Embedder returns the [ai.EmbedderAction] with the given name.
// It returns nil if the embedder was not configured.
func Embedder(name string) ai.Embedder {
func Embedder(name string) *ai.EmbedderAction {
return ai.LookupEmbedder(provider, name)
}

Expand Down
2 changes: 1 addition & 1 deletion go/plugins/googleai/googleai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func TestLive(t *testing.T) {
},
)
t.Run("embedder", func(t *testing.T) {
out, err := googleai.Embedder(embeddingModel).Embed(ctx, &ai.EmbedRequest{
out, err := ai.Embed(ctx, googleai.Embedder(embeddingModel), &ai.EmbedRequest{
Document: ai.DocumentFromText("yellow banana", nil),
})
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions go/plugins/localvec/localvec.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import (
// retriever with genkit, and also return it.
// This retriever may only be used by a single goroutine at a time.
// This is based on js/plugins/dev-local-vectorstore/src/index.ts.
func New(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOptions any) (ai.DocumentStore, error) {
func New(ctx context.Context, dir, name string, embedder *ai.EmbedderAction, embedderOptions any) (ai.DocumentStore, error) {
r, err := newDocStore(ctx, dir, name, embedder, embedderOptions)
if err != nil {
return nil, err
Expand All @@ -50,7 +50,7 @@ func New(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOp
// for a local vector database.
type docStore struct {
filename string
embedder ai.Embedder
embedder *ai.EmbedderAction
embedderOptions any
data map[string]dbValue
}
Expand All @@ -62,7 +62,7 @@ type dbValue struct {
}

// newDocStore returns a new ai.DocumentStore to register.
func newDocStore(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOptions any) (ai.DocumentStore, error) {
func newDocStore(ctx context.Context, dir, name string, embedder *ai.EmbedderAction, embedderOptions any) (ai.DocumentStore, error) {
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, err
}
Expand Down Expand Up @@ -98,7 +98,7 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
Document: doc,
Options: ds.embedderOptions,
}
vals, err := ds.embedder.Embed(ctx, ereq)
vals, err := ai.Embed(ctx, ds.embedder, ereq)
if err != nil {
return fmt.Errorf("localvec index embedding failed: %v", err)
}
Expand Down Expand Up @@ -160,7 +160,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
Document: req.Document,
Options: ds.embedderOptions,
}
vals, err := ds.embedder.Embed(ctx, ereq)
vals, err := ai.Embed(ctx, ds.embedder, ereq)
if err != nil {
return nil, fmt.Errorf("localvec retrieve embedding failed: %v", err)
}
Expand Down
9 changes: 5 additions & 4 deletions go/plugins/localvec/localvec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func TestLocalVec(t *testing.T) {
embedder.Register(d1, v1)
embedder.Register(d2, v2)
embedder.Register(d3, v3)

ds, err := newDocStore(ctx, t.TempDir(), "testLocalVec", embedder, nil)
embedAction := ai.DefineEmbedder("fake", "embedder1", embedder.Embed)
ds, err := newDocStore(ctx, t.TempDir(), "testLocalVec", embedAction, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -110,10 +110,11 @@ func TestPersistentIndexing(t *testing.T) {
embedder.Register(d1, v1)
embedder.Register(d2, v2)
embedder.Register(d3, v3)
embedAction := ai.DefineEmbedder("fake", "embedder2", embedder.Embed)

tDir := t.TempDir()

ds, err := newDocStore(ctx, tDir, "testLocalVec", embedder, nil)
ds, err := newDocStore(ctx, tDir, "testLocalVec", embedAction, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -144,7 +145,7 @@ func TestPersistentIndexing(t *testing.T) {
t.Errorf("got %d results, expected 2", len(docs))
}

dsAnother, err := newDocStore(ctx, tDir, "testLocalVec", embedder, nil)
dsAnother, err := newDocStore(ctx, tDir, "testLocalVec", embedAction, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
8 changes: 4 additions & 4 deletions go/plugins/pinecone/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ const defaultTextKey = "_content"
//
// The textKey parameter is the metadata key to use to store document text
// in Pinecone; the default is "_content".
func New(ctx context.Context, apiKey, host string, embedder ai.Embedder, embedderOptions any, textKey string) (ai.DocumentStore, error) {
func New(ctx context.Context, apiKey, host string, embedder *ai.EmbedderAction, embedderOptions any, textKey string) (ai.DocumentStore, error) {
client, err := NewClient(ctx, apiKey)
if err != nil {
return nil, err
Expand Down Expand Up @@ -90,7 +90,7 @@ type RetrieverOptions struct {
// docStore implements the genkit [ai.DocumentStore] interface.
type docStore struct {
index *Index
embedder ai.Embedder
embedder *ai.EmbedderAction
embedderOptions any
textKey string
}
Expand Down Expand Up @@ -120,7 +120,7 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
Document: doc,
Options: ds.embedderOptions,
}
vals, err := ds.embedder.Embed(ctx, ereq)
vals, err := ai.Embed(ctx, ds.embedder, ereq)
if err != nil {
return fmt.Errorf("pinecone index embedding failed: %v", err)
}
Expand Down Expand Up @@ -215,7 +215,7 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
Document: req.Document,
Options: ds.embedderOptions,
}
vals, err := ds.embedder.Embed(ctx, ereq)
vals, err := ai.Embed(ctx, ds.embedder, ereq)
if err != nil {
return nil, fmt.Errorf("pinecone retrieve embedding failed: %v", err)
}
Expand Down
3 changes: 2 additions & 1 deletion go/plugins/pinecone/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ func TestGenkit(t *testing.T) {
embedder.Register(d1, v1)
embedder.Register(d2, v2)
embedder.Register(d3, v3)
embedAction := ai.DefineEmbedder("fake", "embedder3", embedder.Embed)

r, err := New(ctx, *testAPIKey, indexData.Host, embedder, nil, "")
r, err := New(ctx, *testAPIKey, indexData.Host, embedAction, nil, "")
if err != nil {
t.Fatal(err)
}
Expand Down
4 changes: 2 additions & 2 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ func Model(name string) *ai.ModelAction {
return ai.LookupModel(provider, name)
}

// Embedder returns the embedder with the given name.
// Embedder returns the [ai.EmbedderAction] with the given name.
// It returns nil if the embedder was not configured.
func Embedder(name string) ai.Embedder {
func Embedder(name string) *ai.EmbedderAction {
return ai.LookupEmbedder(provider, name)
}

Expand Down
2 changes: 1 addition & 1 deletion go/plugins/vertexai/vertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func TestLive(t *testing.T) {
}
})
t.Run("embedder", func(t *testing.T) {
out, err := vertexai.Embedder(embedderName).Embed(ctx, &ai.EmbedRequest{
out, err := ai.Embed(ctx, vertexai.Embedder(embedderName), &ai.EmbedRequest{
Document: ai.DocumentFromText("time flies like an arrow", nil),
})
if err != nil {
Expand Down

0 comments on commit 17d32ce

Please sign in to comment.