diff --git a/pkg/mdformatter/mdgen/mdgen.go b/pkg/mdformatter/mdgen/mdgen.go index 73bbb57..9f2fe35 100644 --- a/pkg/mdformatter/mdgen/mdgen.go +++ b/pkg/mdformatter/mdgen/mdgen.go @@ -5,7 +5,6 @@ package mdgen import ( "bytes" - "io/ioutil" "os/exec" "strconv" "strings" @@ -95,22 +94,26 @@ func (t *genCodeBlockTransformer) TransformCodeBlock(ctx mdformatter.SourceConte return b.Bytes(), nil } - if fileWithStruct, ok := infoStringAttr[infoStringKeyGoStruct]; ok { + if structLocation, ok := infoStringAttr[infoStringKeyGoStruct]; ok { // This is like mdox-gen-go-struct=<filename>:structname for now. - fs := strings.Split(fileWithStruct, ":") - src, err := ioutil.ReadFile(fs[0]) + sn := strings.Split(structLocation, ":") + + // Get source code of struct. + src, err := yamlgen.GetSource(ctx, structLocation) if err != nil { - return nil, errors.Wrapf(err, "read file for yaml gen %v", fs[0]) + return nil, errors.Wrapf(err, "get source code for yaml gen %v", structLocation) } + // Generate YAML gen code from source. generatedCode, err := yamlgen.GenGoCode(src) if err != nil { - return nil, errors.Wrapf(err, "generate code for yaml gen %v", fs[0]) + return nil, errors.Wrapf(err, "generate code for yaml gen %v", sn[0]) } + // Execute and fetch output of generated code. b, err := yamlgen.ExecGoCode(ctx, generatedCode) if err != nil { - return nil, errors.Wrapf(err, "execute generated code for yaml gen %v", fs[0]) + return nil, errors.Wrapf(err, "execute generated code for yaml gen %v", sn[0]) } // TODO(saswatamcode): This feels sort of hacky, need better way of printing. @@ -119,7 +122,7 @@ func (t *genCodeBlockTransformer) TransformCodeBlock(ctx mdformatter.SourceConte for _, yaml := range yamls { lines := bytes.Split(yaml, []byte("\n")) if len(lines) > 1 { - if string(lines[1]) == fs[1] { + if string(lines[1]) == sn[1] { ret := bytes.Join(lines[2:len(lines)-1], []byte("\n")) ret = append(ret, []byte("\n")...) return ret, nil diff --git a/pkg/yamlgen/yamlgen.go b/pkg/yamlgen/yamlgen.go index 0787ae0..31f79be 100644 --- a/pkg/yamlgen/yamlgen.go +++ b/pkg/yamlgen/yamlgen.go @@ -8,7 +8,6 @@ import ( "context" "fmt" "go/ast" - "go/importer" "go/parser" "go/token" "go/types" @@ -16,6 +15,8 @@ import ( "os" "os/exec" "path/filepath" + "regexp" + "strings" "github.com/dave/jennifer/jen" "github.com/pkg/errors" @@ -24,7 +25,85 @@ import ( // TODO(saswatamcode): Add tests. // TODO(saswatamcode): Check jennifer code for some safety. // TODO(saswatamcode): Add mechanism for caching output from generated code. -// TODO(saswatamcode): Currently takes file names, need to make it module based(something such as https://golang.org/pkg/cmd/go/internal/list/). + +// getSourceFromMod fetches source code file from $GOPATH/pkg/mod. +func getSourceFromMod(root string, structName string) ([]byte, error) { + var src []byte + stopWalk := errors.New("stop walking") + + // Walk source dir. + err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + // Check if file is Go code. + if !info.IsDir() && filepath.Ext(path) == ".go" && err == nil { + src, err = ioutil.ReadFile(path) + if err != nil { + return errors.Wrapf(err, "read file for yaml gen %v", path) + } + // Check if file has struct. + if bytes.Contains(src, []byte("type "+structName+" struct")) { + return stopWalk + } + } + return nil + }) + if err == stopWalk { + err = nil + } + + return src, err +} + +// GetSource get source code of file containing the struct. +func GetSource(ctx context.Context, structLocation string) ([]byte, error) { + // Get struct name. + loc := strings.Split(structLocation, ":") + + // Check if it is a local file. + _, err := os.Stat(loc[0]) + if err == nil { + src, err := ioutil.ReadFile(loc[0]) + if err != nil { + return nil, errors.Wrapf(err, "read file for yaml gen %v", loc[0]) + } + + // As it is a local file, return source directly. + return src, nil + } + + // Not local file so must be module. Will be of form `github.com/bwplotka/mdox@v0.2.2-0.20210712170635-f49414cc6b5a/pkg/mdformatter/linktransformer:Config`. + // Split using version number (if provided). + getModule := loc[0] + moduleName := strings.SplitN(loc[0], "@", 2) + if len(moduleName) == 2 { + // Split package dir (if provided). + pkg := strings.SplitN(moduleName[1], "/", 2) + if len(pkg) == 2 { + getModule = moduleName[0] + "@" + pkg[0] + } + } + //TODO(saswatamcode): Handle case where version number not present but package name is. + + // Fetch module. + cmd := exec.CommandContext(ctx, "go", "get", "-u", getModule) + err = cmd.Run() + if err != nil { + return nil, errors.Wrapf(err, "run %v", cmd) + } + + // Get GOPATH. + goPath, ok := os.LookupEnv("GOPATH") + if !ok { + return nil, errors.New("GOPATH not set") + } + + // Get source file of struct. + file, err := getSourceFromMod(filepath.Join(goPath, "pkg/mod", loc[0]), loc[1]) + if err != nil { + return nil, err + } + + return file, nil +} // GenGoCode generates Go code for yaml gen from structs in src file. func GenGoCode(src []byte) (string, error) { @@ -56,7 +135,11 @@ func GenGoCode(src []byte) (string, error) { if typeDecl, ok := genericDecl.Specs[0].(*ast.TypeSpec); ok { var structFields []jen.Code // Cast to `type struct`. - structDecl := typeDecl.Type.(*ast.StructType) + structDecl, ok := typeDecl.Type.(*ast.StructType) + if !ok { + generatedCode.Type().Id(typeDecl.Name.Name).Id(string(src[typeDecl.Type.Pos()-1 : typeDecl.Type.End()-1])) + continue + } fields := structDecl.Fields.List arrayInit := make(jen.Dict) @@ -68,20 +151,12 @@ func GenGoCode(src []byte) (string, error) { if n.IsExported() { pos := n.Obj.Decl.(*ast.Field) - // Make type map to check if field is array. - info := types.Info{Types: make(map[ast.Expr]types.TypeAndValue)} - _, err = (&types.Config{Importer: importer.ForCompiler(fset, "source", nil)}).Check("mypkg", fset, []*ast.File{f}, &info) - if err != nil { - return "", err - } - typ := info.Types[field.Type].Type - - switch typ.(type) { - case *types.Slice: - // Field is of type array so initialize it using code like `[]Type{Type{}}`. + // Check if field is a slice type. + sliceRe := regexp.MustCompile(`.*\[.*\].*`) + if sliceRe.MatchString(types.ExprString(field.Type)) { arrayInit[jen.Id(n.Name)] = jen.Id(string(src[pos.Type.Pos()-1 : pos.Type.End()-1])).Values(jen.Id(string(src[pos.Type.Pos()+1 : pos.Type.End()-1])).Values()) - default: } + // Copy struct field to generated code. if pos.Tag != nil { structFields = append(structFields, jen.Id(n.Name).Id(string(src[pos.Type.Pos()-1:pos.Type.End()-1])).Id(pos.Tag.Value))