diff --git a/vertexai/genai/aiplatformpb_veneer.gen.go b/vertexai/genai/aiplatformpb_veneer.gen.go index b9c74a672df0..1fe4b3c0408f 100644 --- a/vertexai/genai/aiplatformpb_veneer.gen.go +++ b/vertexai/genai/aiplatformpb_veneer.gen.go @@ -479,15 +479,15 @@ func (GenerateContentResponse) fromProto(p *pb.GenerateContentResponse) *Generat // GenerationConfig is generation config. type GenerationConfig struct { // Optional. Controls the randomness of predictions. - Temperature float32 + Temperature *float32 // Optional. If specified, nucleus sampling will be used. - TopP float32 + TopP *float32 // Optional. If specified, top-k sampling will be used. - TopK float32 + TopK *float32 // Optional. Number of candidates to generate. - CandidateCount int32 + CandidateCount *int32 // Optional. The maximum number of output tokens to generate per message. - MaxOutputTokens int32 + MaxOutputTokens *int32 // Optional. Stop sequences. StopSequences []string } @@ -497,11 +497,11 @@ func (v *GenerationConfig) toProto() *pb.GenerationConfig { return nil } return &pb.GenerationConfig{ - Temperature: support.AddrOrNil(v.Temperature), - TopP: support.AddrOrNil(v.TopP), - TopK: support.AddrOrNil(v.TopK), - CandidateCount: support.AddrOrNil(v.CandidateCount), - MaxOutputTokens: support.AddrOrNil(v.MaxOutputTokens), + Temperature: v.Temperature, + TopP: v.TopP, + TopK: v.TopK, + CandidateCount: v.CandidateCount, + MaxOutputTokens: v.MaxOutputTokens, StopSequences: v.StopSequences, } } @@ -511,11 +511,11 @@ func (GenerationConfig) fromProto(p *pb.GenerationConfig) *GenerationConfig { return nil } return &GenerationConfig{ - Temperature: support.DerefOrZero(p.Temperature), - TopP: support.DerefOrZero(p.TopP), - TopK: support.DerefOrZero(p.TopK), - CandidateCount: support.DerefOrZero(p.CandidateCount), - MaxOutputTokens: support.DerefOrZero(p.MaxOutputTokens), + Temperature: p.Temperature, + TopP: p.TopP, + TopK: p.TopK, + CandidateCount: p.CandidateCount, + MaxOutputTokens: p.MaxOutputTokens, StopSequences: p.StopSequences, } } diff --git a/vertexai/genai/client.go b/vertexai/genai/client.go index 8f94efac0fec..9fb5db4ede84 100644 --- a/vertexai/genai/client.go +++ b/vertexai/genai/client.go @@ -90,10 +90,6 @@ const defaultMaxOutputTokens = 2048 // GenerativeModel creates a new instance of the named model. func (c *Client) GenerativeModel(name string) *GenerativeModel { return &GenerativeModel{ - GenerationConfig: GenerationConfig{ - MaxOutputTokens: defaultMaxOutputTokens, - TopK: 3, - }, c: c, name: name, fullName: fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", c.projectID, c.location, name), diff --git a/vertexai/genai/client_test.go b/vertexai/genai/client_test.go index 6a3bec412e00..f6f5dc9ae5d5 100644 --- a/vertexai/genai/client_test.go +++ b/vertexai/genai/client_test.go @@ -49,7 +49,7 @@ func TestLive(t *testing.T) { } defer client.Close() model := client.GenerativeModel(*modelName) - model.Temperature = 0 + model.Temperature = Ptr[float32](0) t.Run("GenerateContent", func(t *testing.T) { resp, err := model.GenerateContent(ctx, Text("What is the average size of a swallow?")) @@ -104,16 +104,16 @@ func TestLive(t *testing.T) { checkMatch(t, send("Which is best?", true), - "best", "air fryer", "Philips", "([Cc]onsider|research|compare)", "factors|features") + "best", "air fryer", "Philips", "([Cc]onsider|research|compare|preference)", "factors|features") checkMatch(t, send("Say that again.", false), - "best", "air fryer", "Philips", "([Cc]onsider|research|compare)", "factors|features") + "best", "air fryer", "Philips", "([Cc]onsider|research|compare|preference)", "factors|features") }) t.Run("image", func(t *testing.T) { vmodel := client.GenerativeModel(*modelName + "-vision") - vmodel.Temperature = 0 + vmodel.Temperature = Ptr[float32](0) data, err := os.ReadFile(filepath.Join("testdata", imageFile)) if err != nil { @@ -168,8 +168,8 @@ func TestLive(t *testing.T) { }) t.Run("max-tokens", func(t *testing.T) { maxModel := client.GenerativeModel(*modelName) - maxModel.Temperature = 0 - maxModel.MaxOutputTokens = 10 + maxModel.Temperature = Ptr(float32(0)) + maxModel.SetMaxOutputTokens(10) res, err := maxModel.GenerateContent(ctx, Text("What is a dog?")) if err != nil { t.Fatal(err) @@ -182,8 +182,8 @@ func TestLive(t *testing.T) { }) t.Run("max-tokens-streaming", func(t *testing.T) { maxModel := client.GenerativeModel(*modelName) - maxModel.Temperature = 0 - maxModel.MaxOutputTokens = 10 + maxModel.Temperature = Ptr[float32](0) + maxModel.MaxOutputTokens = Ptr[int32](10) iter := maxModel.GenerateContentStream(ctx, Text("What is a dog?")) var merged *GenerateContentResponse for { @@ -232,7 +232,7 @@ func TestLive(t *testing.T) { }}, } model := client.GenerativeModel(*modelName) - model.Temperature = 0 + model.SetTemperature(0) model.Tools = []*Tool{weatherTool} session := model.StartChat() res, err := session.SendMessage(ctx, Text("What is the weather like in New York?")) @@ -470,3 +470,19 @@ func TestMatchString(t *testing.T) { } } } + +func TestTemperature(t *testing.T) { + m := &GenerativeModel{} + got := m.GenerationConfig.toProto().Temperature + if got != nil { + t.Errorf("got %v, want nil", got) + } + m.SetTemperature(0) + got = m.GenerationConfig.toProto().Temperature + if got == nil { + t.Fatal("got nil") + } + if g := *got; g != 0 { + t.Errorf("got %v, want 0", g) + } +} diff --git a/vertexai/genai/config.yaml b/vertexai/genai/config.yaml index 062eff25a013..a124f7c0946b 100644 --- a/vertexai/genai/config.yaml +++ b/vertexai/genai/config.yaml @@ -59,17 +59,6 @@ types: FunctionResponse: GenerationConfig: - fields: - Temperature: - type: float32 - TopP: - type: float32 - TopK: - type: float32 - CandidateCount: - type: int32 - MaxOutputTokens: - type: int32 SafetyRating: docVerb: 'is the' diff --git a/vertexai/genai/content.go b/vertexai/genai/content.go index 29fc787f0794..6d94d10904ff 100644 --- a/vertexai/genai/content.go +++ b/vertexai/genai/content.go @@ -113,3 +113,24 @@ func ImageData(format string, data []byte) Blob { Data: data, } } + +// Ptr returns a pointer to its argument. +// It can be used to initialize pointer fields: +// +// model.Temperature = genai.Ptr[float32](0.1) +func Ptr[T any](t T) *T { return &t } + +// SetCandidateCount sets the CandidateCount field. +func (c *GenerationConfig) SetCandidateCount(x int32) { c.CandidateCount = &x } + +// SetMaxOutputTokens sets the MaxOutputTokens field. +func (c *GenerationConfig) SetMaxOutputTokens(x int32) { c.MaxOutputTokens = &x } + +// SetTemperature sets the Temperature field. +func (c *GenerationConfig) SetTemperature(x float32) { c.Temperature = &x } + +// SetTopP sets the TopP field. +func (c *GenerationConfig) SetTopP(x float32) { c.TopP = &x } + +// SetTopK sets the TopK field. +func (c *GenerationConfig) SetTopK(x float32) { c.TopK = &x } diff --git a/vertexai/genai/example_test.go b/vertexai/genai/example_test.go index 7442ac6bd54b..327237def1f2 100644 --- a/vertexai/genai/example_test.go +++ b/vertexai/genai/example_test.go @@ -42,7 +42,7 @@ func ExampleGenerativeModel_GenerateContent() { defer client.Close() model := client.GenerativeModel(model) - model.Temperature = 0.9 + model.SetTemperature(0.9) resp, err := model.GenerateContent(ctx, genai.Text("What is the average size of a swallow?")) if err != nil { log.Fatal(err)