Skip to content

Commit

Permalink
feat: Embed() option and Context.Call()
Browse files Browse the repository at this point in the history
The former allows arbitrary structs to be embedded in the root of the
CLI, with optional tags.

The latter allows an arbitrary function to be called using Kong's
binding functionality.
  • Loading branch information
alecthomas committed Nov 22, 2022
1 parent d974d72 commit bf0cbf5
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 29 deletions.
39 changes: 19 additions & 20 deletions build.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func build(k *Kong, ast interface{}) (app *Application, err error) {
seenFlags[flag.Name] = true
}

node, err := buildNode(k, iv, ApplicationNode, seenFlags)
node, err := buildNode(k, iv, ApplicationNode, newEmptyTag(), seenFlags)
if err != nil {
return nil, err
}
Expand All @@ -49,7 +49,7 @@ type flattenedField struct {
tag *Tag
}

func flattenedFields(v reflect.Value) (out []flattenedField, err error) {
func flattenedFields(v reflect.Value, ptag *Tag) (out []flattenedField, err error) {
v = reflect.Indirect(v)
for i := 0; i < v.NumField(); i++ {
ft := v.Type().Field(i)
Expand All @@ -61,14 +61,24 @@ func flattenedFields(v reflect.Value) (out []flattenedField, err error) {
if tag.Ignored {
continue
}
// Assign group if it's not already set.
if tag.Group == "" {
tag.Group = ptag.Group
}
// Accumulate prefixes.
tag.Prefix = ptag.Prefix + tag.Prefix
tag.EnvPrefix = ptag.EnvPrefix + tag.EnvPrefix
// Combine parent vars.
tag.Vars = ptag.Vars.CloneWith(tag.Vars)
// Command and embedded structs can be pointers, so we hydrate them now.
if (tag.Cmd || tag.Embed) && ft.Type.Kind() == reflect.Ptr {
fv = reflect.New(ft.Type.Elem()).Elem()
v.FieldByIndex(ft.Index).Set(fv.Addr())
}
if !ft.Anonymous && !tag.Embed {
if fv.CanSet() {
out = append(out, flattenedField{field: ft, value: fv, tag: tag})
field := flattenedField{field: ft, value: fv, tag: tag}
out = append(out, field)
}
continue
}
Expand All @@ -78,29 +88,18 @@ func flattenedFields(v reflect.Value) (out []flattenedField, err error) {
fv = fv.Elem()
} else if fv.Type() == reflect.TypeOf(Plugins{}) {
for i := 0; i < fv.Len(); i++ {
fields, ferr := flattenedFields(fv.Index(i).Elem())
fields, ferr := flattenedFields(fv.Index(i).Elem(), tag)
if ferr != nil {
return nil, ferr
}
out = append(out, fields...)
}
continue
}
sub, err := flattenedFields(fv)
sub, err := flattenedFields(fv, tag)
if err != nil {
return nil, err
}
for _, subf := range sub {
// Assign parent if it's not already set.
if subf.tag.Group == "" {
subf.tag.Group = tag.Group
}
// Accumulate prefixes.
subf.tag.Prefix = tag.Prefix + subf.tag.Prefix
subf.tag.EnvPrefix = tag.EnvPrefix + subf.tag.EnvPrefix
// Combine parent vars.
subf.tag.Vars = tag.Vars.CloneWith(subf.tag.Vars)
}
out = append(out, sub...)
}
return out, nil
Expand All @@ -109,13 +108,13 @@ func flattenedFields(v reflect.Value) (out []flattenedField, err error) {
// Build a Node in the Kong data model.
//
// "v" is the value to create the node from, "typ" is the output Node type.
func buildNode(k *Kong, v reflect.Value, typ NodeType, seenFlags map[string]bool) (*Node, error) {
func buildNode(k *Kong, v reflect.Value, typ NodeType, tag *Tag, seenFlags map[string]bool) (*Node, error) {
node := &Node{
Type: typ,
Target: v,
Tag: newEmptyTag(),
Tag: tag,
}
fields, err := flattenedFields(v)
fields, err := flattenedFields(v, tag)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -201,7 +200,7 @@ func validatePositionalArguments(node *Node) error {
}

func buildChild(k *Kong, node *Node, typ NodeType, v reflect.Value, ft reflect.StructField, fv reflect.Value, tag *Tag, name string, seenFlags map[string]bool) error {
child, err := buildNode(k, fv, typ, seenFlags)
child, err := buildNode(k, fv, typ, newEmptyTag(), seenFlags)
if err != nil {
return err
}
Expand Down
43 changes: 40 additions & 3 deletions callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,14 @@ func getMethod(value reflect.Value, name string) reflect.Value {
return method
}

func callMethod(name string, v, f reflect.Value, bindings bindings) error {
func callFunction(f reflect.Value, bindings bindings) error {
if f.Kind() != reflect.Func {
return fmt.Errorf("expected function, got %s", f.Type())
}
in := []reflect.Value{}
t := f.Type()
if t.NumOut() != 1 || !t.Out(0).Implements(callbackReturnSignature) {
return fmt.Errorf("return value of %T.%s() must implement \"error\"", v.Type(), name)
return fmt.Errorf("return value of %s must implement \"error\"", t)
}
for i := 0; i < t.NumIn(); i++ {
pt := t.In(i)
Expand All @@ -89,7 +92,7 @@ func callMethod(name string, v, f reflect.Value, bindings bindings) error {
}
in = append(in, argv)
} else {
return fmt.Errorf("couldn't find binding of type %s for parameter %d of %s.%s(), use kong.Bind(%s)", pt, i, v.Type(), name, pt)
return fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt)
}
}
out := f.Call(in)
Expand All @@ -98,3 +101,37 @@ func callMethod(name string, v, f reflect.Value, bindings bindings) error {
}
return out[0].Interface().(error) // nolint
}

func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error) {
if f.Kind() != reflect.Func {
return nil, fmt.Errorf("expected function, got %s", f.Type())
}
in := []reflect.Value{}
t := f.Type()
for i := 0; i < t.NumIn(); i++ {
pt := t.In(i)
if argf, ok := bindings[pt]; ok {
argv, err := argf()
if err != nil {
return nil, err
}
in = append(in, argv)
} else {
return nil, fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt)
}
}
outv := f.Call(in)
out = make([]any, len(outv))
for i, v := range outv {
out[i] = v.Interface()
}
return out, nil
}

func callMethod(name string, v, f reflect.Value, bindings bindings) error {
err := callFunction(f, bindings)
if err != nil {
return fmt.Errorf("%s.%s(): %w", v.Type(), name, err)
}
return nil
}
9 changes: 8 additions & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (c *Context) Bind(args ...interface{}) {
//
// This will typically have to be called like so:
//
// BindTo(impl, (*MyInterface)(nil))
// BindTo(impl, (*MyInterface)(nil))
func (c *Context) BindTo(impl, iface interface{}) {
c.bindings.addTo(impl, iface)
}
Expand Down Expand Up @@ -719,6 +719,13 @@ func (c *Context) parseFlag(flags []*Flag, match string) (err error) {
return findPotentialCandidates(match, candidates, "unknown flag %s", match)
}

// Call an arbitrary function filling arguments with bound values.
func (c *Context) Call(fn any, binds ...interface{}) (out []interface{}, err error) {
fv := reflect.ValueOf(fn)
bindings := c.Kong.bindings.clone().add(binds...).add(c).merge(c.bindings) //nolint:govet
return callAnyFunction(fv, bindings)
}

// RunNode calls the Run() method on an arbitrary node.
//
// This is useful in conjunction with Visit(), for dynamically running commands.
Expand Down
24 changes: 24 additions & 0 deletions kong.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ type Kong struct {

// Set temporarily by Options. These are applied after build().
postBuildOptions []Option
embedded []embedded
dynamicCommands []*dynamicCommand
}

Expand Down Expand Up @@ -110,6 +111,25 @@ func New(grammar interface{}, options ...Option) (*Kong, error) {
k.Model = model
k.Model.HelpFlag = k.helpFlag

// Embed any embedded structs.
for _, embed := range k.embedded {
tag, err := parseTagString(strings.Join(embed.tags, " ")) //nolint:govet
if err != nil {
return nil, err
}
tag.Embed = true
v := reflect.Indirect(reflect.ValueOf(embed.strct))
node, err := buildNode(k, v, CommandNode, tag, map[string]bool{})
if err != nil {
return nil, err
}
for _, child := range node.Children {
child.Parent = k.Model.Node
k.Model.Children = append(k.Model.Children, child)
}
k.Model.Flags = append(k.Model.Flags, node.Flags...)
}

// Synthesise command nodes.
for _, dcmd := range k.dynamicCommands {
tag, terr := parseTagString(strings.Join(dcmd.tags, " "))
Expand Down Expand Up @@ -188,6 +208,10 @@ func (k *Kong) interpolateValue(value *Value, vars Vars) (err error) {
vars = vars.CloneWith(varsContributor.Vars(value))
}

if value.Enum, err = interpolate(value.Enum, vars, nil); err != nil {
return fmt.Errorf("enum for %s: %s", value.Summary(), err)
}

updatedVars := map[string]string{
"default": value.Default,
"enum": value.Enum,
Expand Down
28 changes: 24 additions & 4 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@ func Exit(exit func(int)) Option {
})
}

type embedded struct {
strct any
tags []string
}

// Embed a struct into the root of the CLI.
//
// "strct" must be a pointer to a structure.
func Embed(strct any, tags ...string) Option {
t := reflect.TypeOf(strct)
if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
panic("kong: Embed() must be called with a pointer to a struct")
}
return OptionFunc(func(k *Kong) error {
k.embedded = append(k.embedded, embedded{strct, tags})
return nil
})
}

type dynamicCommand struct {
name string
help string
Expand Down Expand Up @@ -164,8 +183,8 @@ func Writers(stdout, stderr io.Writer) Option {
//
// There are two hook points:
//
// BeforeApply(...) error
// AfterApply(...) error
// BeforeApply(...) error
// AfterApply(...) error
//
// Called before validation/assignment, and immediately after validation/assignment, respectively.
func Bind(args ...interface{}) Option {
Expand All @@ -177,7 +196,7 @@ func Bind(args ...interface{}) Option {

// BindTo allows binding of implementations to interfaces.
//
// BindTo(impl, (*iface)(nil))
// BindTo(impl, (*iface)(nil))
func BindTo(impl, iface interface{}) Option {
return OptionFunc(func(k *Kong) error {
k.bindings.addTo(impl, iface)
Expand Down Expand Up @@ -428,7 +447,8 @@ func siftStrings(ss []string, filter func(s string) bool) []string {
// Predefined environment variables are skipped.
//
// For example:
// --some.value -> PREFIX_SOME_VALUE
//
// --some.value -> PREFIX_SOME_VALUE
func DefaultEnvars(prefix string) Option {
processFlag := func(flag *Flag) {
switch env := flag.Env; {
Expand Down
2 changes: 1 addition & 1 deletion options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestInvalidCallback(t *testing.T) {
p, err := New(&cli, BindTo(impl("foo"), (*iface)(nil)))
assert.NoError(t, err)
err = callMethod("method", reflect.ValueOf(impl("??")), reflect.ValueOf(method), p.bindings)
assert.EqualError(t, err, `return value of *reflect.rtype.method() must implement "error"`)
assert.EqualError(t, err, `kong.impl.method(): return value of func(kong.iface) string must implement "error"`)
}

type zrror struct{}
Expand Down
10 changes: 10 additions & 0 deletions tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ type Tag struct {
items map[string][]string
}

func (t *Tag) String() string {
out := []string{}
for key, list := range t.items {
for _, value := range list {
out = append(out, fmt.Sprintf("%s:%q", key, value))
}
}
return strings.Join(out, " ")
}

type tagChars struct {
sep, quote, assign rune
}
Expand Down

0 comments on commit bf0cbf5

Please sign in to comment.