Skip to content

Commit

Permalink
[Go] googleai: revised plugin design (#326)
Browse files Browse the repository at this point in the history
This illustrates the new plugin design:

- Plugins are initialized with an Init function that holds
  configuration.

- All models and other entities are defined in Init.

- Plugins have no state; all state is in the registry.

- Configured values can be retrieved by top-level functions that
  take strings as arguments.
  • Loading branch information
jba authored Jun 7, 2024
1 parent 2cdb4c8 commit 927730d
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 220 deletions.
10 changes: 10 additions & 0 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedReq
return embedder{core.DefineAction(provider, name, core.ActionTypeEmbedder, nil, embed)}
}

// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder].
// It returns nil if the Embedder was not defined.
func LookupEmbedder(provider, name string) Embedder {
action := core.LookupActionFor[*EmbedRequest, []float32, struct{}](core.ActionTypeEmbedder, provider, name)
if action == nil {
return nil
}
return embedder{action}
}

type embedder struct {
embedAction *core.Action[*EmbedRequest, []float32, struct{}]
}
Expand Down
27 changes: 10 additions & 17 deletions go/ai/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ func DefineGenerator(provider, name string, metadata *GeneratorMetadata, generat
return generator{a}
}

// LookupGenerator looks up a [Generator] registered by [DefineGenerator].
// It returns nil if the Generator was not defined.
func LookupGenerator(provider, name string) Generator {
action := core.LookupActionFor[*GenerateRequest, *GenerateResponse, *Candidate](core.ActionTypeModel, provider, name)
if action == nil {
return nil
}
return generator{action}
}

type generator struct {
generateAction *core.Action[*GenerateRequest, *GenerateResponse, *Candidate]
}
Expand Down Expand Up @@ -114,23 +124,6 @@ func Generate(ctx context.Context, g Generator, req *GenerateRequest, cb Generat
}
}

// generatorActionType is the instantiated core.Action type registered
// by RegisterGenerator.
type generatorActionType = core.Action[*GenerateRequest, *GenerateResponse, *Candidate]

// LookupGenerator looks up a [Generator] registered by [DefineGenerator].
func LookupGenerator(provider, name string) (Generator, error) {
action := core.LookupAction(core.ActionTypeModel, provider, name)
if action == nil {
return nil, fmt.Errorf("LookupGenerator: no generator action named %q/%q", provider, name)
}
actionInst, ok := action.(*generatorActionType)
if !ok {
return nil, fmt.Errorf("LookupGenerator: generator action %q has type %T, want %T", name, action, &generatorActionType{})
}
return generator{actionInst}, nil
}

