Skip to content

Commit

Permalink
fix: incorrect report position and panic on invalid Func.ArgPos (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmzane authored Feb 9, 2023
1 parent 9229084 commit 27a84af
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 147 deletions.
153 changes: 82 additions & 71 deletions musttag.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ var builtin = []Func{
{Name: "github.com/mitchellh/mapstructure.WeakDecodeMetadata", Tag: "mapstructure", ArgPos: 1},
}

// flags creates a flag set for the analyzer. The funcs slice will be filled
// with custom functions passed via CLI flags.
// flags creates a flag set for the analyzer.
// The funcs slice will be filled with custom functions passed via CLI flags.
func flags(funcs *[]Func) flag.FlagSet {
fs := flag.NewFlagSet("musttag", flag.ContinueOnError)
fs.Func("fn", "report custom function (name:tag:argpos)", func(s string) error {
Expand All @@ -80,8 +80,9 @@ func flags(funcs *[]Func) flag.FlagSet {
return *fs
}

// New creates a new musttag analyzer. To report a custom function provide its
// description via Func, it will be added to the builtin ones.
// New creates a new musttag analyzer.
// To report a custom function provide its description via Func,
// it will be added to the builtin ones.
func New(funcs ...Func) *analysis.Analyzer {
var flagFuncs []Func
return &analysis.Analyzer{
Expand Down Expand Up @@ -112,125 +113,135 @@ var (

// reportf is a wrapper for pass.Reportf (as a variable, so it could be mocked in tests).
reportf = func(pass *analysis.Pass, pos token.Pos, fn Func) {
// TODO(junk1tm): print the name of the struct type as well?
pass.Reportf(pos, "exported fields should be annotated with the %q tag", fn.Tag)
}
)

// run starts the analysis.
func run(pass *analysis.Pass, funcs map[string]Func) (any, error) {
insp := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)

filter := []ast.Node{
(*ast.CallExpr)(nil),
}

type report struct {
pos token.Pos
tag string
pos token.Pos // the position for report.
tag string // the missing struct tag.
}
reported := make(map[report]struct{})

insp.Preorder(filter, func(n ast.Node) {
call := n.(*ast.CallExpr)
// store previous reports to prevent reporting
// the same struct more than once (if reportOnce is true).
reports := make(map[report]struct{})

walk := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
filter := []ast.Node{(*ast.CallExpr)(nil)}

walk.Preorder(filter, func(n ast.Node) {
call, ok := n.(*ast.CallExpr)
if !ok {
return // not a function call.
}

callee := typeutil.StaticCallee(pass.TypesInfo, call)
if callee == nil {
return
return // not a static call.
}

fn, ok := funcs[callee.FullName()]
if !ok {
return
return // the function is not supported.
}

if len(call.Args) <= fn.ArgPos {
return // TODO(junk1tm): return a proper error.
}

arg := call.Args[fn.ArgPos]
if unary, ok := arg.(*ast.UnaryExpr); ok {
arg = unary.X // e.g. json.Marshal(&foo)
}

initialPos := token.NoPos
switch arg := arg.(type) {
case *ast.Ident: // e.g. json.Marshal(foo)
initialPos = arg.Obj.Pos()
case *ast.CompositeLit: // e.g. json.Marshal(struct{}{})
initialPos = arg.Pos()
}

s, pos, ok := structAndPos(pass, call.Args[fn.ArgPos])
t := pass.TypesInfo.TypeOf(arg)
s, ok := parseStruct(t, initialPos)
if !ok {
return
return // not a struct argument.
}

if ok := checkStruct(s, fn.Tag, &pos); ok {
return
reportPos, ok := checkStruct(s, fn.Tag)
if ok {
return // nothing to report.
}

r := report{pos, fn.Tag}
if _, ok := reported[r]; ok && reportOnce {
return
r := report{reportPos, fn.Tag}
if _, ok := reports[r]; ok && reportOnce {
return // already reported.
}

reportf(pass, pos, fn)
reported[r] = struct{}{}
reportf(pass, reportPos, fn)
reports[r] = struct{}{}
})

return nil, nil
}

// structAndPos analyses the given expression and returns the struct to check
// and the position to report if needed.
func structAndPos(pass *analysis.Pass, expr ast.Expr) (*types.Struct, token.Pos, bool) {
t := pass.TypesInfo.TypeOf(expr)
if ptr, ok := t.(*types.Pointer); ok {
// structInfo expands types.Struct with its position in the source code.
// If the struct is anonymous, Pos points to the corresponding identifier.
type structInfo struct {
*types.Struct
Pos token.Pos
}

// parseStruct parses the given types.Type, returning the underlying struct type.
// If it's a named type, the result will contain the position of its declaration,
// or the given token.Pos otherwise.
func parseStruct(t types.Type, pos token.Pos) (*structInfo, bool) {
for {
// unwrap pointers (if any) first.
ptr, ok := t.(*types.Pointer)
if !ok {
break
}
t = ptr.Elem()
}

switch t := t.(type) {
case *types.Named: // named type
s, ok := t.Underlying().(*types.Struct)
if ok {
return s, t.Obj().Pos(), true
}

case *types.Struct: // anonymous struct
if unary, ok := expr.(*ast.UnaryExpr); ok {
expr = unary.X // &x
}
//nolint:gocritic // commentedOutCode: these are examples
switch arg := expr.(type) {
case *ast.Ident: // var x struct{}; json.Marshal(x)
return t, arg.Obj.Pos(), true
case *ast.CompositeLit: // json.Marshal(struct{}{})
return t, arg.Pos(), true
case *types.Named: // a struct of the named type.
if s, ok := t.Underlying().(*types.Struct); ok {
return &structInfo{Struct: s, Pos: t.Obj().Pos()}, true
}
case *types.Struct: // an anonymous struct.
return &structInfo{Struct: t, Pos: pos}, true
}

return nil, 0, false
return nil, false
}

// checkStruct checks that exported fields of the given struct are annotated
// with the tag and updates the position to report in case a nested struct of a
// named type is found.
func checkStruct(s *types.Struct, tag string, pos *token.Pos) (ok bool) {
// checkStruct recursively checks the given struct and returns the position for report,
// in case one of its fields is missing the tag.
func checkStruct(s *structInfo, tag string) (token.Pos, bool) {
for i := 0; i < s.NumFields(); i++ {
if !s.Field(i).Exported() {
continue
}

st := reflect.StructTag(s.Tag(i))
if _, ok := st.Lookup(tag); !ok {
// it's ok for embedded types not to be tagged,
// see https://github.com/junk1tm/musttag/issues/12
if !s.Field(i).Embedded() {
return false
}
if _, ok := st.Lookup(tag); !ok && !s.Field(i).Embedded() {
return s.Pos, false
}

// check if the field is a nested struct.
t := s.Field(i).Type()
if ptr, ok := t.(*types.Pointer); ok {
t = ptr.Elem()
}
nested, ok := t.Underlying().(*types.Struct)
nested, ok := parseStruct(t, s.Pos) // TODO(junk1tm): or s.Field(i).Pos()?
if !ok {
continue
}
if ok := checkStruct(nested, tag, pos); ok {
continue
}
// update the position to point to the named type.
if named, ok := t.(*types.Named); ok {
*pos = named.Obj().Pos()
if pos, ok := checkStruct(nested, tag); !ok {
return pos, false
}
return false
}

return true
return token.NoPos, true
}
2 changes: 1 addition & 1 deletion musttag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func TestAnalyzer(t *testing.T) {

func TestFlags(t *testing.T) {
analyzer := New()
analyzer.Flags.SetOutput(io.Discard)
analyzer.Flags.SetOutput(io.Discard) // TODO(junk1tm): does not work, the usage is still printed.

t.Run("ok", func(t *testing.T) {
err := analyzer.Flags.Parse([]string{"-fn=test.Test:test:0"})
Expand Down
149 changes: 74 additions & 75 deletions testdata/src/tests/tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,81 +287,80 @@ func nestedAnonymousType() {

// embedded types should not be reported.
func embeddedType() {
type Y struct { /* want
`\Qjson.Marshal`
`\Qjson.MarshalIndent`
`\Qjson.Unmarshal`
`\Qjson.Encoder.Encode`
`\Qjson.Decoder.Decode`
`\Qxml.Marshal`
`\Qxml.MarshalIndent`
`\Qxml.Unmarshal`
`\Qxml.Encoder.Encode`
`\Qxml.Decoder.Decode`
`\Qxml.Encoder.EncodeElement`
`\Qxml.Decoder.DecodeElement`
`\Qyaml.v3.Marshal`
`\Qyaml.v3.Unmarshal`
`\Qyaml.v3.Encoder.Encode`
`\Qyaml.v3.Decoder.Decode`
`\Qtoml.Unmarshal`
`\Qtoml.Decode`
`\Qtoml.DecodeFS`
`\Qtoml.DecodeFile`
`\Qtoml.Encoder.Encode`
`\Qtoml.Decoder.Decode`
`\Qmapstructure.Decode`
`\Qmapstructure.DecodeMetadata`
`\Qmapstructure.WeakDecode`
`\Qmapstructure.WeakDecodeMetadata`
`\Qcustom.Marshal`
`\Qcustom.Unmarshal` */
NoTag int
}

var x struct {
Y
Z int `json:"z" xml:"z" yaml:"z" toml:"z" mapstructure:"z" custom:"z"`
}

json.Marshal(x)
json.MarshalIndent(x, "", "")
json.Unmarshal(nil, &x)
json.NewEncoder(nil).Encode(x)
json.NewDecoder(nil).Decode(&x)

xml.Marshal(x)
xml.MarshalIndent(x, "", "")
xml.Unmarshal(nil, &x)
xml.NewEncoder(nil).Encode(x)
xml.NewDecoder(nil).Decode(&x)
xml.NewEncoder(nil).EncodeElement(x, xmlSE)
xml.NewDecoder(nil).DecodeElement(&x, &xmlSE)

yaml.Marshal(x)
yaml.Unmarshal(nil, &x)
yaml.NewEncoder(nil).Encode(x)
yaml.NewDecoder(nil).Decode(&x)

toml.Unmarshal(nil, &x)
toml.Decode("", &x)
toml.DecodeFS(nil, "", &x)
toml.DecodeFile("", &x)
toml.NewEncoder(nil).Encode(x)
toml.NewDecoder(nil).Decode(&x)

mapstructure.Decode(nil, &x)
mapstructure.DecodeMetadata(nil, &x, nil)
mapstructure.WeakDecode(nil, &x)
mapstructure.WeakDecodeMetadata(nil, &x, nil)

custom.Marshal(x)
custom.Unmarshal(nil, &x)
type Y struct { /* want
`\Qjson.Marshal`
`\Qjson.MarshalIndent`
`\Qjson.Unmarshal`
`\Qjson.Encoder.Encode`
`\Qjson.Decoder.Decode`
`\Qxml.Marshal`
`\Qxml.MarshalIndent`
`\Qxml.Unmarshal`
`\Qxml.Encoder.Encode`
`\Qxml.Decoder.Decode`
`\Qxml.Encoder.EncodeElement`
`\Qxml.Decoder.DecodeElement`
`\Qyaml.v3.Marshal`
`\Qyaml.v3.Unmarshal`
`\Qyaml.v3.Encoder.Encode`
`\Qyaml.v3.Decoder.Decode`
`\Qtoml.Unmarshal`
`\Qtoml.Decode`
`\Qtoml.DecodeFS`
`\Qtoml.DecodeFile`
`\Qtoml.Encoder.Encode`
`\Qtoml.Decoder.Decode`
`\Qmapstructure.Decode`
`\Qmapstructure.DecodeMetadata`
`\Qmapstructure.WeakDecode`
`\Qmapstructure.WeakDecodeMetadata`
`\Qcustom.Marshal`
`\Qcustom.Unmarshal` */
NoTag int
}
var x struct {
Y
Z int `json:"z" xml:"z" yaml:"z" toml:"z" mapstructure:"z" custom:"z"`
}

json.Marshal(x)
json.MarshalIndent(x, "", "")
json.Unmarshal(nil, &x)
json.NewEncoder(nil).Encode(x)
json.NewDecoder(nil).Decode(&x)

xml.Marshal(x)
xml.MarshalIndent(x, "", "")
xml.Unmarshal(nil, &x)
xml.NewEncoder(nil).Encode(x)
xml.NewDecoder(nil).Decode(&x)
xml.NewEncoder(nil).EncodeElement(x, xmlSE)
xml.NewDecoder(nil).DecodeElement(&x, &xmlSE)

yaml.Marshal(x)
yaml.Unmarshal(nil, &x)
yaml.NewEncoder(nil).Encode(x)
yaml.NewDecoder(nil).Decode(&x)

toml.Unmarshal(nil, &x)
toml.Decode("", &x)
toml.DecodeFS(nil, "", &x)
toml.DecodeFile("", &x)
toml.NewEncoder(nil).Encode(x)
toml.NewDecoder(nil).Decode(&x)

mapstructure.Decode(nil, &x)
mapstructure.DecodeMetadata(nil, &x, nil)
mapstructure.WeakDecode(nil, &x)
mapstructure.WeakDecodeMetadata(nil, &x, nil)

custom.Marshal(x)
custom.Unmarshal(nil, &x)
}

// all good, nothing to report.
Expand Down

0 comments on commit 27a84af

Please sign in to comment.