From 23b246d1366f03fdf62e79907c8b0d54cd923c15 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 21 May 2024 15:49:06 -0400 Subject: [PATCH] [Go] POC: redo dotprompt API (#202) Attempt to make the dotprompt API a little more user-friendly. The main change is to replace the exported Frontmatter type with a type (called Config) that is easier to write a literal for. Miscellaneous other changes: - Unexport the Prompt.Hash field. - Add `New` to create a prompt without registering it, and use `Define` to mean create + register, as it does elsewhere. - Remove the generator as an argument; I believe it would only be used for tests. --- go/genkit/dotprompt/dotprompt.go | 190 +++++++++++++------------- go/genkit/dotprompt/dotprompt_test.go | 12 +- go/genkit/dotprompt/genkit.go | 24 ++-- go/genkit/dotprompt/genkit_test.go | 3 +- go/genkit/dotprompt/render.go | 4 +- go/samples/coffee-shop/main.go | 34 ++--- 6 files changed, 120 insertions(+), 147 deletions(-) diff --git a/go/genkit/dotprompt/dotprompt.go b/go/genkit/dotprompt/dotprompt.go index 48bad4f3f..ab251637e 100644 --- a/go/genkit/dotprompt/dotprompt.go +++ b/go/genkit/dotprompt/dotprompt.go @@ -60,24 +60,50 @@ type Prompt struct { // The name of the prompt. Optional unless the prompt is // registered as an action. Name string - // The name of the prompt variant, if any. - // Variants can be used for easy testing of a tweaked prompt. - Variant string - // A hash of the prompt contents. - Hash string + Config // The template for the prompt. Template *raymond.Template - // The prompt YAML frontmatter. - Frontmatter *Frontmatter + // A hash of the prompt contents. + hash string // A Generator to use. If not nil, this is used by the // [genkit.Action] returned by [Prompt.Action] to execute the prompt. generator ai.Generator } +// Config is optional configuration for a [Prompt]. +type Config struct { + // The prompt variant. + Variant string + // The name of the model for which the prompt is input. + Model string + + // TODO(iant): document + Tools []*ai.ToolDefinition + + // Number of candidates to generate when passing the prompt + // to a generator. If 0, uses 1. + Candidates int + + // Details for the model. + GenerationConfig *ai.GenerationCommonConfig + + InputSchema *jsonschema.Schema // schema for input variables + VariableDefaults map[string]any // default input variable values + + // Desired output format. + OutputFormat ai.OutputFormat + + // Desired output schema, for JSON output. + OutputSchema map[string]any // TODO: use *jsonschema.Schema + + // Arbitrary metadata. + Metadata map[string]any +} + // Open opens and parses a dotprompt file. // The name is a base file name, without the ".prompt" extension. func Open(name string) (*Prompt, error) { @@ -114,44 +140,9 @@ func OpenVariant(name, variant string) (*Prompt, error) { return Parse(name, variant, data) } -// Frontmatter is the data we may see, YAML encoded, at the -// start of a dotprompt file. It appears within --- lines. -// These fields are optional. -type Frontmatter struct { - // The name of the prompt. - Name string - // The prompt variant. - Variant string - // The name of the model for which the prompt is input. - Model string - - // TODO(iant): document - Tools []*ai.ToolDefinition - - // Number of candidates to generate when passing the prompt - // to a generator. If 0, uses 1. - Candidates int - - // Details for the model. - Config *ai.GenerationCommonConfig - - // Description of input data. - Input FrontmatterInput - - // Desired output format. - Output *ai.GenerateRequestOutput - - // Arbitrary metadata. - Metadata map[string]any -} - -// FrontmatterInput describes the input data. -type FrontmatterInput struct { - Schema *jsonschema.Schema // schema for input variables - Default map[string]any // default input variables -} - // frontmatterYAML is the type we use to unpack the frontmatter. +// (Frontmatter is the data we may see, YAML encoded, at the +// start of a dotprompt file. It appears within --- lines.) // We do it this way so that we can handle the input and output // fields as picoschema, while returning them as jsonschema.Schema. type frontmatterYAML struct { @@ -175,76 +166,75 @@ type frontmatterYAML struct { // Parse parses the contents of a dotprompt file. func Parse(name, variant string, data []byte) (*Prompt, error) { const header = "---\n" - var front *Frontmatter + var fmName string + var cfg Config if bytes.HasPrefix(data, []byte(header)) { var err error - front, data, err = parseFrontmatter(data[len(header):]) + fmName, cfg, data, err = parseFrontmatter(data[len(header):]) if err != nil { return nil, err } } + // The name argument takes precedence over the name in the frontmatter. + if name == "" { + name = fmName + } - return define(name, variant, front, fmt.Sprintf("%02x", sha256.Sum256(data)), string(data), nil) + return newPrompt(name, string(data), fmt.Sprintf("%02x", sha256.Sum256(data)), cfg) } -// define defines a new prompt from frontmatter and template text. -func define(name, variant string, frontmatter *Frontmatter, hash, text string, generator ai.Generator) (*Prompt, error) { - template, err := raymond.Parse(text) +// newPrompt creates a new prompt. +// templateText should be a handlebars template. +// hash is its SHA256 hash as a hex string. +func newPrompt(name, templateText, hash string, config Config) (*Prompt, error) { + template, err := raymond.Parse(templateText) if err != nil { return nil, fmt.Errorf("failed to parse template: %w", err) } template.RegisterHelpers(templateHelpers) - - prompt := &Prompt{ - Name: name, - Variant: variant, - Hash: hash, - Template: template, - Frontmatter: frontmatter, - generator: generator, - } - return prompt, nil + return &Prompt{ + Name: name, + Config: config, + hash: hash, + Template: template, + }, nil } // parseFrontmatter parses the initial YAML frontmatter of a dotprompt file. -// Along with the frontmatter itself, it returns the remaining data. -func parseFrontmatter(data []byte) (*Frontmatter, []byte, error) { +// It returns the frontmatter as a Config along with the remaining data. +func parseFrontmatter(data []byte) (name string, c Config, rest []byte, err error) { const footer = "\n---\n" end := bytes.Index(data, []byte(footer)) if end == -1 { - return nil, nil, errors.New("dotprompt: missing marker for end of frontmatter") + return "", Config{}, nil, errors.New("dotprompt: missing marker for end of frontmatter") } input := data[:end] var fy frontmatterYAML if err := yaml.Unmarshal(input, &fy); err != nil { - return nil, nil, fmt.Errorf("dotprompt: failed to parse YAML frontmatter: %w", err) + return "", Config{}, nil, fmt.Errorf("dotprompt: failed to parse YAML frontmatter: %w", err) } - ret := &Frontmatter{ - Name: fy.Name, - Variant: fy.Variant, - Model: fy.Model, - Tools: fy.Tools, - Candidates: fy.Candidates, - Config: fy.Config, - Metadata: fy.Metadata, + ret := Config{ + Variant: fy.Variant, + Model: fy.Model, + Tools: fy.Tools, + Candidates: fy.Candidates, + GenerationConfig: fy.Config, + VariableDefaults: fy.Input.Default, + Metadata: fy.Metadata, } inputSchema, err := picoschemaToJSONSchema(fy.Input.Schema) if err != nil { - return nil, nil, fmt.Errorf("dotprompt: can't parse input: %w", err) - } - ret.Input = FrontmatterInput{ - Schema: inputSchema, - Default: fy.Input.Default, + return "", Config{}, nil, fmt.Errorf("dotprompt: can't parse input: %w", err) } + ret.InputSchema = inputSchema outputSchema, err := picoschemaToJSONSchema(fy.Output.Schema) if err != nil { - return nil, nil, fmt.Errorf("dotprompt: can't parse output: %w", err) + return "", Config{}, nil, fmt.Errorf("dotprompt: can't parse output: %w", err) } - var generateOutputSchema map[string]any if outputSchema != nil { // We have a jsonschema.Schema and we want a map[string]any. // TODO(iant): This conversion is useless. @@ -255,43 +245,47 @@ func parseFrontmatter(data []byte) (*Frontmatter, []byte, error) { data, err := json.Marshal(outputSchema) if err != nil { - return nil, nil, fmt.Errorf("dotprompt: can't JSON marshal JSON schema: %w", err) + return "", Config{}, nil, fmt.Errorf("dotprompt: can't JSON marshal JSON schema: %w", err) } - if err := json.Unmarshal(data, &generateOutputSchema); err != nil { - return nil, nil, fmt.Errorf("dotprompt: can't unmarshal JSON schema: %w", err) + if err := json.Unmarshal(data, &ret.OutputSchema); err != nil { + return "", Config{}, nil, fmt.Errorf("dotprompt: can't unmarshal JSON schema: %w", err) } } // TODO(iant): The TypeScript codes supports media also, // but there is no ai.OutputFormatMedia. - var of ai.OutputFormat switch fy.Output.Format { case "": case string(ai.OutputFormatJSON): - of = ai.OutputFormatJSON + ret.OutputFormat = ai.OutputFormatJSON case string(ai.OutputFormatText): - of = ai.OutputFormatText + ret.OutputFormat = ai.OutputFormatText default: - return nil, nil, fmt.Errorf("dotprompt: unrecognized output format %q", fy.Output.Format) + return "", Config{}, nil, fmt.Errorf("dotprompt: unrecognized output format %q", fy.Output.Format) } + return fy.Name, ret, data[end+len(footer):], nil +} - if of != "" || generateOutputSchema != nil { - ret.Output = &ai.GenerateRequestOutput{ - Format: of, - Schema: generateOutputSchema, - } +// Define creates and registers a new Prompt. This can be called from code that +// doesn't have a prompt file. +func Define(name, templateText string, cfg *Config) (*Prompt, error) { + p, err := New(name, templateText, cfg) + if err != nil { + return nil, err } - - return ret, data[end+len(footer):], nil + p.Register() + return p, nil } -// Define defines a new Prompt. This can be called from code that -// doesn't have a prompt file. If not nil, the generator argument -// will implement the AI model given the substituted prompt text. +// New creates a new Prompt without registering it. // This may be used for testing or for direct calls not using the // genkit action and flow mechanisms. -func Define(name string, frontmatter *Frontmatter, text string, generator ai.Generator) (*Prompt, error) { - return define(name, "", frontmatter, fmt.Sprintf("%02x", sha256.Sum256([]byte(text))), text, generator) +func New(name, templateText string, cfg *Config) (*Prompt, error) { + if cfg == nil { + cfg = &Config{} + } + hash := fmt.Sprintf("%02x", sha256.Sum256([]byte(templateText))) + return newPrompt(name, templateText, hash, *cfg) } // sortSchemaSlices sorts the slices in a jsonschema to permit diff --git a/go/genkit/dotprompt/dotprompt_test.go b/go/genkit/dotprompt/dotprompt_test.go index 037fa37a6..e25654d33 100644 --- a/go/genkit/dotprompt/dotprompt_test.go +++ b/go/genkit/dotprompt/dotprompt_test.go @@ -110,23 +110,23 @@ func TestPrompts(t *testing.T) { t.Fatal(err) } - if prompt.Frontmatter.Model != test.model { - t.Errorf("got model %q want %q", prompt.Frontmatter.Model, test.model) + if prompt.Model != test.model { + t.Errorf("got model %q want %q", prompt.Model, test.model) } - if diff := cmpSchema(t, prompt.Frontmatter.Input.Schema, test.input); diff != "" { + if diff := cmpSchema(t, prompt.InputSchema, test.input); diff != "" { t.Errorf("input schema mismatch (-want, +got):\n%s", diff) } if test.output == "" { - if prompt.Frontmatter.Output != nil && prompt.Frontmatter.Output.Schema != nil { - t.Errorf("unexpected output schema: %v", prompt.Frontmatter.Output.Schema) + if prompt.OutputSchema != nil { + t.Errorf("unexpected output schema: %v", prompt.OutputSchema) } } else { var output map[string]any if err := json.Unmarshal([]byte(test.output), &output); err != nil { t.Fatalf("JSON unmarshal of %q failed: %v", test.output, err) } - if diff := cmp.Diff(output, prompt.Frontmatter.Output.Schema); diff != "" { + if diff := cmp.Diff(output, prompt.OutputSchema); diff != "" { t.Errorf("output schema mismatch (-want, +got):\n%s", diff) } } diff --git a/go/genkit/dotprompt/genkit.go b/go/genkit/dotprompt/genkit.go index 814885618..172f650e9 100644 --- a/go/genkit/dotprompt/genkit.go +++ b/go/genkit/dotprompt/genkit.go @@ -98,24 +98,19 @@ func (p *Prompt) buildRequest(input *ActionInput) (*ai.GenerateRequest, error) { } req.Candidates = input.Candidates - if req.Candidates == 0 && p.Frontmatter != nil { - req.Candidates = p.Frontmatter.Candidates + if req.Candidates == 0 { + req.Candidates = p.Candidates } if req.Candidates == 0 { req.Candidates = 1 } - if p.Frontmatter != nil { - req.Config = p.Frontmatter.Config - } - - if p.Frontmatter != nil { - req.Output = p.Frontmatter.Output - } - - if p.Frontmatter != nil { - req.Tools = p.Frontmatter.Tools + req.Config = p.GenerationConfig + req.Output = &ai.GenerateRequestOutput{ + Format: p.OutputFormat, + Schema: p.OutputSchema, } + req.Tools = p.Tools return req, nil } @@ -174,10 +169,7 @@ func (p *Prompt) Execute(ctx context.Context, input *ActionInput) (*ai.GenerateR generator := p.generator if generator == nil { - var model string - if p.Frontmatter != nil { - model = p.Frontmatter.Model - } + model := p.Model if input.Model != "" { model = input.Model } diff --git a/go/genkit/dotprompt/genkit_test.go b/go/genkit/dotprompt/genkit_test.go index d6604edf8..61d5148c1 100644 --- a/go/genkit/dotprompt/genkit_test.go +++ b/go/genkit/dotprompt/genkit_test.go @@ -45,10 +45,11 @@ func (testGenerator) Generate(ctx context.Context, req *ai.GenerateRequest, cb g } func TestExecute(t *testing.T) { - p, err := Define("TestExecute", &Frontmatter{}, "TestExecute", testGenerator{}) + p, err := New("TestExecute", "TestExecute", nil) if err != nil { t.Fatal(err) } + p.generator = testGenerator{} resp, err := p.Execute(context.Background(), &ActionInput{}) if err != nil { t.Fatal(err) diff --git a/go/genkit/dotprompt/render.go b/go/genkit/dotprompt/render.go index d7faff0bd..20fd9edeb 100644 --- a/go/genkit/dotprompt/render.go +++ b/go/genkit/dotprompt/render.go @@ -48,9 +48,9 @@ func (p *Prompt) RenderText(variables map[string]any) (string, error) { // RenderMessages executes the prompt's template and converts it into messages. func (p *Prompt) RenderMessages(variables map[string]any) ([]*ai.Message, error) { - if p.Frontmatter != nil && p.Frontmatter.Input.Default != nil { + if p.VariableDefaults != nil { nv := make(map[string]any) - maps.Copy(nv, p.Frontmatter.Input.Default) + maps.Copy(nv, p.VariableDefaults) maps.Copy(nv, variables) variables = nv } diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index 4cc25c9f5..73171c397 100644 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -83,19 +83,12 @@ func main() { log.Fatal(err) } - simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", - &dotprompt.Frontmatter{ - Name: "simpleGreeting", - Model: "google-genai/gemini-1.0-pro", - Input: dotprompt.FrontmatterInput{ - Schema: jsonschema.Reflect(simpleGreetingInput{}), - }, - Output: &ai.GenerateRequestOutput{ - Format: ai.OutputFormatText, - }, + simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", simpleGreetingPromptTemplate, + &dotprompt.Config{ + Model: "google-genai/gemini-1.0-pro", + InputSchema: jsonschema.Reflect(simpleGreetingInput{}), + OutputFormat: ai.OutputFormatText, }, - simpleGreetingPromptTemplate, - nil, ) if err != nil { log.Fatal(err) @@ -118,19 +111,12 @@ func main() { return text, nil }) - greetingWithHistoryPrompt, err := dotprompt.Define("greetingWithHistory", - &dotprompt.Frontmatter{ - Name: "greetingWithHistory", - Model: "google-genai/gemini-1.0-pro", - Input: dotprompt.FrontmatterInput{ - Schema: jsonschema.Reflect(customerTimeAndHistoryInput{}), - }, - Output: &ai.GenerateRequestOutput{ - Format: ai.OutputFormatText, - }, + greetingWithHistoryPrompt, err := dotprompt.Define("greetingWithHistory", greetingWithHistoryPromptTemplate, + &dotprompt.Config{ + Model: "google-genai/gemini-1.0-pro", + InputSchema: jsonschema.Reflect(customerTimeAndHistoryInput{}), + OutputFormat: ai.OutputFormatText, }, - greetingWithHistoryPromptTemplate, - nil, ) if err != nil { log.Fatal(err)