// conformOutput appends a message to the request indicating conformance to the expected schema.
func conformOutput(req *GenerateRequest) error {
if req.Output != nil && req.Output.Format == OutputFormatJSON && len(req.Messages) > 0 {
Expand Down
7 changes: 7 additions & 0 deletions go/core/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ func LookupAction(typ ActionType, provider, name string) action {
return globalRegistry.lookupAction(key)
}

// LookupActionFor returns the action for the given key in the global registry,
// or nil if there is none.
// It panics if the action is of the wrong type.
func LookupActionFor[In, Out, Stream any](typ ActionType, provider, name string) *Action[In, Out, Stream] {
return LookupAction(typ, provider, name).(*Action[In, Out, Stream])
}

// listActions returns a list of descriptions of all registered actions.
// The list is sorted by action name.
func (r *registry) listActions() []actionDesc {
Expand Down
7 changes: 4 additions & 3 deletions go/plugins/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package dotprompt
import (
"context"
"errors"
"fmt"
"reflect"
"strings"

Expand Down Expand Up @@ -156,9 +157,9 @@ func (p *Prompt) Generate(ctx context.Context, pr *ai.PromptRequest, cb func(con
return nil, errors.New("dotprompt model not in provider/name format")
}

generator, err = ai.LookupGenerator(provider, name)
if err != nil {
return nil, err
generator := ai.LookupGenerator(provider, name)
if generator == nil {
return nil, fmt.Errorf("no generator named %q for provider %q", name, provider)
}
}

Expand Down
43 changes: 0 additions & 43 deletions go/plugins/googleai/embed.go

This file was deleted.

119 changes: 100 additions & 19 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ package googleai

import (
"context"
"errors"
"fmt"
"path"
"slices"

"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/plugins/internal/uri"
Expand All @@ -25,8 +28,103 @@ import (
"google.golang.org/api/option"
)

func newClient(ctx context.Context, apiKey string) (*genai.Client, error) {
return genai.NewClient(ctx, option.WithAPIKey(apiKey))
const provider = "google-genai"

// Config configures the plugin.
type Config struct {
// API key. Required.
APIKey string
// Generative models to provide.
// If empty, a complete list will be obtained from the service.
Models []string
// Embedding models to provide.
// If empty, a complete list will be obtained from the service.
Embedders []string
}

func Init(ctx context.Context, cfg Config) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("googleai.Init: %w", err)
}
}()

if cfg.APIKey == "" {
return errors.New("missing API key")
}

client, err := genai.NewClient(ctx, option.WithAPIKey(cfg.APIKey))
if err != nil {
return err
}

needModels := len(cfg.Models) == 0
needEmbedders := len(cfg.Embedders) == 0
if needModels || needEmbedders {
iter := client.ListModels(ctx)
for {
mi, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return err
}
// Model names are of the form "models/name".
name := path.Base(mi.Name)
if needModels && slices.Contains(mi.SupportedGenerationMethods, "generateContent") {
cfg.Models = append(cfg.Models, name)
}
if needEmbedders && slices.Contains(mi.SupportedGenerationMethods, "embedContent") {
cfg.Embedders = append(cfg.Embedders, name)
}
}
}
for _, name := range cfg.Models {
defineModel(name, client)
}
for _, name := range cfg.Embedders {
defineEmbedder(name, client)
}
return nil
}

func defineModel(name string, client *genai.Client) {
meta := &ai.GeneratorMetadata{
Label: "Google AI - " + name,
Supports: ai.GeneratorCapabilities{
Multiturn: true,
},
}
g := generator{model: name, client: client}
ai.DefineGenerator(provider, name, meta, g.Generate)
}

func defineEmbedder(name string, client *genai.Client) {
ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) ([]float32, error) {
em := client.EmbeddingModel(name)
parts, err := convertParts(input.Document.Content)
if err != nil {
return nil, err
}
res, err := em.EmbedContent(ctx, parts...)
if err != nil {
return nil, err
}
return res.Embedding.Values, nil
})
}

// Generator returns the generator with the given name.
// It returns nil if the generator was not configured.
func Generator(name string) ai.Generator {
return ai.LookupGenerator(provider, name)
}

// Embedder returns the embedder with the given name.
// It returns nil if the embedder was not configured.
func Embedder(name string) ai.Embedder {
return ai.LookupEmbedder(provider, name)
}

type generator struct {
Expand Down Expand Up @@ -195,23 +293,6 @@ func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse
return r
}

// NewGenerator returns an [ai.Generator] which sends a request to
// the google AI model and returns the response.
func NewGenerator(ctx context.Context, model, apiKey string) (ai.Generator, error) {
client, err := newClient(ctx, apiKey)
if err != nil {
return nil, err
}
meta := &ai.GeneratorMetadata{
Label: "Google AI - " + model,
Supports: ai.GeneratorCapabilities{
Multiturn: true,
},
}
g := generator{model: model, client: client}
return ai.DefineGenerator("google-genai", model, meta, g.Generate), nil
}

// convertParts converts a slice of *ai.Part to a slice of genai.Part.
func convertParts(parts []*ai.Part) ([]genai.Part, error) {
res := make([]genai.Part, 0, len(parts))
Expand Down
Loading

0 comments on commit 927730d

Please sign in to comment.