Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(vertexai): use pointers for GenerationConfig fields #9182

Merged
merged 2 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions vertexai/genai/aiplatformpb_veneer.gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,15 +479,15 @@ func (GenerateContentResponse) fromProto(p *pb.GenerateContentResponse) *Generat
// GenerationConfig is generation config.
type GenerationConfig struct {
// Optional. Controls the randomness of predictions.
Temperature float32
Temperature *float32
// Optional. If specified, nucleus sampling will be used.
TopP float32
TopP *float32
// Optional. If specified, top-k sampling will be used.
TopK float32
TopK *float32
// Optional. Number of candidates to generate.
CandidateCount int32
CandidateCount *int32
// Optional. The maximum number of output tokens to generate per message.
MaxOutputTokens int32
MaxOutputTokens *int32
// Optional. Stop sequences.
StopSequences []string
}
Expand All @@ -497,11 +497,11 @@ func (v *GenerationConfig) toProto() *pb.GenerationConfig {
return nil
}
return &pb.GenerationConfig{
Temperature: support.AddrOrNil(v.Temperature),
TopP: support.AddrOrNil(v.TopP),
TopK: support.AddrOrNil(v.TopK),
CandidateCount: support.AddrOrNil(v.CandidateCount),
MaxOutputTokens: support.AddrOrNil(v.MaxOutputTokens),
Temperature: v.Temperature,
TopP: v.TopP,
TopK: v.TopK,
CandidateCount: v.CandidateCount,
MaxOutputTokens: v.MaxOutputTokens,
StopSequences: v.StopSequences,
}
}
Expand All @@ -511,11 +511,11 @@ func (GenerationConfig) fromProto(p *pb.GenerationConfig) *GenerationConfig {
return nil
}
return &GenerationConfig{
Temperature: support.DerefOrZero(p.Temperature),
TopP: support.DerefOrZero(p.TopP),
TopK: support.DerefOrZero(p.TopK),
CandidateCount: support.DerefOrZero(p.CandidateCount),
MaxOutputTokens: support.DerefOrZero(p.MaxOutputTokens),
Temperature: p.Temperature,
TopP: p.TopP,
TopK: p.TopK,
CandidateCount: p.CandidateCount,
MaxOutputTokens: p.MaxOutputTokens,
StopSequences: p.StopSequences,
}
}
Expand Down
4 changes: 0 additions & 4 deletions vertexai/genai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,6 @@ const defaultMaxOutputTokens = 2048
// GenerativeModel creates a new instance of the named model.
func (c *Client) GenerativeModel(name string) *GenerativeModel {
return &GenerativeModel{
GenerationConfig: GenerationConfig{
MaxOutputTokens: defaultMaxOutputTokens,
TopK: 3,
},
c: c,
name: name,
fullName: fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", c.projectID, c.location, name),
Expand Down
34 changes: 25 additions & 9 deletions vertexai/genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestLive(t *testing.T) {
}
defer client.Close()
model := client.GenerativeModel(*modelName)
model.Temperature = 0
model.Temperature = Ptr[float32](0)

t.Run("GenerateContent", func(t *testing.T) {
resp, err := model.GenerateContent(ctx, Text("What is the average size of a swallow?"))
Expand Down Expand Up @@ -104,16 +104,16 @@ func TestLive(t *testing.T) {

checkMatch(t,
send("Which is best?", true),
"best", "air fryer", "Philips", "([Cc]onsider|research|compare)", "factors|features")
"best", "air fryer", "Philips", "([Cc]onsider|research|compare|preference)", "factors|features")

checkMatch(t,
send("Say that again.", false),
"best", "air fryer", "Philips", "([Cc]onsider|research|compare)", "factors|features")
"best", "air fryer", "Philips", "([Cc]onsider|research|compare|preference)", "factors|features")
})

t.Run("image", func(t *testing.T) {
vmodel := client.GenerativeModel(*modelName + "-vision")
vmodel.Temperature = 0
vmodel.Temperature = Ptr[float32](0)

data, err := os.ReadFile(filepath.Join("testdata", imageFile))
if err != nil {
Expand Down Expand Up @@ -168,8 +168,8 @@ func TestLive(t *testing.T) {
})
t.Run("max-tokens", func(t *testing.T) {
maxModel := client.GenerativeModel(*modelName)
maxModel.Temperature = 0
maxModel.MaxOutputTokens = 10
maxModel.Temperature = Ptr(float32(0))
maxModel.SetMaxOutputTokens(10)
res, err := maxModel.GenerateContent(ctx, Text("What is a dog?"))
if err != nil {
t.Fatal(err)
Expand All @@ -182,8 +182,8 @@ func TestLive(t *testing.T) {
})
t.Run("max-tokens-streaming", func(t *testing.T) {
maxModel := client.GenerativeModel(*modelName)
maxModel.Temperature = 0
maxModel.MaxOutputTokens = 10
maxModel.Temperature = Ptr[float32](0)
maxModel.MaxOutputTokens = Ptr[int32](10)
iter := maxModel.GenerateContentStream(ctx, Text("What is a dog?"))
var merged *GenerateContentResponse
for {
Expand Down Expand Up @@ -232,7 +232,7 @@ func TestLive(t *testing.T) {
}},
}
model := client.GenerativeModel(*modelName)
model.Temperature = 0
model.SetTemperature(0)
model.Tools = []*Tool{weatherTool}
session := model.StartChat()
res, err := session.SendMessage(ctx, Text("What is the weather like in New York?"))
Expand Down Expand Up @@ -470,3 +470,19 @@ func TestMatchString(t *testing.T) {
}
}
}

func TestTemperature(t *testing.T) {
m := &GenerativeModel{}
got := m.GenerationConfig.toProto().Temperature
if got != nil {
t.Errorf("got %v, want nil", got)
}
m.SetTemperature(0)
got = m.GenerationConfig.toProto().Temperature
if got == nil {
t.Fatal("got nil")
}
if g := *got; g != 0 {
t.Errorf("got %v, want 0", g)
}
}
11 changes: 0 additions & 11 deletions vertexai/genai/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,6 @@ types:
FunctionResponse:

GenerationConfig:
fields:
Temperature:
type: float32
TopP:
type: float32
TopK:
type: float32
CandidateCount:
type: int32
MaxOutputTokens:
type: int32

SafetyRating:
docVerb: 'is the'
Expand Down
21 changes: 21 additions & 0 deletions vertexai/genai/content.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,24 @@ func ImageData(format string, data []byte) Blob {
Data: data,
}
}

// Ptr returns a pointer to its argument.
// It can be used to initialize pointer fields:
//
// model.Temperature = genai.Ptr[float32](0.1)
func Ptr[T any](t T) *T { return &t }

// SetCandidateCount sets the CandidateCount field.
func (c *GenerationConfig) SetCandidateCount(x int32) { c.CandidateCount = &x }

// SetMaxOutputTokens sets the MaxOutputTokens field.
func (c *GenerationConfig) SetMaxOutputTokens(x int32) { c.MaxOutputTokens = &x }

// SetTemperature sets the Temperature field.
func (c *GenerationConfig) SetTemperature(x float32) { c.Temperature = &x }

// SetTopP sets the TopP field.
func (c *GenerationConfig) SetTopP(x float32) { c.TopP = &x }

// SetTopK sets the TopK field.
func (c *GenerationConfig) SetTopK(x float32) { c.TopK = &x }
2 changes: 1 addition & 1 deletion vertexai/genai/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func ExampleGenerativeModel_GenerateContent() {
defer client.Close()

model := client.GenerativeModel(model)
model.Temperature = 0.9
model.SetTemperature(0.9)
resp, err := model.GenerateContent(ctx, genai.Text("What is the average size of a swallow?"))
if err != nil {
log.Fatal(err)
Expand Down