Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] retrieve and index calls are always actions #291

Merged
merged 1 commit into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,32 @@ type RetrieverResponse struct {
Documents []*Document `json:"documents"`
}

// RegisterRetriever registers the actions for a specific retriever.
func RegisterRetriever(name string, retriever Retriever) {
core.RegisterAction(name,
core.NewAction(name, core.ActionTypeRetriever, nil, retriever.Retrieve))

core.RegisterAction(name,
core.NewAction(name, core.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
err := retriever.Index(ctx, req)
return struct{}{}, err
}))
// DefineRetriever takes index and retrieve functions that access a vector DB
// and returns a new Retriever that wraps them in registered actions.
func DefineRetriever(
name string,
index func(context.Context, *IndexerRequest) error,
retrieve func(context.Context, *RetrieverRequest) (*RetrieverResponse, error),
) Retriever {
ia := core.NewAction(name, core.ActionTypeIndexer, nil, func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
return struct{}{}, index(ctx, req)
})
core.RegisterAction(name, ia)
Comment on lines +59 to +62
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: add core.DefineAction function that does this -- NewAction + RegisterAction.

Ex:

export function defineAction<

ra := core.NewAction(name, core.ActionTypeRetriever, nil, retrieve)
core.RegisterAction(name, ra)
return &retriever{ia, ra}
}

type retriever struct {
index *core.Action[*IndexerRequest, struct{}, struct{}]
retrieve *core.Action[*RetrieverRequest, *RetrieverResponse, struct{}]
}

func (r *retriever) Index(ctx context.Context, req *IndexerRequest) error {
_, err := r.index.Run(ctx, req, nil)
return err
}

func (r *retriever) Retrieve(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) {
return r.retrieve.Run(ctx, req, nil)
}
12 changes: 1 addition & 11 deletions go/plugins/localvec/localvec.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,6 @@ import (
"github.com/firebase/genkit/go/core/logger"
)

// Init registers all the actions in this package with [ai]'s Register calls.
func Init(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOptions any) error {
r, err := New(ctx, dir, name, embedder, embedderOptions)
if err != nil {
return err
}
ai.RegisterRetriever("devLocalVectorStore/"+name, r)
return nil
}

// New returns a new local vector database. This will register a new
// retriever with genkit, and also return it.
// This retriever may only be used by a single goroutine at a time.
Expand All @@ -53,7 +43,7 @@ func New(ctx context.Context, dir, name string, embedder ai.Embedder, embedderOp
if err != nil {
return nil, err
}
return r, nil
return ai.DefineRetriever("devLocalVectorStore/"+name, r.Index, r.Retrieve), nil
}

// retriever implements the [ai.Retriever] interface
Expand Down
16 changes: 3 additions & 13 deletions go/plugins/pinecone/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,6 @@ import (
// documents in pinecone.
const defaultTextKey = "_content"

// Init registers all the actions in this package with [ai]'s Register calls.
//
// The arguments are as for [New].
func Init(ctx context.Context, apiKey, host string, embedder ai.Embedder, embedderOptions any, textKey string) error {
r, err := New(ctx, apiKey, host, embedder, embedderOptions, textKey)
if err != nil {
return err
}
ai.RegisterRetriever("pinecone", r)
return nil
}

// New returns an [ai.Retriever] that uses Pinecone.
//
// apiKey is the API key to use to access Pinecone.
Expand Down Expand Up @@ -80,7 +68,9 @@ func New(ctx context.Context, apiKey, host string, embedder ai.Embedder, embedde
embedderOptions: embedderOptions,
textKey: textKey,
}
return r, nil

// index := func(ctx context.Context, req *ai.IndexerRequest) error {
return ai.DefineRetriever("pinecone", r.Index, r.Retrieve), nil
}

// IndexerOptions may be passed in the Options field
Expand Down
Loading