Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] replace ai.DocumentStore with indexer and retriever actions #353

Merged
merged 1 commit into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 34 additions & 27 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,12 @@ import (
"github.com/firebase/genkit/go/core"
)

// DocumentStore supports adding documents to a database, and
// retrieving documents from the database that are similar to a query document.
// Vector databases will implement this interface.
type DocumentStore interface {
// Add a document to the database.
Index(context.Context, *IndexerRequest) error
// Retrieve matching documents from the database.
Retrieve(context.Context, *RetrieverRequest) (*RetrieverResponse, error)
}
type (
// An IndexerAction is used to index documents in a store.
IndexerAction = core.Action[*IndexerRequest, struct{}, struct{}]
// A RetrieverAction is used to retrieve indexed documents.
RetrieverAction = core.Action[*RetrieverRequest, *RetrieverResponse, struct{}]
)

// IndexerRequest is the data we pass to add documents to the database.
// The Options field is specific to the actual retriever implementation.
Expand All @@ -49,30 +46,40 @@ type RetrieverResponse struct {
Documents []*Document `json:"documents"`
}

// DefineDocumentStore takes index and retrieve functions that access a document store
// and returns a new [DocumentStore] that wraps them in registered actions.
func DefineDocumentStore(
name string,
index func(context.Context, *IndexerRequest) error,
retrieve func(context.Context, *RetrieverRequest) (*RetrieverResponse, error),
) DocumentStore {
ia := core.DefineAction("indexer", name, core.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
// DefineIndexer registers the given index function as an action, and returns an
// [IndexerAction] that runs it.
func DefineIndexer(provider, name string, index func(context.Context, *IndexerRequest) error) *IndexerAction {
f := func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
return struct{}{}, index(ctx, req)
})
ra := core.DefineAction("retriever", name, core.ActionTypeRetriever, nil, retrieve)
return &docStore{ia, ra}
}
return core.DefineAction(provider, name, core.ActionTypeIndexer, nil, f)
}

// LookupIndexer looks up a [IndexerAction] registered by [DefineIndexer].
// It returns nil if the model was not defined.
func LookupIndexer(provider, name string) *IndexerAction {
return core.LookupActionFor[*IndexerRequest, struct{}, struct{}](core.ActionTypeIndexer, provider, name)
}

// DefineRetriever registers the given retrieve function as an action, and returns a
// [RetrieverAction] that runs it.
func DefineRetriever(provider, name string, ret func(context.Context, *RetrieverRequest) (*RetrieverResponse, error)) *RetrieverAction {
return core.DefineAction(provider, name, core.ActionTypeRetriever, nil, ret)
}

type docStore struct {
index *core.Action[*IndexerRequest, struct{}, struct{}]
retrieve *core.Action[*RetrieverRequest, *RetrieverResponse, struct{}]
// LookupRetriever looks up a [RetrieverAction] registered by [DefineRetriever].
// It returns nil if the model was not defined.
func LookupRetriever(provider, name string) *RetrieverAction {
return core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](core.ActionTypeRetriever, provider, name)
}

