Skip to content

Commit

Permalink
Add support for modules
Browse files Browse the repository at this point in the history
Signed-off-by: Saswata Mukherjee <[email protected]>
  • Loading branch information
saswatamcode committed Jul 22, 2021
1 parent dc8cc39 commit 34b021c
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 23 deletions.
19 changes: 11 additions & 8 deletions pkg/mdformatter/mdgen/mdgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package mdgen

import (
"bytes"
"io/ioutil"
"os/exec"
"strconv"
"strings"
Expand Down Expand Up @@ -95,22 +94,26 @@ func (t *genCodeBlockTransformer) TransformCodeBlock(ctx mdformatter.SourceConte
return b.Bytes(), nil
}

if fileWithStruct, ok := infoStringAttr[infoStringKeyGoStruct]; ok {
if structLocation, ok := infoStringAttr[infoStringKeyGoStruct]; ok {
// This is like mdox-gen-go-struct=<filename>:structname for now.
fs := strings.Split(fileWithStruct, ":")
src, err := ioutil.ReadFile(fs[0])
sn := strings.Split(structLocation, ":")

// Get source code of struct.
src, err := yamlgen.GetSource(ctx, structLocation)
if err != nil {
return nil, errors.Wrapf(err, "read file for yaml gen %v", fs[0])
return nil, errors.Wrapf(err, "get source code for yaml gen %v", structLocation)
}

// Generate YAML gen code from source.
generatedCode, err := yamlgen.GenGoCode(src)
if err != nil {
return nil, errors.Wrapf(err, "generate code for yaml gen %v", fs[0])
return nil, errors.Wrapf(err, "generate code for yaml gen %v", sn[0])
}

// Execute and fetch output of generated code.
b, err := yamlgen.ExecGoCode(ctx, generatedCode)
if err != nil {
return nil, errors.Wrapf(err, "execute generated code for yaml gen %v", fs[0])
return nil, errors.Wrapf(err, "execute generated code for yaml gen %v", sn[0])
}

// TODO(saswatamcode): This feels sort of hacky, need better way of printing.
Expand All @@ -119,7 +122,7 @@ func (t *genCodeBlockTransformer) TransformCodeBlock(ctx mdformatter.SourceConte
for _, yaml := range yamls {
lines := bytes.Split(yaml, []byte("\n"))
if len(lines) > 1 {
if string(lines[1]) == fs[1] {
if string(lines[1]) == sn[1] {
ret := bytes.Join(lines[2:len(lines)-1], []byte("\n"))
ret = append(ret, []byte("\n")...)
return ret, nil
Expand Down
105 changes: 90 additions & 15 deletions pkg/yamlgen/yamlgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ import (
"context"
"fmt"
"go/ast"
"go/importer"
"go/parser"
"go/token"
"go/types"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"

"github.com/dave/jennifer/jen"
"github.com/pkg/errors"
Expand All @@ -24,7 +25,85 @@ import (
// TODO(saswatamcode): Add tests.
// TODO(saswatamcode): Check jennifer code for some safety.
// TODO(saswatamcode): Add mechanism for caching output from generated code.
// TODO(saswatamcode): Currently takes file names, need to make it module based(something such as https://golang.org/pkg/cmd/go/internal/list/).

// getSourceFromMod fetches source code file from $GOPATH/pkg/mod.
func getSourceFromMod(root string, structName string) ([]byte, error) {
var src []byte
stopWalk := errors.New("stop walking")

// Walk source dir.
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
// Check if file is Go code.
if !info.IsDir() && filepath.Ext(path) == ".go" && err == nil {
src, err = ioutil.ReadFile(path)
if err != nil {
return errors.Wrapf(err, "read file for yaml gen %v", path)
}
// Check if file has struct.
if bytes.Contains(src, []byte("type "+structName+" struct")) {
return stopWalk
}
}
return nil
})
if err == stopWalk {
err = nil
}

return src, err
}

// GetSource get source code of file containing the struct.
func GetSource(ctx context.Context, structLocation string) ([]byte, error) {
// Get struct name.
loc := strings.Split(structLocation, ":")

// Check if it is a local file.
_, err := os.Stat(loc[0])
if err == nil {
src, err := ioutil.ReadFile(loc[0])
if err != nil {
return nil, errors.Wrapf(err, "read file for yaml gen %v", loc[0])
}

// As it is a local file, return source directly.
return src, nil
}

// Not local file so must be module. Will be of form `github.com/bwplotka/[email protected]/pkg/mdformatter/linktransformer:Config`.
// Split using version number (if provided).
getModule := loc[0]
moduleName := strings.SplitN(loc[0], "@", 2)
if len(moduleName) == 2 {
// Split package dir (if provided).
pkg := strings.SplitN(moduleName[1], "/", 2)
if len(pkg) == 2 {
getModule = moduleName[0] + "@" + pkg[0]
}
}
//TODO(saswatamcode): Handle case where version number not present but package name is.

// Fetch module.
cmd := exec.CommandContext(ctx, "go", "get", "-u", getModule)
err = cmd.Run()
if err != nil {
return nil, errors.Wrapf(err, "run %v", cmd)
}

// Get GOPATH.
goPath, ok := os.LookupEnv("GOPATH")
if !ok {
return nil, errors.New("GOPATH not set")
}

// Get source file of struct.
file, err := getSourceFromMod(filepath.Join(goPath, "pkg/mod", loc[0]), loc[1])
if err != nil {
return nil, err
}

return file, nil
}

// GenGoCode generates Go code for yaml gen from structs in src file.
func GenGoCode(src []byte) (string, error) {
Expand Down Expand Up @@ -56,7 +135,11 @@ func GenGoCode(src []byte) (string, error) {
if typeDecl, ok := genericDecl.Specs[0].(*ast.TypeSpec); ok {
var structFields []jen.Code
// Cast to `type struct`.
structDecl := typeDecl.Type.(*ast.StructType)
structDecl, ok := typeDecl.Type.(*ast.StructType)
if !ok {
generatedCode.Type().Id(typeDecl.Name.Name).Id(string(src[typeDecl.Type.Pos()-1 : typeDecl.Type.End()-1]))
continue
}
fields := structDecl.Fields.List
arrayInit := make(jen.Dict)

Expand All @@ -68,20 +151,12 @@ func GenGoCode(src []byte) (string, error) {
if n.IsExported() {
pos := n.Obj.Decl.(*ast.Field)

// Make type map to check if field is array.
info := types.Info{Types: make(map[ast.Expr]types.TypeAndValue)}
_, err = (&types.Config{Importer: importer.ForCompiler(fset, "source", nil)}).Check("mypkg", fset, []*ast.File{f}, &info)
if err != nil {
return "", err
}
typ := info.Types[field.Type].Type

switch typ.(type) {
case *types.Slice:
// Field is of type array so initialize it using code like `[]Type{Type{}}`.
// Check if field is a slice type.
sliceRe := regexp.MustCompile(`.*\[.*\].*`)
if sliceRe.MatchString(types.ExprString(field.Type)) {
arrayInit[jen.Id(n.Name)] = jen.Id(string(src[pos.Type.Pos()-1 : pos.Type.End()-1])).Values(jen.Id(string(src[pos.Type.Pos()+1 : pos.Type.End()-1])).Values())
default:
}

// Copy struct field to generated code.
if pos.Tag != nil {
structFields = append(structFields, jen.Id(n.Name).Id(string(src[pos.Type.Pos()-1:pos.Type.End()-1])).Id(pos.Tag.Value))
Expand Down

0 comments on commit 34b021c

Please sign in to comment.