Skip to content

Commit

Permalink
[Go] change vertexai plugin to new style (#347)
Browse files Browse the repository at this point in the history
Same change as to googleai.
Add Init function, and Model and Embedder functions to retrieve them.
  • Loading branch information
jba authored Jun 7, 2024
1 parent e48487b commit b7be7fb
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 191 deletions.
54 changes: 16 additions & 38 deletions go/plugins/vertexai/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@ package vertexai
import (
"context"
"errors"
"fmt"
"runtime"

aiplatform "cloud.google.com/go/aiplatform/apiv1"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
"github.com/firebase/genkit/go/ai"
"google.golang.org/api/option"
"google.golang.org/protobuf/types/known/structpb"
)

Expand All @@ -37,49 +34,30 @@ type EmbedOptions struct {
TaskType string `json:"task_type,omitempty"`
}

// NewEmbedder returns an [ai.Embedder] that can compute the embedding
// of an input document.
func NewEmbedder(ctx context.Context, model, projectID, location string) (ai.Embedder, error) {
endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
numConns := max(runtime.GOMAXPROCS(0), 4)
o := []option.ClientOption{
option.WithEndpoint(endpoint),
option.WithGRPCConnectionPool(numConns),
func embed(ctx context.Context, reqEndpoint string, client *aiplatform.PredictionClient, req *ai.EmbedRequest) ([]float32, error) {
preq, err := newPredictRequest(reqEndpoint, req)
if err != nil {
return nil, err
}

client, err := aiplatform.NewPredictionClient(ctx, o...)
resp, err := client.Predict(ctx, preq)
if err != nil {
return nil, err
}

reqEndpoint := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", projectID, location, model)
// TODO(ianlancetaylor): This can return multiple vectors.
// We just use the first one for now.

e := ai.DefineEmbedder("google-vertexai", model, func(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) {
preq, err := newPredictRequest(reqEndpoint, req)
if err != nil {
return nil, err
}
resp, err := client.Predict(ctx, preq)
if err != nil {
return nil, err
}

// TODO(ianlancetaylor): This can return multiple vectors.
// We just use the first one for now.

if len(resp.Predictions) < 1 {
return nil, errors.New("vertexai: embed request returned no values")
}
if len(resp.Predictions) < 1 {
return nil, errors.New("vertexai: embed request returned no values")
}

values := resp.Predictions[0].GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values
ret := make([]float32, len(values))
for i, value := range values {
ret[i] = float32(value.GetNumberValue())
}
values := resp.Predictions[0].GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values
ret := make([]float32, len(values))
for i, value := range values {
ret[i] = float32(value.GetNumberValue())
}

return ret, nil
})
return e, nil
return ret, nil
}

func newPredictRequest(endpoint string, req *ai.EmbedRequest) (*aiplatformpb.PredictRequest, error) {
Expand Down
56 changes: 0 additions & 56 deletions go/plugins/vertexai/embed_test.go

This file was deleted.

113 changes: 94 additions & 19 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,107 @@ package vertexai

import (
"context"
"errors"
"fmt"
"runtime"

aiplatform "cloud.google.com/go/aiplatform/apiv1"
"cloud.google.com/go/vertexai/genai"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/plugins/internal/uri"
"google.golang.org/api/option"
)

func newClient(ctx context.Context, projectID, location string) (*genai.Client, error) {
return genai.NewClient(ctx, projectID, location)
const provider = "google-genai"

// Config provides configuration options for the Init function.
type Config struct {
// The project holding the resources.
ProjectID string
// The location of the resources.
// Defaults to "us-central1".
Location string
// Generative models to provide.
Models []string
// Embedding models to provide.
Embedders []string
}

func Init(ctx context.Context, cfg Config) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("vertexai.Init: %w", err)
}
}()

if cfg.ProjectID == "" {
return errors.New("missing ProjectID")
}
if cfg.Location == "" {
cfg.Location = "us-central1"
}
// TODO(#345): call ListModels. See googleai.go.
if len(cfg.Models) == 0 && len(cfg.Embedders) == 0 {
return errors.New("need at least one model or embedder")
}

if err := initModels(ctx, cfg); err != nil {
return err
}
return initEmbedders(ctx, cfg)
}

func initModels(ctx context.Context, cfg Config) error {
// Client for Gemini SDK.
gclient, err := genai.NewClient(ctx, cfg.ProjectID, cfg.Location)
if err != nil {
return err
}
for _, name := range cfg.Models {
meta := &ai.ModelMetadata{
Label: "Vertex AI - " + name,
Supports: ai.ModelCapabilities{
Multiturn: true,
},
}
g := &generator{model: name, client: gclient}
ai.DefineModel(provider, name, meta, g.generate)
}
return nil
}

func initEmbedders(ctx context.Context, cfg Config) error {
// Client for prediction SDK.
endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", cfg.Location)
numConns := max(runtime.GOMAXPROCS(0), 4)
o := []option.ClientOption{
option.WithEndpoint(endpoint),
option.WithGRPCConnectionPool(numConns),
}

pclient, err := aiplatform.NewPredictionClient(ctx, o...)
if err != nil {
return err
}
for _, name := range cfg.Embedders {
fullName := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", cfg.ProjectID, cfg.Location, name)
ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) {
return embed(ctx, fullName, pclient, req)
})
}
return nil
}

// Model returns the [ai.ModelAction] with the given name.
// It returns nil if the model was not configured.
func Model(name string) *ai.ModelAction {
return ai.LookupModel(provider, name)
}

// Embedder returns the embedder with the given name.
// It returns nil if the embedder was not configured.
func Embedder(name string) ai.Embedder {
return ai.LookupEmbedder(provider, name)
}

type generator struct {
Expand Down Expand Up @@ -171,23 +263,6 @@ func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse
return r
}

// NewModel returns an [ai.ModelAction] which sends a request to
// the vertex AI model and returns the response.
func NewModel(ctx context.Context, model, projectID, location string) (*ai.ModelAction, error) {
client, err := newClient(ctx, projectID, location)
if err != nil {
return nil, err
}
g := &generator{model: model, client: client}
meta := &ai.ModelMetadata{
Label: "Vertex AI - " + model,
Supports: ai.ModelCapabilities{
Multiturn: true,
},
}
return ai.DefineModel("google-vertexai", model, meta, g.generate), nil
}

// convertParts converts a slice of *ai.Part to a slice of genai.Part.
func convertParts(parts []*ai.Part) ([]genai.Part, error) {
res := make([]genai.Part, 0, len(parts))
Expand Down
Loading

0 comments on commit b7be7fb

Please sign in to comment.