Skip to content

Commit

Permalink
feat(vertexai): support model garden and tuned models names (googleap…
Browse files Browse the repository at this point in the history
  • Loading branch information
eliben authored May 17, 2024
1 parent 4dd2f4d commit d481e0e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
25 changes: 22 additions & 3 deletions vertexai/genai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,36 @@ type GenerativeModel struct {
const defaultMaxOutputTokens = 2048

// GenerativeModel creates a new instance of the named model.
// name is a string model name like "gemini-1.0.-pro".
// name is a string model name like "gemini-1.0-pro" or "models/gemini-1.0-pro"
// for Google-published models.
// See https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning
// for details on model naming and versioning.
// for details on model naming and versioning, and
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-garden/explore-models
// for providing model garden names. The SDK isn't familiar with custom model
// garden models, and will pass your model name to the backend API server.
func (c *Client) GenerativeModel(name string) *GenerativeModel {
return &GenerativeModel{
c: c,
name: name,
fullName: fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", c.projectID, c.location, name),
fullName: inferFullModelName(c.projectID, c.location, name),
}
}

// inferFullModelName infers the full model name (with all the required prefixes)
func inferFullModelName(project, location, name string) string {
pubName := name
if !strings.Contains(name, "/") {
pubName = "publishers/google/models/" + name
} else if strings.HasPrefix(name, "models/") {
pubName = "publishers/google/" + name
}

if !strings.HasPrefix(pubName, "publishers/") {
return pubName
}
return fmt.Sprintf("projects/%s/locations/%s/%s", project, location, pubName)
}

// Name returns the name of the model.
func (m *GenerativeModel) Name() string {
return m.name
Expand Down
19 changes: 19 additions & 0 deletions vertexai/genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -582,3 +582,22 @@ func TestIntFloatConversions(t *testing.T) {
t.Errorf("got %v, want *1", goti)
}
}

func TestInferFullModelName(t *testing.T) {
for _, test := range []struct {
name string
want string
}{
{"xyz", "projects/proj/locations/loc/publishers/google/models/xyz"},
{"models/abc", "projects/proj/locations/loc/publishers/google/models/abc"},
{"publishers/foo/xyz", "projects/proj/locations/loc/publishers/foo/xyz"},
{"x/y/z", "x/y/z"},
} {
t.Run(test.name, func(t *testing.T) {
got := inferFullModelName("proj", "loc", test.name)
if got != test.want {
t.Errorf("got %q, want %q", got, test.want)
}
})
}
}
4 changes: 4 additions & 0 deletions vertexai/genai/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ const projectID = "your-project"
const location = "some-gcp-location"

// A model name like "gemini-1.0-pro"
// For custom models from different publishers, prepent the full publisher
// prefix for the model, e.g.:
//
// model = publishers/some-publisher/models/some-model-name
const model = "some-model"

func ExampleGenerativeModel_GenerateContent() {
Expand Down

0 comments on commit d481e0e

Please sign in to comment.