Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] change vertexai plugin to new style #347

Merged
merged 2 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 16 additions & 38 deletions go/plugins/vertexai/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@ package vertexai
import (
"context"
"errors"
"fmt"
"runtime"

aiplatform "cloud.google.com/go/aiplatform/apiv1"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
"github.com/firebase/genkit/go/ai"
"google.golang.org/api/option"
"google.golang.org/protobuf/types/known/structpb"
)

Expand All @@ -37,49 +34,30 @@ type EmbedOptions struct {
TaskType string `json:"task_type,omitempty"`
}

// NewEmbedder returns an [ai.Embedder] that can compute the embedding
// of an input document.
func NewEmbedder(ctx context.Context, model, projectID, location string) (ai.Embedder, error) {
endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
numConns := max(runtime.GOMAXPROCS(0), 4)
o := []option.ClientOption{
option.WithEndpoint(endpoint),
option.WithGRPCConnectionPool(numConns),
func embed(ctx context.Context, reqEndpoint string, client *aiplatform.PredictionClient, req *ai.EmbedRequest) ([]float32, error) {
preq, err := newPredictRequest(reqEndpoint, req)
if err != nil {
return nil, err
}

client, err := aiplatform.NewPredictionClient(ctx, o...)
resp, err := client.Predict(ctx, preq)
if err != nil {
return nil, err
}

reqEndpoint := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", projectID, location, model)
// TODO(ianlancetaylor): This can return multiple vectors.
// We just use the first one for now.

e := ai.DefineEmbedder("google-vertexai", model, func(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) {
preq, err := newPredictRequest(reqEndpoint, req)
if err != nil {
return nil, err
}
resp, err := client.Predict(ctx, preq)
if err != nil {
return nil, err
}

// TODO(ianlancetaylor): This can return multiple vectors.
// We just use the first one for now.

if len(resp.Predictions) < 1 {
return nil, errors.New("vertexai: embed request returned no values")
}
if len(resp.Predictions) < 1 {
return nil, errors.New("vertexai: embed request returned no values")
}

values := resp.Predictions[0].GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values
ret := make([]float32, len(values))
for i, value := range values {
ret[i] = float32(value.GetNumberValue())
}
values := resp.Predictions[0].GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values
ret := make([]float32, len(values))
for i, value := range values {
ret[i] = float32(value.GetNumberValue())
}

return ret, nil
})
return e, nil
return ret, nil
}

func newPredictRequest(endpoint string, req *ai.EmbedRequest) (*aiplatformpb.PredictRequest, error) {
Expand Down
56 changes: 0 additions & 56 deletions go/plugins/vertexai/embed_test.go

This file was deleted.

113 changes: 94 additions & 19 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,107 @@ package vertexai

import (
"context"
"errors"
"fmt"
"runtime"

aiplatform "cloud.google.com/go/aiplatform/apiv1"
"cloud.google.com/go/vertexai/genai"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/plugins/internal/uri"
"google.golang.org/api/option"
)

func newClient(ctx context.Context, projectID, location string) (*genai.Client, error) {
return genai.NewClient(ctx, projectID, location)
const provider = "google-genai"

// Config provides configuration options for the Init function.
type Config struct {
// The project holding the resources.
ProjectID string
// The location of the resources.
// Defaults to "us-central1".
Location string
// Generative models to provide.
Models []string
// Embedding models to provide.
Embedders []string
}

func Init(ctx context.Context, cfg Config) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("vertexai.Init: %w", err)
}
}()

if cfg.ProjectID == "" {
return errors.New("missing ProjectID")
}
if cfg.Location == "" {
cfg.Location = "us-central1"
}
// TODO(#345): call ListModels. See googleai.go.
if len(cfg.Models) == 0 && len(cfg.Embedders) == 0 {
return errors.New("need at least one model or embedder")
}

if err := initModels(ctx, cfg); err != nil {
return err
}
return initEmbedders(ctx, cfg)
}

func initModels(ctx context.Context, cfg Config) error {
// Client for Gemini SDK.
gclient, err := genai.NewClient(ctx, cfg.ProjectID, cfg.Location)
if err != nil {
return err
}
for _, name := range cfg.Models {
meta := &ai.ModelMetadata{
Label: "Vertex AI - " + name,
Supports: ai.ModelCapabilities{
Multiturn: true,
},
}
g := &generator{model: name, client: gclient}
ai.DefineModel(provider, name, meta, g.generate)
}
return nil
}

func initEmbedders(ctx context.Context, cfg Config) error {
// Client for prediction SDK.
endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", cfg.Location)
numConns := max(runtime.GOMAXPROCS(0), 4)
o := []option.ClientOption{
option.WithEndpoint(endpoint),
option.WithGRPCConnectionPool(numConns),
}

pclient, err := aiplatform.NewPredictionClient(ctx, o...)
if err != nil {
return err
}
for _, name := range cfg.Embedders {
fullName := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", cfg.ProjectID, cfg.Location, name)
ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) {
return embed(ctx, fullName, pclient, req)
})
}
return nil
}

// Model returns the [ai.ModelAction] with the given name.
// It returns nil if the model was not configured.
func Model(name string) *ai.ModelAction {
return ai.LookupModel(provider, name)
}

// Embedder returns the embedder with the given name.
// It returns nil if the embedder was not configured.
func Embedder(name string) ai.Embedder {
return ai.LookupEmbedder(provider, name)
}

type generator struct {
Expand Down Expand Up @@ -171,23 +263,6 @@ func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse
return r
}

// NewModel returns an [ai.ModelAction] which sends a request to
// the vertex AI model and returns the response.
func NewModel(ctx context.Context, model, projectID, location string) (*ai.ModelAction, error) {
client, err := newClient(ctx, projectID, location)
if err != nil {
return nil, err
}
g := &generator{model: model, client: client}
meta := &ai.ModelMetadata{
Label: "Vertex AI - " + model,
Supports: ai.ModelCapabilities{
Multiturn: true,
},
}
return ai.DefineModel("google-vertexai", model, meta, g.generate), nil
}

// convertParts converts a slice of *ai.Part to a slice of genai.Part.
func convertParts(parts []*ai.Part) ([]genai.Part, error) {
res := make([]genai.Part, 0, len(parts))
Expand Down
Loading
Loading