Skip to content

Commit

Permalink
[Go] POC: redo dotprompt API (#202)
Browse files Browse the repository at this point in the history
Attempt to make the dotprompt API a little more user-friendly.

The main change is to replace the exported Frontmatter type
with a type (called Config) that is easier to write a literal for.

Miscellaneous other changes:

- Unexport the Prompt.Hash field.

- Add `New` to create a prompt without registering it,
  and use `Define` to mean create + register, as it does elsewhere.

- Remove the generator as an argument; I believe it would only be
  used for tests.
  • Loading branch information
jba authored May 21, 2024
1 parent cd9be86 commit 23b246d
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 147 deletions.
190 changes: 92 additions & 98 deletions go/genkit/dotprompt/dotprompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,50 @@ type Prompt struct {
// The name of the prompt. Optional unless the prompt is
// registered as an action.
Name string
// The name of the prompt variant, if any.
// Variants can be used for easy testing of a tweaked prompt.
Variant string

// A hash of the prompt contents.
Hash string
Config

// The template for the prompt.
Template *raymond.Template

// The prompt YAML frontmatter.
Frontmatter *Frontmatter
// A hash of the prompt contents.
hash string

// A Generator to use. If not nil, this is used by the
// [genkit.Action] returned by [Prompt.Action] to execute the prompt.
generator ai.Generator
}

// Config is optional configuration for a [Prompt].
type Config struct {
// The prompt variant.
Variant string
// The name of the model for which the prompt is input.
Model string

// TODO(iant): document
Tools []*ai.ToolDefinition

// Number of candidates to generate when passing the prompt
// to a generator. If 0, uses 1.
Candidates int

// Details for the model.
GenerationConfig *ai.GenerationCommonConfig

InputSchema *jsonschema.Schema // schema for input variables
VariableDefaults map[string]any // default input variable values

// Desired output format.
OutputFormat ai.OutputFormat

// Desired output schema, for JSON output.
OutputSchema map[string]any // TODO: use *jsonschema.Schema

// Arbitrary metadata.
Metadata map[string]any
}

// Open opens and parses a dotprompt file.
// The name is a base file name, without the ".prompt" extension.
func Open(name string) (*Prompt, error) {
Expand Down Expand Up @@ -114,44 +140,9 @@ func OpenVariant(name, variant string) (*Prompt, error) {
return Parse(name, variant, data)
}

// Frontmatter is the data we may see, YAML encoded, at the
// start of a dotprompt file. It appears within --- lines.
// These fields are optional.
type Frontmatter struct {
// The name of the prompt.
Name string
// The prompt variant.
Variant string
// The name of the model for which the prompt is input.
Model string

// TODO(iant): document
Tools []*ai.ToolDefinition

// Number of candidates to generate when passing the prompt
// to a generator. If 0, uses 1.
Candidates int

// Details for the model.
Config *ai.GenerationCommonConfig

// Description of input data.
Input FrontmatterInput

// Desired output format.
Output *ai.GenerateRequestOutput

// Arbitrary metadata.
Metadata map[string]any
}

// FrontmatterInput describes the input data.
type FrontmatterInput struct {
Schema *jsonschema.Schema // schema for input variables
Default map[string]any // default input variables
}

// frontmatterYAML is the type we use to unpack the frontmatter.
// (Frontmatter is the data we may see, YAML encoded, at the
// start of a dotprompt file. It appears within --- lines.)
// We do it this way so that we can handle the input and output
// fields as picoschema, while returning them as jsonschema.Schema.
type frontmatterYAML struct {
Expand All @@ -175,76 +166,75 @@ type frontmatterYAML struct {
// Parse parses the contents of a dotprompt file.
func Parse(name, variant string, data []byte) (*Prompt, error) {
const header = "---\n"
var front *Frontmatter
var fmName string
var cfg Config
if bytes.HasPrefix(data, []byte(header)) {
var err error
front, data, err = parseFrontmatter(data[len(header):])
fmName, cfg, data, err = parseFrontmatter(data[len(header):])
if err != nil {
return nil, err
}
}
// The name argument takes precedence over the name in the frontmatter.
if name == "" {
name = fmName
}

return define(name, variant, front, fmt.Sprintf("%02x", sha256.Sum256(data)), string(data), nil)
return newPrompt(name, string(data), fmt.Sprintf("%02x", sha256.Sum256(data)), cfg)
}

// define defines a new prompt from frontmatter and template text.
func define(name, variant string, frontmatter *Frontmatter, hash, text string, generator ai.Generator) (*Prompt, error) {
template, err := raymond.Parse(text)
// newPrompt creates a new prompt.
// templateText should be a handlebars template.
// hash is its SHA256 hash as a hex string.
func newPrompt(name, templateText, hash string, config Config) (*Prompt, error) {
template, err := raymond.Parse(templateText)
if err != nil {
return nil, fmt.Errorf("failed to parse template: %w", err)
}
template.RegisterHelpers(templateHelpers)

prompt := &Prompt{
Name: name,
Variant: variant,
Hash: hash,
Template: template,
Frontmatter: frontmatter,
generator: generator,
}
return prompt, nil
return &Prompt{
Name: name,
Config: config,
hash: hash,
Template: template,
}, nil
}

// parseFrontmatter parses the initial YAML frontmatter of a dotprompt file.
// Along with the frontmatter itself, it returns the remaining data.
func parseFrontmatter(data []byte) (*Frontmatter, []byte, error) {
// It returns the frontmatter as a Config along with the remaining data.
func parseFrontmatter(data []byte) (name string, c Config, rest []byte, err error) {
const footer = "\n---\n"
end := bytes.Index(data, []byte(footer))
if end == -1 {
return nil, nil, errors.New("dotprompt: missing marker for end of frontmatter")
return "", Config{}, nil, errors.New("dotprompt: missing marker for end of frontmatter")
}
input := data[:end]
var fy frontmatterYAML
if err := yaml.Unmarshal(input, &fy); err != nil {
return nil, nil, fmt.Errorf("dotprompt: failed to parse YAML frontmatter: %w", err)
return "", Config{}, nil, fmt.Errorf("dotprompt: failed to parse YAML frontmatter: %w", err)
}

ret := &Frontmatter{
Name: fy.Name,
Variant: fy.Variant,
Model: fy.Model,
Tools: fy.Tools,
Candidates: fy.Candidates,
Config: fy.Config,
Metadata: fy.Metadata,
ret := Config{
Variant: fy.Variant,
Model: fy.Model,
Tools: fy.Tools,
Candidates: fy.Candidates,
GenerationConfig: fy.Config,
VariableDefaults: fy.Input.Default,
Metadata: fy.Metadata,
}

inputSchema, err := picoschemaToJSONSchema(fy.Input.Schema)
if err != nil {
return nil, nil, fmt.Errorf("dotprompt: can't parse input: %w", err)
}
ret.Input = FrontmatterInput{
Schema: inputSchema,
Default: fy.Input.Default,
return "", Config{}, nil, fmt.Errorf("dotprompt: can't parse input: %w", err)
}
ret.InputSchema = inputSchema

outputSchema, err := picoschemaToJSONSchema(fy.Output.Schema)
if err != nil {
return nil, nil, fmt.Errorf("dotprompt: can't parse output: %w", err)
return "", Config{}, nil, fmt.Errorf("dotprompt: can't parse output: %w", err)
}

var generateOutputSchema map[string]any
if outputSchema != nil {
// We have a jsonschema.Schema and we want a map[string]any.
// TODO(iant): This conversion is useless.
Expand All @@ -255,43 +245,47 @@ func parseFrontmatter(data []byte) (*Frontmatter, []byte, error) {

data, err := json.Marshal(outputSchema)
if err != nil {
return nil, nil, fmt.Errorf("dotprompt: can't JSON marshal JSON schema: %w", err)
return "", Config{}, nil, fmt.Errorf("dotprompt: can't JSON marshal JSON schema: %w", err)
}
if err := json.Unmarshal(data, &generateOutputSchema); err != nil {
return nil, nil, fmt.Errorf("dotprompt: can't unmarshal JSON schema: %w", err)
if err := json.Unmarshal(data, &ret.OutputSchema); err != nil {
return "", Config{}, nil, fmt.Errorf("dotprompt: can't unmarshal JSON schema: %w", err)
}
}

// TODO(iant): The TypeScript codes supports media also,
// but there is no ai.OutputFormatMedia.
var of ai.OutputFormat
switch fy.Output.Format {
case "":
case string(ai.OutputFormatJSON):
of = ai.OutputFormatJSON
ret.OutputFormat = ai.OutputFormatJSON
case string(ai.OutputFormatText):
of = ai.OutputFormatText
ret.OutputFormat = ai.OutputFormatText
default:
return nil, nil, fmt.Errorf("dotprompt: unrecognized output format %q", fy.Output.Format)
return "", Config{}, nil, fmt.Errorf("dotprompt: unrecognized output format %q", fy.Output.Format)
}
return fy.Name, ret, data[end+len(footer):], nil
}

if of != "" || generateOutputSchema != nil {
ret.Output = &ai.GenerateRequestOutput{
Format: of,
Schema: generateOutputSchema,
}
// Define creates and registers a new Prompt. This can be called from code that
// doesn't have a prompt file.
func Define(name, templateText string, cfg *Config) (*Prompt, error) {
p, err := New(name, templateText, cfg)
if err != nil {
return nil, err
}

return ret, data[end+len(footer):], nil
p.Register()
return p, nil
}

// Define defines a new Prompt. This can be called from code that
// doesn't have a prompt file. If not nil, the generator argument
// will implement the AI model given the substituted prompt text.
// New creates a new Prompt without registering it.
// This may be used for testing or for direct calls not using the
// genkit action and flow mechanisms.
func Define(name string, frontmatter *Frontmatter, text string, generator ai.Generator) (*Prompt, error) {
return define(name, "", frontmatter, fmt.Sprintf("%02x", sha256.Sum256([]byte(text))), text, generator)
func New(name, templateText string, cfg *Config) (*Prompt, error) {
if cfg == nil {
cfg = &Config{}
}
hash := fmt.Sprintf("%02x", sha256.Sum256([]byte(templateText)))
return newPrompt(name, templateText, hash, *cfg)
}

// sortSchemaSlices sorts the slices in a jsonschema to permit
Expand Down
12 changes: 6 additions & 6 deletions go/genkit/dotprompt/dotprompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,23 @@ func TestPrompts(t *testing.T) {
t.Fatal(err)
}

if prompt.Frontmatter.Model != test.model {
t.Errorf("got model %q want %q", prompt.Frontmatter.Model, test.model)
if prompt.Model != test.model {
t.Errorf("got model %q want %q", prompt.Model, test.model)
}
if diff := cmpSchema(t, prompt.Frontmatter.Input.Schema, test.input); diff != "" {
if diff := cmpSchema(t, prompt.InputSchema, test.input); diff != "" {
t.Errorf("input schema mismatch (-want, +got):\n%s", diff)
}

if test.output == "" {
if prompt.Frontmatter.Output != nil && prompt.Frontmatter.Output.Schema != nil {
t.Errorf("unexpected output schema: %v", prompt.Frontmatter.Output.Schema)
if prompt.OutputSchema != nil {
t.Errorf("unexpected output schema: %v", prompt.OutputSchema)
}
} else {
var output map[string]any
if err := json.Unmarshal([]byte(test.output), &output); err != nil {
t.Fatalf("JSON unmarshal of %q failed: %v", test.output, err)
}
if diff := cmp.Diff(output, prompt.Frontmatter.Output.Schema); diff != "" {
if diff := cmp.Diff(output, prompt.OutputSchema); diff != "" {
t.Errorf("output schema mismatch (-want, +got):\n%s", diff)
}
}
Expand Down
24 changes: 8 additions & 16 deletions go/genkit/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,19 @@ func (p *Prompt) buildRequest(input *ActionInput) (*ai.GenerateRequest, error) {
}

req.Candidates = input.Candidates
if req.Candidates == 0 && p.Frontmatter != nil {
req.Candidates = p.Frontmatter.Candidates
if req.Candidates == 0 {
req.Candidates = p.Candidates
}
if req.Candidates == 0 {
req.Candidates = 1
}

if p.Frontmatter != nil {
req.Config = p.Frontmatter.Config
}

if p.Frontmatter != nil {
req.Output = p.Frontmatter.Output
}

if p.Frontmatter != nil {
req.Tools = p.Frontmatter.Tools
req.Config = p.GenerationConfig
req.Output = &ai.GenerateRequestOutput{
Format: p.OutputFormat,
Schema: p.OutputSchema,
}
req.Tools = p.Tools

return req, nil
}
Expand Down Expand Up @@ -174,10 +169,7 @@ func (p *Prompt) Execute(ctx context.Context, input *ActionInput) (*ai.GenerateR

generator := p.generator
if generator == nil {
var model string
if p.Frontmatter != nil {
model = p.Frontmatter.Model
}
model := p.Model
if input.Model != "" {
model = input.Model
}
Expand Down
3 changes: 2 additions & 1 deletion go/genkit/dotprompt/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ func (testGenerator) Generate(ctx context.Context, req *ai.GenerateRequest, cb g
}

func TestExecute(t *testing.T) {
p, err := Define("TestExecute", &Frontmatter{}, "TestExecute", testGenerator{})
p, err := New("TestExecute", "TestExecute", nil)
if err != nil {
t.Fatal(err)
}
p.generator = testGenerator{}
resp, err := p.Execute(context.Background(), &ActionInput{})
if err != nil {
t.Fatal(err)
Expand Down
4 changes: 2 additions & 2 deletions go/genkit/dotprompt/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ func (p *Prompt) RenderText(variables map[string]any) (string, error) {

// RenderMessages executes the prompt's template and converts it into messages.
func (p *Prompt) RenderMessages(variables map[string]any) ([]*ai.Message, error) {
if p.Frontmatter != nil && p.Frontmatter.Input.Default != nil {
if p.VariableDefaults != nil {
nv := make(map[string]any)
maps.Copy(nv, p.Frontmatter.Input.Default)
maps.Copy(nv, p.VariableDefaults)
maps.Copy(nv, variables)
variables = nv
}
Expand Down
Loading

0 comments on commit 23b246d

Please sign in to comment.