Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhancement: record token.FileSet for every file #1393

Merged
merged 1 commit into from
Nov 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions packages.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package swag

import (
"fmt"
"go/ast"
goparser "go/parser"
"go/token"
Expand All @@ -19,6 +20,7 @@ type PackagesDefinitions struct {
packages map[string]*PackageDefinitions
uniqueDefinitions map[string]*TypeSpecDef
parseDependency bool
debug Debugger
}

// NewPackagesDefinitions create object PackagesDefinitions.
Expand All @@ -30,8 +32,19 @@ func NewPackagesDefinitions() *PackagesDefinitions {
}
}

// CollectAstFile collect ast.file.
func (pkgDefs *PackagesDefinitions) CollectAstFile(packageDir, path string, astFile *ast.File) error {
// ParseFile parse a source file.
func (pkgDefs *PackagesDefinitions) ParseFile(packageDir, path string, src interface{}) error {
// positions are relative to FileSet
fileSet := token.NewFileSet()
astFile, err := goparser.ParseFile(fileSet, path, src, goparser.ParseComments)
if err != nil {
return fmt.Errorf("failed to parse file %s, error:%+v", path, err)
}
return pkgDefs.collectAstFile(fileSet, packageDir, path, astFile)
}

// collectAstFile collect ast.file.
func (pkgDefs *PackagesDefinitions) collectAstFile(fileSet *token.FileSet, packageDir, path string, astFile *ast.File) error {
if pkgDefs.files == nil {
pkgDefs.files = make(map[*ast.File]*AstFileInfo)
}
Expand Down Expand Up @@ -64,6 +77,7 @@ func (pkgDefs *PackagesDefinitions) CollectAstFile(packageDir, path string, astF
}

pkgDefs.files[astFile] = &AstFileInfo{
FileSet: fileSet,
File: astFile,
Path: path,
PackagePath: packageDir,
Expand All @@ -73,9 +87,9 @@ func (pkgDefs *PackagesDefinitions) CollectAstFile(packageDir, path string, astF
}

// RangeFiles for range the collection of ast.File in alphabetic order.
func rangeFiles(files map[*ast.File]*AstFileInfo, handle func(filename string, file *ast.File) error) error {
sortedFiles := make([]*AstFileInfo, 0, len(files))
for _, info := range files {
func (pkgDefs *PackagesDefinitions) RangeFiles(handle func(filename string, file *ast.File) error) error {
sortedFiles := make([]*AstFileInfo, 0, len(pkgDefs.files))
for _, info := range pkgDefs.files {
// ignore package path prefix with 'vendor' or $GOROOT,
// because the router info of api will not be included these files.
if strings.HasPrefix(info.PackagePath, "vendor") || strings.HasPrefix(info.Path, runtime.GOROOT()) {
Expand Down Expand Up @@ -270,6 +284,14 @@ func (pkgDefs *PackagesDefinitions) evaluateAllConstVariables() {
// EvaluateConstValue evaluate a const variable.
func (pkgDefs *PackagesDefinitions) EvaluateConstValue(pkg *PackageDefinitions, cv *ConstVariable, recursiveStack map[string]struct{}) (interface{}, ast.Expr) {
if expr, ok := cv.Value.(ast.Expr); ok {
defer func() {
if err := recover(); err != nil {
if fi, ok := pkgDefs.files[cv.File]; ok {
pos := fi.FileSet.Position(cv.Name.NamePos)
pkgDefs.debug.Printf("warning: failed to evaluate const %s at %s:%d:%d, %v", cv.Name.Name, fi.Path, pos.Line, pos.Column, err)
}
}
}()
if recursiveStack == nil {
recursiveStack = make(map[string]struct{})
}
Expand Down
29 changes: 20 additions & 9 deletions packages_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,45 @@ import (
"github.com/stretchr/testify/assert"
)

func TestPackagesDefinitions_CollectAstFile(t *testing.T) {
func TestPackagesDefinitions_ParseFile(t *testing.T) {
pd := PackagesDefinitions{}
assert.NoError(t, pd.CollectAstFile("", "", nil))
packageDir := "github.com/swaggo/swag/testdata/simple"
assert.NoError(t, pd.ParseFile(packageDir, "testdata/simple/main.go", nil))
assert.Equal(t, 1, len(pd.packages))
assert.Equal(t, 1, len(pd.files))
}

func TestPackagesDefinitions_collectAstFile(t *testing.T) {
pd := PackagesDefinitions{}
fileSet := token.NewFileSet()
assert.NoError(t, pd.collectAstFile(fileSet, "", "", nil))

firstFile := &ast.File{
Name: &ast.Ident{Name: "main.go"},
}

packageDir := "github.com/swaggo/swag/testdata/simple"
assert.NoError(t, pd.CollectAstFile(packageDir, "testdata/simple/"+firstFile.Name.String(), firstFile))
assert.NoError(t, pd.collectAstFile(fileSet, packageDir, "testdata/simple/"+firstFile.Name.String(), firstFile))
assert.NotEmpty(t, pd.packages[packageDir])

absPath, _ := filepath.Abs("testdata/simple/" + firstFile.Name.String())
astFileInfo := &AstFileInfo{
FileSet: fileSet,
File: firstFile,
Path: absPath,
PackagePath: packageDir,
}
assert.Equal(t, pd.files[firstFile], astFileInfo)

// Override
assert.NoError(t, pd.CollectAstFile(packageDir, "testdata/simple/"+firstFile.Name.String(), firstFile))
assert.NoError(t, pd.collectAstFile(fileSet, packageDir, "testdata/simple/"+firstFile.Name.String(), firstFile))
assert.Equal(t, pd.files[firstFile], astFileInfo)

// Another file
secondFile := &ast.File{
Name: &ast.Ident{Name: "api.go"},
}
assert.NoError(t, pd.CollectAstFile(packageDir, "testdata/simple/"+secondFile.Name.String(), secondFile))
assert.NoError(t, pd.collectAstFile(fileSet, packageDir, "testdata/simple/"+secondFile.Name.String(), secondFile))
}

func TestPackagesDefinitions_rangeFiles(t *testing.T) {
Expand All @@ -62,7 +72,7 @@ func TestPackagesDefinitions_rangeFiles(t *testing.T) {
}

i, expect := 0, []string{"testdata/simple/api/api.go", "testdata/simple/main.go"}
_ = rangeFiles(pd.files, func(filename string, file *ast.File) error {
_ = pd.RangeFiles(func(filename string, file *ast.File) error {
assert.Equal(t, expect[i], filename)
i++
return nil
Expand Down Expand Up @@ -182,7 +192,8 @@ func TestPackagesDefinitions_FindTypeSpec(t *testing.T) {
}

func TestPackage_rangeFiles(t *testing.T) {
files := map[*ast.File]*AstFileInfo{
pd := NewPackagesDefinitions()
pd.files = map[*ast.File]*AstFileInfo{
{
Name: &ast.Ident{Name: "main.go"},
}: {
Expand Down Expand Up @@ -218,10 +229,10 @@ func TestPackage_rangeFiles(t *testing.T) {
sorted = append(sorted, filename)
return nil
}
assert.NoError(t, rangeFiles(files, processor))
assert.NoError(t, pd.RangeFiles(processor))
assert.Equal(t, []string{"testdata/simple/api/api.go", "testdata/simple/main.go"}, sorted)

assert.Error(t, rangeFiles(files, func(filename string, file *ast.File) error {
assert.Error(t, pd.RangeFiles(func(filename string, file *ast.File) error {
return ErrFuncTypeField
}))

Expand Down
19 changes: 4 additions & 15 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ func New(options ...func(*Parser)) *Parser {
option(parser)
}

parser.packages.debug = parser.debug

return parser
}

Expand Down Expand Up @@ -276,7 +278,6 @@ func SetDebugger(logger Debugger) func(parser *Parser) {
if logger != nil {
p.debug = logger
}

}
}

Expand Down Expand Up @@ -377,7 +378,7 @@ func (parser *Parser) ParseAPIMultiSearchDir(searchDirs []string, mainAPIFile st
return err
}

err = rangeFiles(parser.packages.files, parser.ParseRouterAPIInfo)
err = parser.packages.RangeFiles(parser.ParseRouterAPIInfo)
if err != nil {
return err
}
Expand Down Expand Up @@ -969,7 +970,6 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) (
if err == ErrRecursiveParseStruct && ref {
return parser.getRefTypeSchema(typeSpecDef, schema), nil
}

return nil, err
}
}
Expand Down Expand Up @@ -1520,18 +1520,7 @@ func (parser *Parser) parseFile(packageDir, path string, src interface{}) error
return nil
}

// positions are relative to FileSet
astFile, err := goparser.ParseFile(token.NewFileSet(), path, src, goparser.ParseComments)
if err != nil {
return fmt.Errorf("ParseFile error:%+v", err)
}

err = parser.packages.CollectAstFile(packageDir, path, astFile)
if err != nil {
return err
}

return nil
return parser.packages.ParseFile(packageDir, path, src)
}

func (parser *Parser) checkOperationIDUniqueness() error {
Expand Down
Loading