Skip to content

Commit

Permalink
feat: actually print file from ast
Browse files Browse the repository at this point in the history
  • Loading branch information
haunt98 committed Nov 26, 2022
1 parent ec9ec20 commit 30f79e9
Showing 1 changed file with 64 additions and 37 deletions.
101 changes: 64 additions & 37 deletions internal/imports/formatter.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package imports

import (
"bytes"
"errors"
"fmt"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"log"
"os"
Expand Down Expand Up @@ -132,14 +134,7 @@ func (ft *Formatter) formatFile(path string) error {
}
ft.log("formatFile: moduleName: %+v\n", moduleName)

// Parse ast
pathASTFile, err := ft.parseAST(path, pathBytes)
if err != nil {
return err
}
ft.log("formatFile: pathASTFile: %+v\n", pathASTFile)

if err := ft.formatASTFile(pathASTFile, moduleName); err != nil {
if err := ft.formatImports(path, pathBytes, moduleName); err != nil {
return err
}

Expand All @@ -150,36 +145,36 @@ func (ft *Formatter) formatFile(path string) error {
return nil
}

func (ft *Formatter) parseAST(path string, pathBytes []byte) (*ast.File, error) {
// Copy from goimports-reviser
func (ft *Formatter) formatImports(
path string,
pathBytes []byte,
moduleName string,
) error {
// Parse ast
fset := token.NewFileSet()

parserMode := parser.Mode(0)
parserMode |= parser.ParseComments

pathASTFile, err := parser.ParseFile(fset, path, pathBytes, parserMode)
astFile, err := parser.ParseFile(fset, path, pathBytes, parserMode)
if err != nil {
return nil, fmt.Errorf("parser: failed to parse file [%s]: %w", path, err)
return fmt.Errorf("parser: failed to parse file [%s]: %w", path, err)
}

// Ignore generated file
if isGoGenerated(pathASTFile) {
return nil, ErrGoGeneratedFile
if isGoGenerated(astFile) {
return ErrGoGeneratedFile
}

return pathASTFile, nil
}

// Copy from goimports-reviser
// If exist multi import, combine them into one
// This func will edit pathASTFile directly to combine
func (ft *Formatter) formatASTFile(
pathASTFile *ast.File,
moduleName string,
) error {
// Extract imports
importSpecs := make([]ast.Spec, 0, len(pathASTFile.Imports))
for _, importSpec := range pathASTFile.Imports {
ft.log("parseImportsAndCombine: importSpec: %+v %+v\n", importSpec.Name.String(), importSpec.Path.Value)
importSpecs := make([]ast.Spec, 0, len(astFile.Imports))
for _, importSpec := range astFile.Imports {
if importSpec.Path.Value == "" {
continue
}

ft.log("formatImports: importSpec: %+v %+v\n", importSpec.Name.String(), importSpec.Path.Value)
importSpecs = append(importSpecs, importSpec)
}

Expand All @@ -195,12 +190,13 @@ func (ft *Formatter) formatASTFile(
if err != nil {
return err
}
ft.mustLogImportSpecs("formatImports: formattedImportSpecs: ", formattedImportSpecs)

// Combine multi import decl
// Combine multi import decl into one
isMultiImportDecl := false
isExistFirstImportDecl := false
decls := make([]ast.Decl, 0, len(pathASTFile.Decls))
for _, decl := range pathASTFile.Decls {
decls := make([]ast.Decl, 0, len(astFile.Decls))
for _, decl := range astFile.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok {
decls = append(decls, decl)
Expand All @@ -215,8 +211,8 @@ func (ft *Formatter) formatASTFile(
if isExistFirstImportDecl {
isMultiImportDecl = true
// TODO: explain this
storedGenDecl := decls[len(decls)-1].(*ast.GenDecl)
if storedGenDecl.Tok == token.IMPORT {
storedGenDecl, ok := decls[len(decls)-1].(*ast.GenDecl)
if ok && storedGenDecl.Tok == token.IMPORT {
storedGenDecl.Rparen = genDecl.End()
}
continue
Expand All @@ -227,12 +223,20 @@ func (ft *Formatter) formatASTFile(
genDecl.Specs = formattedImportSpecs
decls = append(decls, genDecl)
}
ft.log("parseImportsAndCombine: decls: %+v\n", decls)

if isMultiImportDecl {
pathASTFile.Decls = decls
astFile.Decls = decls
}

// Print formatted bytes from formatted ast
var formattedBytes []byte
formattedBuffer := bytes.NewBuffer(formattedBytes)
if err := printer.Fprint(formattedBuffer, fset, astFile); err != nil {
return err
}

fmt.Println(formattedBuffer.String())

return nil
}

Expand Down Expand Up @@ -278,16 +282,17 @@ func (ft *Formatter) groupImportSpecs(
result[thirdPartyImport] = append(result[thirdPartyImport], importSpec)
}

ft.log("groupImports: std: %+v\n", result[stdImport])
ft.log("groupImports: third-party: %+v\n", result[thirdPartyImport])
ft.logImportSpecs("stdImport", result[stdImport])
ft.logImportSpecs("thirdPartyImport", result[thirdPartyImport])
if ft.companyPrefix != "" {
ft.log("groupImports: company: %+v\n", result[companyImport])
ft.logImportSpecs("companyImport", result[companyImport])
}
ft.log("groupImports: local: %+v\n", result[localImport])
ft.logImportSpecs("localImport", result[localImport])

return result, nil
}

// Copy from goimports-reviser
func (ft *Formatter) formatImportSpecs(
importSpecs []ast.Spec,
groupedImportSpecs map[string][]*ast.ImportSpec,
Expand Down Expand Up @@ -350,6 +355,7 @@ func (ft *Formatter) formatImportSpecs(
return result, nil
}

// Copy from goimports-reviser
func (ft *Formatter) moduleName(path string) (string, error) {
ft.muModuleNames.RLock()
if pkgName, ok := ft.moduleNames[path]; ok {
Expand Down Expand Up @@ -413,3 +419,24 @@ func (ft *Formatter) log(format string, v ...any) {
log.Printf(format, v...)
}
}

func (ft *Formatter) logImportSpecs(logPrefix string, importSpecs []*ast.ImportSpec) {
if ft.isVerbose {
for _, importSpec := range importSpecs {
log.Printf("%s: importSpec: %+v %+v\n", logPrefix, importSpec.Name.String(), importSpec.Path.Value)
}
}
}

func (ft *Formatter) mustLogImportSpecs(logPrefix string, importSpecs []ast.Spec) {
if ft.isVerbose {
for _, importSpec := range importSpecs {
importSpec, ok := importSpec.(*ast.ImportSpec)
if !ok {
continue
}

log.Printf("%s: importSpec: %+v %+v\n", logPrefix, importSpec.Name.String(), importSpec.Path.Value)
}
}
}

0 comments on commit 30f79e9

Please sign in to comment.