func (ds *docStore) Index(ctx context.Context, req *IndexerRequest) error {
_, err := ds.index.Run(ctx, req, nil)
// Index runs the given [IndexerAction].
func Index(ctx context.Context, indexer *IndexerAction, req *IndexerRequest) error {
_, err := indexer.Run(ctx, req, nil)
return err
}

func (ds *docStore) Retrieve(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) {
return ds.retrieve.Run(ctx, req, nil)
// Retrieve runs the given [RetrieverAction].
func Retrieve(ctx context.Context, retriever *RetrieverAction, req *RetrieverRequest) (*RetrieverResponse, error) {
return retriever.Run(ctx, req, nil)
}
6 changes: 5 additions & 1 deletion go/core/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ func LookupAction(typ ActionType, provider, name string) action {
// or nil if there is none.
// It panics if the action is of the wrong type.
func LookupActionFor[In, Out, Stream any](typ ActionType, provider, name string) *Action[In, Out, Stream] {
return LookupAction(typ, provider, name).(*Action[In, Out, Stream])
a := LookupAction(typ, provider, name)
if a == nil {
return nil
}
return a.(*Action[In, Out, Stream])
}

// listActions returns a list of descriptions of all registered actions.
Expand Down
55 changes: 41 additions & 14 deletions go/plugins/localvec/localvec.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,47 @@ import (
"github.com/firebase/genkit/go/core/logger"
)

// New returns a new local vector database. This will register a new
// retriever with genkit, and also return it.
const provider = "devLocalVectorStore"

// Config configures the plugin.
type Config struct {
// Where to store the data.
Dir string
Name string
Embedder *ai.EmbedderAction
EmbedderOptions any
}

// Init initializes a new local vector database. This will register a new
// indexer and retriever with genkit.
// This retriever may only be used by a single goroutine at a time.
// This is based on js/plugins/dev-local-vectorstore/src/index.ts.
func New(ctx context.Context, dir, name string, embedder *ai.EmbedderAction, embedderOptions any) (ai.DocumentStore, error) {
r, err := newDocStore(ctx, dir, name, embedder, embedderOptions)
func Init(ctx context.Context, cfg Config) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("localvec.Init: %w", err)
}
}()
ds, err := newDocStore(cfg.Dir, cfg.Name, cfg.Embedder, cfg.EmbedderOptions)
if err != nil {
return nil, err
return err
}
return ai.DefineDocumentStore("devLocalVectorStore-"+name, r.Index, r.Retrieve), nil
ai.DefineIndexer(provider, cfg.Name, ds.index)
ai.DefineRetriever(provider, cfg.Name, ds.retrieve)
return nil
}

// Indexer returns the indexer with the given name.
func Indexer(name string) *ai.IndexerAction {
return ai.LookupIndexer(provider, name)
}

// docStore implements the [ai.DocumentStore] interface
// for a local vector database.
// Retriever returns the retriever with the given name.
func Retriever(name string) *ai.RetrieverAction {
return ai.LookupRetriever(provider, name)
}

// docStore implements a local vector database.
// This is based on js/plugins/dev-local-vectorstore/src/index.ts.
type docStore struct {
filename string
embedder *ai.EmbedderAction
Expand All @@ -62,7 +89,7 @@ type dbValue struct {
}

// newDocStore returns a new ai.DocumentStore to register.
func newDocStore(ctx context.Context, dir, name string, embedder *ai.EmbedderAction, embedderOptions any) (ai.DocumentStore, error) {
func newDocStore(dir, name string, embedder *ai.EmbedderAction, embedderOptions any) (*docStore, error) {
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, err
}
Expand Down Expand Up @@ -91,8 +118,8 @@ func newDocStore(ctx context.Context, dir, name string, embedder *ai.EmbedderAct
return ds, nil
}

// Index implements the genkit [ai.DocumentStore.Index] method.
func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
// index indexes a document.
func (ds *docStore) index(ctx context.Context, req *ai.IndexerRequest) error {
for _, doc := range req.Documents {
ereq := &ai.EmbedRequest{
Document: doc,
Expand Down Expand Up @@ -152,8 +179,8 @@ type RetrieverOptions struct {
K int `json:"k,omitempty"` // number of entries to return
}

// Retrieve implements the genkit [ai.DocumentStore.Retrieve] method.
func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// retrieve retrieves documents close to the argument.
func (ds *docStore) retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// Use the embedder to convert the document we want to
// retrieve into a vector.
ereq := &ai.EmbedRequest{
Expand Down
18 changes: 9 additions & 9 deletions go/plugins/localvec/localvec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ func TestLocalVec(t *testing.T) {
embedder.Register(d2, v2)
embedder.Register(d3, v3)
embedAction := ai.DefineEmbedder("fake", "embedder1", embedder.Embed)
ds, err := newDocStore(ctx, t.TempDir(), "testLocalVec", embedAction, nil)
ds, err := newDocStore(t.TempDir(), "testLocalVec", embedAction, nil)
if err != nil {
t.Fatal(err)
}

indexerReq := &ai.IndexerRequest{
Documents: []*ai.Document{d1, d2, d3},
}
err = ds.Index(ctx, indexerReq)
err = ds.index(ctx, indexerReq)
if err != nil {
t.Fatalf("Index operation failed: %v", err)
}
Expand All @@ -72,7 +72,7 @@ func TestLocalVec(t *testing.T) {
Document: d1,
Options: retrieverOptions,
}
retrieverResp, err := ds.Retrieve(ctx, retrieverReq)
retrieverResp, err := ds.retrieve(ctx, retrieverReq)
if err != nil {
t.Fatalf("Retrieve operation failed: %v", err)
}
Expand Down Expand Up @@ -114,15 +114,15 @@ func TestPersistentIndexing(t *testing.T) {

tDir := t.TempDir()

ds, err := newDocStore(ctx, tDir, "testLocalVec", embedAction, nil)
ds, err := newDocStore(tDir, "testLocalVec", embedAction, nil)
if err != nil {
t.Fatal(err)
}

indexerReq := &ai.IndexerRequest{
Documents: []*ai.Document{d1, d2},
}
err = ds.Index(ctx, indexerReq)
err = ds.index(ctx, indexerReq)
if err != nil {
t.Fatalf("Index operation failed: %v", err)
}
Expand All @@ -135,7 +135,7 @@ func TestPersistentIndexing(t *testing.T) {
Document: d1,
Options: retrieverOptions,
}
retrieverResp, err := ds.Retrieve(ctx, retrieverReq)
retrieverResp, err := ds.retrieve(ctx, retrieverReq)
if err != nil {
t.Fatalf("Retrieve operation failed: %v", err)
}
Expand All @@ -145,15 +145,15 @@ func TestPersistentIndexing(t *testing.T) {
t.Errorf("got %d results, expected 2", len(docs))
}

dsAnother, err := newDocStore(ctx, tDir, "testLocalVec", embedAction, nil)
dsAnother, err := newDocStore(tDir, "testLocalVec", embedAction, nil)
if err != nil {
t.Fatal(err)
}

indexerReq = &ai.IndexerRequest{
Documents: []*ai.Document{d3},
}
err = dsAnother.Index(ctx, indexerReq)
err = dsAnother.index(ctx, indexerReq)
if err != nil {
t.Fatalf("Index operation failed: %v", err)
}
Expand All @@ -166,7 +166,7 @@ func TestPersistentIndexing(t *testing.T) {
Document: d1,
Options: retrieverOptions,
}
retrieverResp, err = dsAnother.Retrieve(ctx, retrieverReq)
retrieverResp, err = dsAnother.retrieve(ctx, retrieverReq)
if err != nil {
t.Fatalf("Retrieve operation failed: %v", err)
}
Expand Down
75 changes: 50 additions & 25 deletions go/plugins/pinecone/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"github.com/firebase/genkit/go/ai"
)

const provider = "pinecone"

// defaultTextKey is the default metadata key we use to store the
// document text as metadata of the documented stored in pinecone.
// This lets us map back from the document returned by a query to
Expand All @@ -36,40 +38,63 @@ import (
// documents in pinecone.
const defaultTextKey = "_content"

// New returns an [ai.DocumentStore] that uses Pinecone.
//
// apiKey is the API key to use to access Pinecone.
// If it is the empty string, it is read from the PINECONE_API_KEY
// environment variable.
//
// host is the controller host name. This implies which index to use.
// If it is the empty string, it is read from the PINECONE_CONTROLLER_HOST
// environment variable.
//
// The caller must pass an Embedder to use.
//
// The textKey parameter is the metadata key to use to store document text
// in Pinecone; the default is "_content".
func New(ctx context.Context, apiKey, host string, embedder *ai.EmbedderAction, embedderOptions any, textKey string) (ai.DocumentStore, error) {
client, err := NewClient(ctx, apiKey)
// Config configures the plugin.
type Config struct {
// API key for Pinecone.
// If it is the empty string, it is read from the PINECONE_API_KEY
// environment variable.
APIKey string
// The controller host name. This implies which index to use.
// If it is the empty string, it is read from the PINECONE_CONTROLLER_HOST
// environment variable.
Host string
// Embedder to use. Required.
Embedder *ai.EmbedderAction
EmbedderOptions any
// The metadata key to use to store document text
// in Pinecone; the default is "_content".
TextKey string
}

// Init initializes the Pinecone plugin.
func Init(ctx context.Context, cfg Config) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("pinecone.Init: %w", err)
}
}()

client, err := NewClient(ctx, cfg.APIKey)
if err != nil {
return nil, err
return err
}
index, err := client.Index(ctx, host)
index, err := client.Index(ctx, cfg.Host)
if err != nil {
return nil, err
return err
}
if textKey == "" {
textKey = defaultTextKey
if cfg.TextKey == "" {
cfg.TextKey = defaultTextKey
}
r := &docStore{
index: index,
embedder: embedder,
embedderOptions: embedderOptions,
textKey: textKey,
embedder: cfg.Embedder,
embedderOptions: cfg.EmbedderOptions,
textKey: cfg.TextKey,
}
name := index.host // the resolved host
ai.DefineIndexer(provider, name, r.Index)
ai.DefineRetriever(provider, name, r.Retrieve)
return nil
}

// Indexer returns the indexer with the given index name.
func Indexer(name string) *ai.IndexerAction {
return ai.LookupIndexer(provider, name)
}

return ai.DefineDocumentStore("pinecone", r.Index, r.Retrieve), nil
// Retriever returns the retriever with the given index name.
func Retriever(name string) *ai.RetrieverAction {
return ai.LookupRetriever(provider, name)
}

// IndexerOptions may be passed in the Options field
Expand Down
14 changes: 9 additions & 5 deletions go/plugins/pinecone/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,13 @@ func TestGenkit(t *testing.T) {
embedder.Register(d1, v1)
embedder.Register(d2, v2)
embedder.Register(d3, v3)
embedAction := ai.DefineEmbedder("fake", "embedder3", embedder.Embed)

r, err := New(ctx, *testAPIKey, indexData.Host, embedAction, nil, "")
if err != nil {
cfg := Config{
APIKey: *testAPIKey,
Host: indexData.Host,
Embedder: ai.DefineEmbedder("fake", "embedder3", embedder.Embed),
}
if err := Init(ctx, cfg); err != nil {
t.Fatal(err)
}

Expand All @@ -86,7 +89,8 @@ func TestGenkit(t *testing.T) {
Documents: []*ai.Document{d1, d2, d3},
Options: indexerOptions,
}
err = r.Index(ctx, indexerReq)
t.Logf("index flag = %q, indexData.Host = %q", *testIndex, indexData.Host)
err = ai.Index(ctx, Indexer(indexData.Host), indexerReq)
if err != nil {
t.Fatalf("Index operation failed: %v", err)
}
Expand Down Expand Up @@ -123,7 +127,7 @@ func TestGenkit(t *testing.T) {
Document: d1,
Options: retrieverOptions,
}
retrieverResp, err := r.Retrieve(ctx, retrieverReq)
retrieverResp, err := ai.Retrieve(ctx, Retriever(indexData.Host), retrieverReq)
if err != nil {
t.Fatalf("Retrieve operation failed: %v", err)
}
Expand Down
Loading
Loading