Skip to content

Commit

Permalink
feat: return an error on bad Func.ArgPos (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmzane authored Feb 25, 2023
1 parent 227228f commit d352735
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 17 deletions.
14 changes: 10 additions & 4 deletions musttag.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package musttag

import (
"flag"
"fmt"
"go/ast"
"go/token"
"go/types"
Expand Down Expand Up @@ -81,7 +82,7 @@ func flags(funcs *[]Func) flag.FlagSet {

// for tests only.
var (
reportf = func(pass *analysis.Pass, st *structType, fn Func, fnPos token.Position) {
report = func(pass *analysis.Pass, st *structType, fn Func, fnPos token.Position) {
const format = "`%s` should be annotated with the `%s` tag as it is passed to `%s` at %s"
pass.Reportf(st.Pos, format, st.Name, fn.Tag, fn.shortName(), fnPos)
}
Expand All @@ -106,6 +107,10 @@ func run(pass *analysis.Pass, funcs map[string]Func) (any, error) {
filter := []ast.Node{(*ast.CallExpr)(nil)}

walk.Preorder(filter, func(n ast.Node) {
if err != nil {
return // there is already an error.
}

call, ok := n.(*ast.CallExpr)
if !ok {
return // not a function call.
Expand All @@ -122,7 +127,8 @@ func run(pass *analysis.Pass, funcs map[string]Func) (any, error) {
}

if len(call.Args) <= fn.ArgPos {
return // TODO(junk1tm): return a proper error.
err = fmt.Errorf("Func.ArgPos cannot be %d: %s accepts only %d argument(s)", fn.ArgPos, fn.Name, len(call.Args))
return
}

arg := call.Args[fn.ArgPos]
Expand Down Expand Up @@ -159,10 +165,10 @@ func run(pass *analysis.Pass, funcs map[string]Func) (any, error) {

p := pass.Fset.Position(call.Pos())
p.Filename, _ = filepath.Rel(moduleDir, p.Filename)
reportf(pass, result, fn, p)
report(pass, result, fn, p)
})

return nil, nil
return nil, err
}

// structType is an extension for types.Struct.
Expand Down
44 changes: 31 additions & 13 deletions musttag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,36 @@ func TestAnalyzer(t *testing.T) {
prepareTestFiles(t)
testPackages = []string{"tests", "builtins"}

analyzer := New(
Func{Name: "example.com/custom.Marshal", Tag: "custom", ArgPos: 0},
Func{Name: "example.com/custom.Unmarshal", Tag: "custom", ArgPos: 1},
)
testdata := analysistest.TestData()

t.Run("tests", func(t *testing.T) {
r := report
defer func() { report = r }()
report = func(pass *analysis.Pass, st *structType, fn Func, fnPos token.Position) {
pass.Reportf(st.Pos, fn.shortName())
}
analyzer := New()
analysistest.Run(t, testdata, analyzer, "tests")
})

t.Run("builtins", func(t *testing.T) {
testdata := analysistest.TestData()
analyzer := New(
Func{Name: "example.com/custom.Marshal", Tag: "custom", ArgPos: 0},
Func{Name: "example.com/custom.Unmarshal", Tag: "custom", ArgPos: 1},
)
analysistest.Run(t, testdata, analyzer, "builtins")
})

t.Run("tests", func(t *testing.T) {
reportf = func(pass *analysis.Pass, st *structType, fn Func, fnPos token.Position) {
pass.Reportf(st.Pos, fn.shortName())
t.Run("bad Func.ArgPos", func(t *testing.T) {
const want = `Func.ArgPos cannot be 10: encoding/json.Marshal accepts only 1 argument(s)`
analyzer := New(
// override the builtin function.
Func{Name: "encoding/json.Marshal", Tag: "json", ArgPos: 10},
)
result := analysistest.Run(nopT{}, testdata, analyzer, "tests")[0]
if got := result.Err.Error(); got != want {
t.Errorf("\ngot\t%s\nwant\t%s", got, want)
}
testdata := analysistest.TestData()
analysistest.Run(t, testdata, analyzer, "tests")
})
}

Expand All @@ -46,27 +60,31 @@ func TestFlags(t *testing.T) {
t.Run("ok", func(t *testing.T) {
err := analyzer.Flags.Parse([]string{"-fn=test.Test:test:0"})
if err != nil {
t.Errorf("got %v; want no error", err)
t.Errorf("\ngot\t%s\nwant\tno error", err)
}
})

t.Run("invalid format", func(t *testing.T) {
const want = `invalid value "test.Test" for flag -fn: invalid syntax`
err := analyzer.Flags.Parse([]string{"-fn=test.Test"})
if got := err.Error(); got != want {
t.Errorf("got %q; want %q", got, want)
t.Errorf("\ngot\t%s\nwant\t%s", got, want)
}
})

t.Run("non-number argpos", func(t *testing.T) {
const want = `invalid value "test.Test:test:-" for flag -fn: strconv.Atoi: parsing "-": invalid syntax`
err := analyzer.Flags.Parse([]string{"-fn=test.Test:test:-"})
if got := err.Error(); got != want {
t.Errorf("got %q; want %q", got, want)
t.Errorf("\ngot\t%s\nwant\t%s", got, want)
}
})
}

type nopT struct{}

func (nopT) Errorf(string, ...any) {}

func prepareTestFiles(t *testing.T) {
testdata := analysistest.TestData()

Expand Down

0 comments on commit d352735

Please sign in to comment.