diff --git a/musttag.go b/musttag.go index a58dc1c..121777b 100644 --- a/musttag.go +++ b/musttag.go @@ -57,8 +57,8 @@ var builtin = []Func{ {Name: "github.com/mitchellh/mapstructure.WeakDecodeMetadata", Tag: "mapstructure", ArgPos: 1}, } -// flags creates a flag set for the analyzer. The funcs slice will be filled -// with custom functions passed via CLI flags. +// flags creates a flag set for the analyzer. +// The funcs slice will be filled with custom functions passed via CLI flags. func flags(funcs *[]Func) flag.FlagSet { fs := flag.NewFlagSet("musttag", flag.ContinueOnError) fs.Func("fn", "report custom function (name:tag:argpos)", func(s string) error { @@ -80,8 +80,9 @@ func flags(funcs *[]Func) flag.FlagSet { return *fs } -// New creates a new musttag analyzer. To report a custom function provide its -// description via Func, it will be added to the builtin ones. +// New creates a new musttag analyzer. +// To report a custom function provide its description via Func, +// it will be added to the builtin ones. func New(funcs ...Func) *analysis.Analyzer { var flagFuncs []Func return &analysis.Analyzer{ @@ -112,125 +113,135 @@ var ( // reportf is a wrapper for pass.Reportf (as a variable, so it could be mocked in tests). reportf = func(pass *analysis.Pass, pos token.Pos, fn Func) { + // TODO(junk1tm): print the name of the struct type as well? pass.Reportf(pos, "exported fields should be annotated with the %q tag", fn.Tag) } ) // run starts the analysis. func run(pass *analysis.Pass, funcs map[string]Func) (any, error) { - insp := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) - - filter := []ast.Node{ - (*ast.CallExpr)(nil), - } - type report struct { - pos token.Pos - tag string + pos token.Pos // the position for report. + tag string // the missing struct tag. } - reported := make(map[report]struct{}) - insp.Preorder(filter, func(n ast.Node) { - call := n.(*ast.CallExpr) + // store previous reports to prevent reporting + // the same struct more than once (if reportOnce is true). + reports := make(map[report]struct{}) + + walk := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + filter := []ast.Node{(*ast.CallExpr)(nil)} + + walk.Preorder(filter, func(n ast.Node) { + call, ok := n.(*ast.CallExpr) + if !ok { + return // not a function call. + } callee := typeutil.StaticCallee(pass.TypesInfo, call) if callee == nil { - return + return // not a static call. } fn, ok := funcs[callee.FullName()] if !ok { - return + return // the function is not supported. + } + + if len(call.Args) <= fn.ArgPos { + return // TODO(junk1tm): return a proper error. + } + + arg := call.Args[fn.ArgPos] + if unary, ok := arg.(*ast.UnaryExpr); ok { + arg = unary.X // e.g. json.Marshal(&foo) + } + + initialPos := token.NoPos + switch arg := arg.(type) { + case *ast.Ident: // e.g. json.Marshal(foo) + initialPos = arg.Obj.Pos() + case *ast.CompositeLit: // e.g. json.Marshal(struct{}{}) + initialPos = arg.Pos() } - s, pos, ok := structAndPos(pass, call.Args[fn.ArgPos]) + t := pass.TypesInfo.TypeOf(arg) + s, ok := parseStruct(t, initialPos) if !ok { - return + return // not a struct argument. } - if ok := checkStruct(s, fn.Tag, &pos); ok { - return + reportPos, ok := checkStruct(s, fn.Tag) + if ok { + return // nothing to report. } - r := report{pos, fn.Tag} - if _, ok := reported[r]; ok && reportOnce { - return + r := report{reportPos, fn.Tag} + if _, ok := reports[r]; ok && reportOnce { + return // already reported. } - reportf(pass, pos, fn) - reported[r] = struct{}{} + reportf(pass, reportPos, fn) + reports[r] = struct{}{} }) return nil, nil } -// structAndPos analyses the given expression and returns the struct to check -// and the position to report if needed. -func structAndPos(pass *analysis.Pass, expr ast.Expr) (*types.Struct, token.Pos, bool) { - t := pass.TypesInfo.TypeOf(expr) - if ptr, ok := t.(*types.Pointer); ok { +// structInfo expands types.Struct with its position in the source code. +// If the struct is anonymous, Pos points to the corresponding identifier. +type structInfo struct { + *types.Struct + Pos token.Pos +} + +// parseStruct parses the given types.Type, returning the underlying struct type. +// If it's a named type, the result will contain the position of its declaration, +// or the given token.Pos otherwise. +func parseStruct(t types.Type, pos token.Pos) (*structInfo, bool) { + for { + // unwrap pointers (if any) first. + ptr, ok := t.(*types.Pointer) + if !ok { + break + } t = ptr.Elem() } switch t := t.(type) { - case *types.Named: // named type - s, ok := t.Underlying().(*types.Struct) - if ok { - return s, t.Obj().Pos(), true - } - - case *types.Struct: // anonymous struct - if unary, ok := expr.(*ast.UnaryExpr); ok { - expr = unary.X // &x - } - //nolint:gocritic // commentedOutCode: these are examples - switch arg := expr.(type) { - case *ast.Ident: // var x struct{}; json.Marshal(x) - return t, arg.Obj.Pos(), true - case *ast.CompositeLit: // json.Marshal(struct{}{}) - return t, arg.Pos(), true + case *types.Named: // a struct of the named type. + if s, ok := t.Underlying().(*types.Struct); ok { + return &structInfo{Struct: s, Pos: t.Obj().Pos()}, true } + case *types.Struct: // an anonymous struct. + return &structInfo{Struct: t, Pos: pos}, true } - return nil, 0, false + return nil, false } -// checkStruct checks that exported fields of the given struct are annotated -// with the tag and updates the position to report in case a nested struct of a -// named type is found. -func checkStruct(s *types.Struct, tag string, pos *token.Pos) (ok bool) { +// checkStruct recursively checks the given struct and returns the position for report, +// in case one of its fields is missing the tag. +func checkStruct(s *structInfo, tag string) (token.Pos, bool) { for i := 0; i < s.NumFields(); i++ { if !s.Field(i).Exported() { continue } st := reflect.StructTag(s.Tag(i)) - if _, ok := st.Lookup(tag); !ok { - // it's ok for embedded types not to be tagged, - // see https://github.com/junk1tm/musttag/issues/12 - if !s.Field(i).Embedded() { - return false - } + if _, ok := st.Lookup(tag); !ok && !s.Field(i).Embedded() { + return s.Pos, false } - // check if the field is a nested struct. t := s.Field(i).Type() - if ptr, ok := t.(*types.Pointer); ok { - t = ptr.Elem() - } - nested, ok := t.Underlying().(*types.Struct) + nested, ok := parseStruct(t, s.Pos) // TODO(junk1tm): or s.Field(i).Pos()? if !ok { continue } - if ok := checkStruct(nested, tag, pos); ok { - continue - } - // update the position to point to the named type. - if named, ok := t.(*types.Named); ok { - *pos = named.Obj().Pos() + if pos, ok := checkStruct(nested, tag); !ok { + return pos, false } - return false } - return true + return token.NoPos, true } diff --git a/musttag_test.go b/musttag_test.go index 08b41a3..f1953a8 100644 --- a/musttag_test.go +++ b/musttag_test.go @@ -54,7 +54,7 @@ func TestAnalyzer(t *testing.T) { func TestFlags(t *testing.T) { analyzer := New() - analyzer.Flags.SetOutput(io.Discard) + analyzer.Flags.SetOutput(io.Discard) // TODO(junk1tm): does not work, the usage is still printed. t.Run("ok", func(t *testing.T) { err := analyzer.Flags.Parse([]string{"-fn=test.Test:test:0"}) diff --git a/testdata/src/tests/tests.go b/testdata/src/tests/tests.go index 7e5b806..007328c 100644 --- a/testdata/src/tests/tests.go +++ b/testdata/src/tests/tests.go @@ -287,81 +287,80 @@ func nestedAnonymousType() { // embedded types should not be reported. func embeddedType() { - type Y struct { /* want - `\Qjson.Marshal` - `\Qjson.MarshalIndent` - `\Qjson.Unmarshal` - `\Qjson.Encoder.Encode` - `\Qjson.Decoder.Decode` - - `\Qxml.Marshal` - `\Qxml.MarshalIndent` - `\Qxml.Unmarshal` - `\Qxml.Encoder.Encode` - `\Qxml.Decoder.Decode` - `\Qxml.Encoder.EncodeElement` - `\Qxml.Decoder.DecodeElement` - - `\Qyaml.v3.Marshal` - `\Qyaml.v3.Unmarshal` - `\Qyaml.v3.Encoder.Encode` - `\Qyaml.v3.Decoder.Decode` - - `\Qtoml.Unmarshal` - `\Qtoml.Decode` - `\Qtoml.DecodeFS` - `\Qtoml.DecodeFile` - `\Qtoml.Encoder.Encode` - `\Qtoml.Decoder.Decode` - - `\Qmapstructure.Decode` - `\Qmapstructure.DecodeMetadata` - `\Qmapstructure.WeakDecode` - `\Qmapstructure.WeakDecodeMetadata` - - `\Qcustom.Marshal` - `\Qcustom.Unmarshal` */ - NoTag int - } - - var x struct { - Y - Z int `json:"z" xml:"z" yaml:"z" toml:"z" mapstructure:"z" custom:"z"` - } - - json.Marshal(x) - json.MarshalIndent(x, "", "") - json.Unmarshal(nil, &x) - json.NewEncoder(nil).Encode(x) - json.NewDecoder(nil).Decode(&x) - - xml.Marshal(x) - xml.MarshalIndent(x, "", "") - xml.Unmarshal(nil, &x) - xml.NewEncoder(nil).Encode(x) - xml.NewDecoder(nil).Decode(&x) - xml.NewEncoder(nil).EncodeElement(x, xmlSE) - xml.NewDecoder(nil).DecodeElement(&x, &xmlSE) - - yaml.Marshal(x) - yaml.Unmarshal(nil, &x) - yaml.NewEncoder(nil).Encode(x) - yaml.NewDecoder(nil).Decode(&x) - - toml.Unmarshal(nil, &x) - toml.Decode("", &x) - toml.DecodeFS(nil, "", &x) - toml.DecodeFile("", &x) - toml.NewEncoder(nil).Encode(x) - toml.NewDecoder(nil).Decode(&x) - - mapstructure.Decode(nil, &x) - mapstructure.DecodeMetadata(nil, &x, nil) - mapstructure.WeakDecode(nil, &x) - mapstructure.WeakDecodeMetadata(nil, &x, nil) - - custom.Marshal(x) - custom.Unmarshal(nil, &x) + type Y struct { /* want + `\Qjson.Marshal` + `\Qjson.MarshalIndent` + `\Qjson.Unmarshal` + `\Qjson.Encoder.Encode` + `\Qjson.Decoder.Decode` + + `\Qxml.Marshal` + `\Qxml.MarshalIndent` + `\Qxml.Unmarshal` + `\Qxml.Encoder.Encode` + `\Qxml.Decoder.Decode` + `\Qxml.Encoder.EncodeElement` + `\Qxml.Decoder.DecodeElement` + + `\Qyaml.v3.Marshal` + `\Qyaml.v3.Unmarshal` + `\Qyaml.v3.Encoder.Encode` + `\Qyaml.v3.Decoder.Decode` + + `\Qtoml.Unmarshal` + `\Qtoml.Decode` + `\Qtoml.DecodeFS` + `\Qtoml.DecodeFile` + `\Qtoml.Encoder.Encode` + `\Qtoml.Decoder.Decode` + + `\Qmapstructure.Decode` + `\Qmapstructure.DecodeMetadata` + `\Qmapstructure.WeakDecode` + `\Qmapstructure.WeakDecodeMetadata` + + `\Qcustom.Marshal` + `\Qcustom.Unmarshal` */ + NoTag int + } + var x struct { + Y + Z int `json:"z" xml:"z" yaml:"z" toml:"z" mapstructure:"z" custom:"z"` + } + + json.Marshal(x) + json.MarshalIndent(x, "", "") + json.Unmarshal(nil, &x) + json.NewEncoder(nil).Encode(x) + json.NewDecoder(nil).Decode(&x) + + xml.Marshal(x) + xml.MarshalIndent(x, "", "") + xml.Unmarshal(nil, &x) + xml.NewEncoder(nil).Encode(x) + xml.NewDecoder(nil).Decode(&x) + xml.NewEncoder(nil).EncodeElement(x, xmlSE) + xml.NewDecoder(nil).DecodeElement(&x, &xmlSE) + + yaml.Marshal(x) + yaml.Unmarshal(nil, &x) + yaml.NewEncoder(nil).Encode(x) + yaml.NewDecoder(nil).Decode(&x) + + toml.Unmarshal(nil, &x) + toml.Decode("", &x) + toml.DecodeFS(nil, "", &x) + toml.DecodeFile("", &x) + toml.NewEncoder(nil).Encode(x) + toml.NewDecoder(nil).Decode(&x) + + mapstructure.Decode(nil, &x) + mapstructure.DecodeMetadata(nil, &x, nil) + mapstructure.WeakDecode(nil, &x) + mapstructure.WeakDecodeMetadata(nil, &x, nil) + + custom.Marshal(x) + custom.Unmarshal(nil, &x) } // all good, nothing to report.