diff --git a/go/ai/generator.go b/go/ai/generator.go index 032221789..f066910fa 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -16,7 +16,9 @@ package ai import ( "context" + "errors" "fmt" + "strings" "github.com/google/genkit/go/genkit" ) @@ -61,3 +63,28 @@ type generatorAction struct { func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb genkit.NoStream) (*GenerateResponse, error) { return ga.action.Run(ctx, input, cb) } + +// Text returns the contents of the first candidate in a +// [GenerateResponse] as a string. It returns an error if there +// are no candidates or if the candidate has no message. +func (gr *GenerateResponse) Text() (string, error) { + if len(gr.Candidates) == 0 { + return "", errors.New("no candidates returned") + } + msg := gr.Candidates[0].Message + if msg == nil { + return "", errors.New("candidate with no message") + } + if len(msg.Content) == 0 { + return "", errors.New("candidate message has no content") + } + if len(msg.Content) == 1 { + return msg.Content[0].Text(), nil + } else { + var sb strings.Builder + for _, p := range msg.Content { + sb.WriteString(p.Text()) + } + return sb.String(), nil + } +} diff --git a/go/genkit/dotprompt/genkit.go b/go/genkit/dotprompt/genkit.go index 3b20ba4ac..ee5a7ad8c 100644 --- a/go/genkit/dotprompt/genkit.go +++ b/go/genkit/dotprompt/genkit.go @@ -17,6 +17,8 @@ package dotprompt import ( "context" "errors" + "reflect" + "strings" "github.com/google/genkit/go/ai" "github.com/google/genkit/go/genkit" @@ -37,6 +39,54 @@ type ActionInput struct { Model string `json:"model,omitempty"` } +// BuildVariables returns a map for [ActionInput.Variables] based +// on a pointer to a struct value. The struct value should have +// JSON tags that correspond to the Prompt's input schema. +// Only exported fields of the struct will be used. +func (p *Prompt) BuildVariables(input any) (map[string]any, error) { + v := reflect.ValueOf(input).Elem() + if v.Kind() != reflect.Struct { + return nil, errors.New("BuildVariables: not a pointer to a struct") + } + vt := v.Type() + + // TODO(ianlancetaylor): Verify the struct with p.Frontmatter.Schema. + + m := make(map[string]any) + +fieldLoop: + for i := 0; i < vt.NumField(); i++ { + ft := vt.Field(i) + if ft.PkgPath != "" { + continue + } + + jsonTag := ft.Tag.Get("json") + jsonName, rest, _ := strings.Cut(jsonTag, ",") + if jsonName == "" { + jsonName = ft.Name + } + + vf := v.Field(i) + + // If the field is the zero value, and omitempty is set, + // don't pass it as a prompt input variable. + if vf.IsZero() { + for rest != "" { + var key string + key, rest, _ = strings.Cut(rest, ",") + if key == "omitempty" { + continue fieldLoop + } + } + } + + m[jsonName] = vf.Interface() + } + + return m, nil +} + // buildRequest prepares an [ai.GenerateRequest] based on the prompt, // using the input variables and other information in the [ActionInput]. func (p *Prompt) buildRequest(input *ActionInput) (*ai.GenerateRequest, error) { diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go new file mode 100644 index 000000000..87d3a2d88 --- /dev/null +++ b/go/samples/coffee-shop/main.go @@ -0,0 +1,192 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This program can be manually tested like so: +// Start the server listening on port 3100: +// +// go run . & +// +// Tell it to run an action: +// +// curl -d '{"key":"/flow/testAllCoffeeFlows/testAllCoffeeFlows", "input":{"start": {"input":null}}}' http://localhost:3100/api/runAction +package main + +import ( + "context" + "fmt" + "log" + "os" + + "github.com/google/genkit/go/ai" + "github.com/google/genkit/go/genkit" + "github.com/google/genkit/go/genkit/dotprompt" + "github.com/google/genkit/go/plugins/googleai" + "github.com/invopop/jsonschema" +) + +const simpleGreetingPromptTemplate = ` +You're a barista at a nice coffee shop. +A regular customer named {{customerName}} enters. +Greet the customer in one sentence, and recommend a coffee drink. +` + +type simpleGreetingInput struct { + CustomerName string `json:"customerName"` +} + +const greetingWithHistoryPromptTemplate = ` +{{role "user"}} +Hi, my name is {{customerName}}. The time is {{currentTime}}. Who are you? + +{{role "model"}} +I am Barb, a barista at this nice underwater-themed coffee shop called Krabby Kooffee. +I know pretty much everything there is to know about coffee, +and I can cheerfully recommend delicious coffee drinks to you based on whatever you like. + +{{role "user"}} +Great. Last time I had {{previousOrder}}. +I want you to greet me in one sentence, and recommend a drink. +` + +type customerTimeAndHistoryInput struct { + CustomerName string `json:"customerName"` + CurrentTime string `json:"currentTime"` + PreviousOrder string `json:"previousOrder"` +} + +type testAllCoffeeFlowsOutput struct { + Pass bool `json:"pass"` + Replies []string `json:"replies,omitempty"` + Error string `json:"error,omitempty"` +} + +func main() { + apiKey := os.Getenv("GEMINI_API_KEY") + if apiKey == "" { + fmt.Fprintln(os.Stderr, "coffee-shop example requires setting GEMINI_API_KEY in the environment.") + fmt.Fprintln(os.Stderr, "You can get an API key at https://ai.google.dev.") + os.Exit(1) + } + + if err := googleai.Init(context.Background(), "gemini-pro", apiKey); err != nil { + log.Fatal(err) + } + + simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", + &dotprompt.Frontmatter{ + Name: "simpleGreeting", + Model: "google-genai", + Input: dotprompt.FrontmatterInput{ + Schema: jsonschema.Reflect(simpleGreetingInput{}), + }, + Output: &ai.GenerateRequestOutput{ + Format: ai.OutputFormatText, + }, + }, + simpleGreetingPromptTemplate, + nil, + ) + if err != nil { + log.Fatal(err) + } + + simpleGreetingFlow := genkit.DefineFlow("simpleGreeting", func(ctx context.Context, input *simpleGreetingInput, _ genkit.NoStream) (string, error) { + vars, err := simpleGreetingPrompt.BuildVariables(input) + if err != nil { + return "", err + } + ai := &dotprompt.ActionInput{ Variables: vars } + resp, err := simpleGreetingPrompt.Execute(ctx, ai) + if err != nil { + return "", err + } + text, err := resp.Text() + if err != nil { + return "", fmt.Errorf("simpleGreeting: %v", err) + } + return text, nil + }) + + greetingWithHistoryPrompt, err := dotprompt.Define("greetingWithHistory", + &dotprompt.Frontmatter{ + Name: "greetingWithHistory", + Model: "google-genai", + Input: dotprompt.FrontmatterInput{ + Schema: jsonschema.Reflect(customerTimeAndHistoryInput{}), + }, + Output: &ai.GenerateRequestOutput{ + Format: ai.OutputFormatText, + }, + }, + greetingWithHistoryPromptTemplate, + nil, + ) + if err != nil { + log.Fatal(err) + } + + greetingWithHistoryFlow := genkit.DefineFlow("greetingWithHistory", func(ctx context.Context, input *customerTimeAndHistoryInput, _ genkit.NoStream) (string, error) { + vars, err := greetingWithHistoryPrompt.BuildVariables(input) + if err != nil { + return "", err + } + ai := &dotprompt.ActionInput{ Variables: vars } + resp, err := greetingWithHistoryPrompt.Execute(ctx, ai) + if err != nil { + return "", err + } + text, err := resp.Text() + if err != nil { + return "", fmt.Errorf("greetingWithHistory: %v", err) + } + return text, nil + }) + + genkit.DefineFlow("testAllCoffeeFlows", func(ctx context.Context, _ struct{}, _ genkit.NoStream) (*testAllCoffeeFlowsOutput, error) { + test1, err := genkit.RunFlow(ctx, simpleGreetingFlow, &simpleGreetingInput{ + CustomerName: "Sam", + }) + if err != nil { + out := &testAllCoffeeFlowsOutput{ + Pass: false, + Error: err.Error(), + } + return out, nil + } + test2, err := genkit.RunFlow(ctx, greetingWithHistoryFlow, &customerTimeAndHistoryInput{ + CustomerName: "Sam", + CurrentTime: "09:45am", + PreviousOrder: "Caramel Macchiato", + }) + if err != nil { + out := &testAllCoffeeFlowsOutput{ + Pass: false, + Error: err.Error(), + } + return out, nil + } + out := &testAllCoffeeFlowsOutput{ + Pass: true, + Replies: []string{ + test1, + test2, + }, + } + return out, nil + }) + + if err := genkit.StartDevServer(""); err != nil { + log.Fatal(err) + } +}