diff --git a/go/plugins/vertexai/embed.go b/go/plugins/vertexai/embed.go index 856b3cc8a..53d04ec0b 100644 --- a/go/plugins/vertexai/embed.go +++ b/go/plugins/vertexai/embed.go @@ -17,13 +17,10 @@ package vertexai import ( "context" "errors" - "fmt" - "runtime" aiplatform "cloud.google.com/go/aiplatform/apiv1" "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" "github.com/firebase/genkit/go/ai" - "google.golang.org/api/option" "google.golang.org/protobuf/types/known/structpb" ) @@ -37,49 +34,30 @@ type EmbedOptions struct { TaskType string `json:"task_type,omitempty"` } -// NewEmbedder returns an [ai.Embedder] that can compute the embedding -// of an input document. -func NewEmbedder(ctx context.Context, model, projectID, location string) (ai.Embedder, error) { - endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location) - numConns := max(runtime.GOMAXPROCS(0), 4) - o := []option.ClientOption{ - option.WithEndpoint(endpoint), - option.WithGRPCConnectionPool(numConns), +func embed(ctx context.Context, reqEndpoint string, client *aiplatform.PredictionClient, req *ai.EmbedRequest) ([]float32, error) { + preq, err := newPredictRequest(reqEndpoint, req) + if err != nil { + return nil, err } - - client, err := aiplatform.NewPredictionClient(ctx, o...) + resp, err := client.Predict(ctx, preq) if err != nil { return nil, err } - reqEndpoint := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", projectID, location, model) + // TODO(ianlancetaylor): This can return multiple vectors. + // We just use the first one for now. - e := ai.DefineEmbedder("google-vertexai", model, func(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) { - preq, err := newPredictRequest(reqEndpoint, req) - if err != nil { - return nil, err - } - resp, err := client.Predict(ctx, preq) - if err != nil { - return nil, err - } - - // TODO(ianlancetaylor): This can return multiple vectors. - // We just use the first one for now. - - if len(resp.Predictions) < 1 { - return nil, errors.New("vertexai: embed request returned no values") - } + if len(resp.Predictions) < 1 { + return nil, errors.New("vertexai: embed request returned no values") + } - values := resp.Predictions[0].GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values - ret := make([]float32, len(values)) - for i, value := range values { - ret[i] = float32(value.GetNumberValue()) - } + values := resp.Predictions[0].GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values + ret := make([]float32, len(values)) + for i, value := range values { + ret[i] = float32(value.GetNumberValue()) + } - return ret, nil - }) - return e, nil + return ret, nil } func newPredictRequest(endpoint string, req *ai.EmbedRequest) (*aiplatformpb.PredictRequest, error) { diff --git a/go/plugins/vertexai/embed_test.go b/go/plugins/vertexai/embed_test.go deleted file mode 100644 index 00f20d929..000000000 --- a/go/plugins/vertexai/embed_test.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package vertexai_test - -import ( - "context" - "testing" - - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/plugins/vertexai" -) - -func TestEmbedder(t *testing.T) { - if *projectID == "" { - t.Skip("no -projectid provided") - } - - ctx := context.Background() - - e, err := vertexai.NewEmbedder(ctx, "textembedding-gecko", *projectID, *location) - if err != nil { - t.Fatal(err) - } - - out, err := e.Embed(ctx, &ai.EmbedRequest{ - Document: ai.DocumentFromText("time flies like an arrow", nil), - }) - if err != nil { - t.Fatal(err) - } - - // There's not a whole lot we can test about the result. - // Just do a few sanity checks. - if len(out) < 100 { - t.Errorf("embedding vector looks too short: len(out)=%d", len(out)) - } - var normSquared float32 - for _, x := range out { - normSquared += x * x - } - if normSquared < 0.9 || normSquared > 1.1 { - t.Errorf("embedding vector not unit length: %f", normSquared) - } -} diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 1b4e89273..f750e3e14 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -16,15 +16,107 @@ package vertexai import ( "context" + "errors" "fmt" + "runtime" + aiplatform "cloud.google.com/go/aiplatform/apiv1" "cloud.google.com/go/vertexai/genai" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/plugins/internal/uri" + "google.golang.org/api/option" ) -func newClient(ctx context.Context, projectID, location string) (*genai.Client, error) { - return genai.NewClient(ctx, projectID, location) +const provider = "google-genai" + +// Config provides configuration options for the Init function. +type Config struct { + // The project holding the resources. + ProjectID string + // The location of the resources. + // Defaults to "us-central1". + Location string + // Generative models to provide. + Models []string + // Embedding models to provide. + Embedders []string +} + +func Init(ctx context.Context, cfg Config) (err error) { + defer func() { + if err != nil { + err = fmt.Errorf("vertexai.Init: %w", err) + } + }() + + if cfg.ProjectID == "" { + return errors.New("missing ProjectID") + } + if cfg.Location == "" { + cfg.Location = "us-central1" + } + // TODO(#345): call ListModels. See googleai.go. + if len(cfg.Models) == 0 && len(cfg.Embedders) == 0 { + return errors.New("need at least one model or embedder") + } + + if err := initModels(ctx, cfg); err != nil { + return err + } + return initEmbedders(ctx, cfg) +} + +func initModels(ctx context.Context, cfg Config) error { + // Client for Gemini SDK. + gclient, err := genai.NewClient(ctx, cfg.ProjectID, cfg.Location) + if err != nil { + return err + } + for _, name := range cfg.Models { + meta := &ai.ModelMetadata{ + Label: "Vertex AI - " + name, + Supports: ai.ModelCapabilities{ + Multiturn: true, + }, + } + g := &generator{model: name, client: gclient} + ai.DefineModel(provider, name, meta, g.generate) + } + return nil +} + +func initEmbedders(ctx context.Context, cfg Config) error { + // Client for prediction SDK. + endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", cfg.Location) + numConns := max(runtime.GOMAXPROCS(0), 4) + o := []option.ClientOption{ + option.WithEndpoint(endpoint), + option.WithGRPCConnectionPool(numConns), + } + + pclient, err := aiplatform.NewPredictionClient(ctx, o...) + if err != nil { + return err + } + for _, name := range cfg.Embedders { + fullName := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", cfg.ProjectID, cfg.Location, name) + ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) { + return embed(ctx, fullName, pclient, req) + }) + } + return nil +} + +// Model returns the [ai.ModelAction] with the given name. +// It returns nil if the model was not configured. +func Model(name string) *ai.ModelAction { + return ai.LookupModel(provider, name) +} + +// Embedder returns the embedder with the given name. +// It returns nil if the embedder was not configured. +func Embedder(name string) ai.Embedder { + return ai.LookupEmbedder(provider, name) } type generator struct { @@ -171,23 +263,6 @@ func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse return r } -// NewModel returns an [ai.ModelAction] which sends a request to -// the vertex AI model and returns the response. -func NewModel(ctx context.Context, model, projectID, location string) (*ai.ModelAction, error) { - client, err := newClient(ctx, projectID, location) - if err != nil { - return nil, err - } - g := &generator{model: model, client: client} - meta := &ai.ModelMetadata{ - Label: "Vertex AI - " + model, - Supports: ai.ModelCapabilities{ - Multiturn: true, - }, - } - return ai.DefineModel("google-vertexai", model, meta, g.generate), nil -} - // convertParts converts a slice of *ai.Part to a slice of genai.Part. func convertParts(parts []*ai.Part) ([]genai.Part, error) { res := make([]genai.Part, 0, len(parts)) diff --git a/go/plugins/vertexai/vertexai_test.go b/go/plugins/vertexai/vertexai_test.go index 48eb29b63..e5ae734ff 100644 --- a/go/plugins/vertexai/vertexai_test.go +++ b/go/plugins/vertexai/vertexai_test.go @@ -33,45 +33,28 @@ import ( var projectID = flag.String("projectid", "", "VertexAI project") var location = flag.String("location", "us-central1", "geographic location") -func TestModel(t *testing.T) { +func TestLive(t *testing.T) { if *projectID == "" { t.Skipf("no -projectid provided") } ctx := context.Background() - g, err := vertexai.NewModel(ctx, "gemini-1.0-pro", *projectID, *location) + const modelName = "gemini-1.0-pro" + const embedderName = "textembedding-gecko" + err := vertexai.Init(ctx, vertexai.Config{ + ProjectID: *projectID, + Location: *location, + Models: []string{modelName}, + Embedders: []string{embedderName}, + }) if err != nil { t.Fatal(err) } - req := &ai.GenerateRequest{ - Candidates: 1, - Messages: []*ai.Message{ - &ai.Message{ - Content: []*ai.Part{ai.NewTextPart("Which country was Napoleon the emperor of?")}, - Role: ai.RoleUser, - }, - }, - } - resp, err := ai.Generate(ctx, g, req, nil) - if err != nil { - t.Fatal(err) - } - out := resp.Candidates[0].Message.Content[0].Text - if !strings.Contains(out, "France") { - t.Errorf("got \"%s\", expecting it would contain \"France\"", out) - } - if resp.Request != req { - t.Error("Request field not set properly") + toolDef := &ai.ToolDefinition{ + Name: "exponentiation", + InputSchema: map[string]any{"base": "float64", "exponent": "int"}, + OutputSchema: map[string]any{"output": "float64"}, } -} - -var toolDef = &ai.ToolDefinition{ - Name: "exponentiation", - InputSchema: map[string]any{"base": "float64", "exponent": "int"}, - OutputSchema: map[string]any{"output": "float64"}, -} - -func init() { ai.RegisterTool(toolDef, nil, func(ctx context.Context, input map[string]any) (map[string]any, error) { baseAny, ok := input["base"] @@ -100,35 +83,70 @@ func init() { return r, nil }, ) -} + t.Run("model", func(t *testing.T) { + req := &ai.GenerateRequest{ + Candidates: 1, + Messages: []*ai.Message{ + &ai.Message{ + Content: []*ai.Part{ai.NewTextPart("Which country was Napoleon the emperor of?")}, + Role: ai.RoleUser, + }, + }, + } -func TestModelTool(t *testing.T) { - if *projectID == "" { - t.Skip("no -projectid provided") - } - ctx := context.Background() - g, err := vertexai.NewModel(ctx, "gemini-1.0-pro", *projectID, *location) - if err != nil { - t.Fatal(err) - } - req := &ai.GenerateRequest{ - Candidates: 1, - Messages: []*ai.Message{ - &ai.Message{ - Content: []*ai.Part{ai.NewTextPart("what is 3.5 squared? Use the tool provided.")}, - Role: ai.RoleUser, + resp, err := ai.Generate(ctx, vertexai.Model(modelName), req, nil) + if err != nil { + t.Fatal(err) + } + out := resp.Candidates[0].Message.Content[0].Text + if !strings.Contains(out, "France") { + t.Errorf("got \"%s\", expecting it would contain \"France\"", out) + } + if resp.Request != req { + t.Error("Request field not set properly") + } + }) + t.Run("tool", func(t *testing.T) { + req := &ai.GenerateRequest{ + Candidates: 1, + Messages: []*ai.Message{ + &ai.Message{ + Content: []*ai.Part{ai.NewTextPart("what is 3.5 squared? Use the tool provided.")}, + Role: ai.RoleUser, + }, }, - }, - Tools: []*ai.ToolDefinition{toolDef}, - } + Tools: []*ai.ToolDefinition{toolDef}, + } - resp, err := ai.Generate(ctx, g, req, nil) - if err != nil { - t.Fatal(err) - } + resp, err := ai.Generate(ctx, vertexai.Model(modelName), req, nil) + if err != nil { + t.Fatal(err) + } - out := resp.Candidates[0].Message.Content[0].Text - if !strings.Contains(out, "12.25") { - t.Errorf("got %s, expecting it to contain \"12.25\"", out) - } + out := resp.Candidates[0].Message.Content[0].Text + if !strings.Contains(out, "12.25") { + t.Errorf("got %s, expecting it to contain \"12.25\"", out) + } + }) + t.Run("embedder", func(t *testing.T) { + out, err := vertexai.Embedder(embedderName).Embed(ctx, &ai.EmbedRequest{ + Document: ai.DocumentFromText("time flies like an arrow", nil), + }) + if err != nil { + t.Fatal(err) + } + + // There's not a whole lot we can test about the result. + // Just do a few sanity checks. + if len(out) < 100 { + t.Errorf("embedding vector looks too short: len(out)=%d", len(out)) + } + var normSquared float32 + for _, x := range out { + normSquared += x * x + } + if normSquared < 0.9 || normSquared > 1.1 { + t.Errorf("embedding vector not unit length: %f", normSquared) + } + }) } diff --git a/go/samples/menu/main.go b/go/samples/menu/main.go index 0db4aa69b..acf688a37 100644 --- a/go/samples/menu/main.go +++ b/go/samples/menu/main.go @@ -26,7 +26,10 @@ import ( "github.com/invopop/jsonschema" ) -const geminiPro = "gemini-1.0-pro" +const ( + geminiPro = "gemini-1.0-pro" + embeddingGecko = "textembedding-gecko" +) // menuItem is the data model for an item on the menu. type menuItem struct { @@ -75,42 +78,36 @@ func main() { os.Exit(1) } - location := "us-central1" - if env := os.Getenv("GCLOUD_LOCATION"); env != "" { - location = env - } - - gen, err := vertexai.NewModel(context.Background(), geminiPro, projectID, location) - if err != nil { - log.Fatal(err) - } - genVision, err := vertexai.NewModel(context.Background(), geminiPro+"-vision", projectID, location) - ctx := context.Background() - if err := setup01(ctx, gen); err != nil { + err := vertexai.Init(ctx, vertexai.Config{ + ProjectID: projectID, + Location: os.Getenv("GCLOUD_LOCATION"), + Models: []string{geminiPro, geminiPro + "-vision"}, + Embedders: []string{embeddingGecko}, + }) + + model := vertexai.Model(geminiPro) + if err := setup01(ctx, model); err != nil { log.Fatal(err) } - if err := setup02(ctx, gen); err != nil { + if err := setup02(ctx, model); err != nil { log.Fatal(err) } - if err := setup03(ctx, gen); err != nil { + if err := setup03(ctx, model); err != nil { log.Fatal(err) } - embedder, err := vertexai.NewEmbedder(ctx, "textembedding-gecko", projectID, location) - if err != nil { - log.Fatal(err) - } + embedder := vertexai.Embedder(embeddingGecko) ds, err := localvec.New(ctx, os.TempDir(), "go-menu-items", embedder, nil) if err != nil { log.Fatal(err) } - if err := setup04(ctx, ds, gen); err != nil { + if err := setup04(ctx, ds, model); err != nil { log.Fatal(err) } - if err := setup05(ctx, gen, genVision); err != nil { + if err := setup05(ctx, model, vertexai.Model(geminiPro+"-vision")); err != nil { log.Fatal(err) }