diff --git a/_test/install/binary_gopath/expected.txt b/_test/install/binary_gopath/expected.txt index 3aa61cce..286d66e7 100644 --- a/_test/install/binary_gopath/expected.txt +++ b/_test/install/binary_gopath/expected.txt @@ -1,4 +1,5 @@ -/root/target.go:8:10: exprUnparen: the parentheses around 1 are superfluous (rules.go:15) -/root/target.go:9:10: exprUnparen: the parentheses around 2 are superfluous (rules.go:15) +/root/target.go:8:10: exprUnparen: the parentheses around 1 are superfluous (rules.go:21) +/root/target.go:9:10: exprUnparen: the parentheses around 2 are superfluous (rules.go:21) /root/target.go:11:10: boolComparison: omit bool literal in expression (rules1.go:8) /root/target.go:12:10: boolExprSimplify: suggestion: b (rules2.go:6) +/root/target.go:15:10: interfaceAddr: taking address of interface-typed value (rules.go:27) diff --git a/_test/install/binary_gopath/rules.go b/_test/install/binary_gopath/rules.go index 22bba22b..422879ae 100644 --- a/_test/install/binary_gopath/rules.go +++ b/_test/install/binary_gopath/rules.go @@ -4,6 +4,7 @@ package gorules import ( "github.com/quasilyte/go-ruleguard/dsl" + "github.com/quasilyte/go-ruleguard/dsl/types" testrules "github.com/quasilyte/ruleguard-rules-test" ) @@ -11,8 +12,19 @@ func init() { dsl.ImportRules("", testrules.Bundle) } +func isInterface(ctx *dsl.VarFilterContext) bool { + // Could be written as m["x"].Type.Underlying().Is(`interface{$*_}`) in DSL. + return types.AsInterface(ctx.Type.Underlying()) != nil +} + func exprUnparen(m dsl.Matcher) { m.Match(`$f($*_, ($x), $*_)`). Report(`the parentheses around $x are superfluous`). Suggest(`$f($x)`) } + +func interfaceAddr(m dsl.Matcher) { + m.Match(`&$x`). + Where(m["x"].Filter(isInterface)). + Report(`taking address of interface-typed value`) +} diff --git a/_test/install/binary_gopath/target.go b/_test/install/binary_gopath/target.go index e8678368..a2c79744 100644 --- a/_test/install/binary_gopath/target.go +++ b/_test/install/binary_gopath/target.go @@ -10,4 +10,7 @@ func test(b bool) { println(b == true) println(!!b) + + var eface interface{} + println(&eface) } diff --git a/_test/install/binary_nogopath/expected.txt b/_test/install/binary_nogopath/expected.txt index 852ad068..286d66e7 100644 --- a/_test/install/binary_nogopath/expected.txt +++ b/_test/install/binary_nogopath/expected.txt @@ -1,4 +1,5 @@ -/root/target.go:8:10: exprUnparen: the parentheses around 1 are superfluous (rules.go:15) -/root/target.go:9:10: exprUnparen: the parentheses around 2 are superfluous (rules.go:15) -/root/target.go:11:10: testrules/boolComparison: omit bool literal in expression (rules1.go:8) -/root/target.go:12:10: testrules/boolExprSimplify: suggestion: b (rules2.go:6) +/root/target.go:8:10: exprUnparen: the parentheses around 1 are superfluous (rules.go:21) +/root/target.go:9:10: exprUnparen: the parentheses around 2 are superfluous (rules.go:21) +/root/target.go:11:10: boolComparison: omit bool literal in expression (rules1.go:8) +/root/target.go:12:10: boolExprSimplify: suggestion: b (rules2.go:6) +/root/target.go:15:10: interfaceAddr: taking address of interface-typed value (rules.go:27) diff --git a/_test/install/binary_nogopath/rules.go b/_test/install/binary_nogopath/rules.go index 6f942bb8..422879ae 100644 --- a/_test/install/binary_nogopath/rules.go +++ b/_test/install/binary_nogopath/rules.go @@ -4,11 +4,17 @@ package gorules import ( "github.com/quasilyte/go-ruleguard/dsl" + "github.com/quasilyte/go-ruleguard/dsl/types" testrules "github.com/quasilyte/ruleguard-rules-test" ) func init() { - dsl.ImportRules("testrules", testrules.Bundle) + dsl.ImportRules("", testrules.Bundle) +} + +func isInterface(ctx *dsl.VarFilterContext) bool { + // Could be written as m["x"].Type.Underlying().Is(`interface{$*_}`) in DSL. + return types.AsInterface(ctx.Type.Underlying()) != nil } func exprUnparen(m dsl.Matcher) { @@ -16,3 +22,9 @@ func exprUnparen(m dsl.Matcher) { Report(`the parentheses around $x are superfluous`). Suggest(`$f($x)`) } + +func interfaceAddr(m dsl.Matcher) { + m.Match(`&$x`). + Where(m["x"].Filter(isInterface)). + Report(`taking address of interface-typed value`) +} diff --git a/_test/install/binary_nogopath/target.go b/_test/install/binary_nogopath/target.go index e8678368..a2c79744 100644 --- a/_test/install/binary_nogopath/target.go +++ b/_test/install/binary_nogopath/target.go @@ -10,4 +10,7 @@ func test(b bool) { println(b == true) println(!!b) + + var eface interface{} + println(&eface) } diff --git a/_test/install/gitclone/expected.txt b/_test/install/gitclone/expected.txt index 909ede4f..226e07ab 100644 --- a/_test/install/gitclone/expected.txt +++ b/_test/install/gitclone/expected.txt @@ -1,2 +1,2 @@ -/root/target.go:8:10: exprUnparen: the parentheses around 1 are superfluous -/root/target.go:9:10: exprUnparen: the parentheses around 2 are superfluous +/root/target.go:8:10: exprUnparen: the parentheses around 1 are superfluous (rules.go:10) +/root/target.go:9:10: exprUnparen: the parentheses around 2 are superfluous (rules.go:10) diff --git a/analyzer/analyzer.go b/analyzer/analyzer.go index 63c69882..0e55643d 100644 --- a/analyzer/analyzer.go +++ b/analyzer/analyzer.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "strings" + "sync" "github.com/quasilyte/go-ruleguard/ruleguard" "golang.org/x/tools/go/analysis" @@ -21,6 +22,16 @@ var Analyzer = &analysis.Analyzer{ Run: runAnalyzer, } +// ForceNewEngine disables engine cache optimization. +// This should only be useful for analyzer testing. +var ForceNewEngine = false + +var ( + globalEngineMu sync.Mutex + globalEngine *ruleguard.Engine + globalEngineErrored bool +) + var ( flagRules string flagE string @@ -28,6 +39,7 @@ var ( flagDisable string flagDebug string + flagDebugFilter string flagDebugImports bool flagDebugEnableDisable bool ) @@ -35,48 +47,48 @@ var ( func init() { Analyzer.Flags.StringVar(&flagRules, "rules", "", "comma-separated list of gorule file paths") Analyzer.Flags.StringVar(&flagE, "e", "", "execute a single rule from a given string") - Analyzer.Flags.StringVar(&flagDebug, "debug-group", "", "enable debug for the specified function") + Analyzer.Flags.StringVar(&flagDebug, "debug-group", "", "enable debug for the specified matcher function") + Analyzer.Flags.StringVar(&flagDebugFilter, "debug-filter", "", "enable debug for the specified filter function") Analyzer.Flags.StringVar(&flagEnable, "enable", "", "comma-separated list of enabled groups or '' to enable everything") Analyzer.Flags.StringVar(&flagDisable, "disable", "", "comma-separated list of groups to be disabled") Analyzer.Flags.BoolVar(&flagDebugImports, "debug-imports", false, "enable debug for rules compile-time package lookups") Analyzer.Flags.BoolVar(&flagDebugEnableDisable, "debug-enable-disable", false, "enable debug for -enable/-disable related info") } -type parseRulesResult struct { - rset *ruleguard.GoRuleSet - multiFile bool -} - func debugPrint(s string) { fmt.Fprintln(os.Stderr, s) } func runAnalyzer(pass *analysis.Pass) (interface{}, error) { - // TODO(quasilyte): parse config under sync.Once and - // create rule sets from it. - - parseResult, err := readRules() + engine, err := prepareEngine() if err != nil { return nil, fmt.Errorf("load rules: %v", err) } - rset := parseResult.rset - multiFile := parseResult.multiFile - - ctx := &ruleguard.Context{ - Debug: flagDebug, - DebugPrint: debugPrint, - Pkg: pass.Pkg, - Types: pass.TypesInfo, - Sizes: pass.TypesSizes, - Fset: pass.Fset, + // This condition will trigger only if we failed to init + // the engine. Return without an error as other analysis + // pass probably reported init error by this moment. + if engine == nil { + return nil, nil + } + + printRuleLocation := flagE == "" + + ctx := &ruleguard.RunContext{ + Debug: flagDebug, + DebugImports: flagDebugImports, + DebugPrint: debugPrint, + Pkg: pass.Pkg, + Types: pass.TypesInfo, + Sizes: pass.TypesSizes, + Fset: pass.Fset, Report: func(info ruleguard.GoRuleInfo, n ast.Node, msg string, s *ruleguard.Suggestion) { - msg = info.Group + ": " + msg - if multiFile { - msg += fmt.Sprintf(" (%s:%d)", filepath.Base(info.Filename), info.Line) + fullMessage := info.Group + ": " + msg + if printRuleLocation { + fullMessage += fmt.Sprintf(" (%s:%d)", filepath.Base(info.Filename), info.Line) } diag := analysis.Diagnostic{ Pos: n.Pos(), - Message: msg, + Message: fullMessage, } if s != nil { diag.SuggestedFixes = []analysis.SuggestedFix{ @@ -97,7 +109,7 @@ func runAnalyzer(pass *analysis.Pass) (interface{}, error) { } for _, f := range pass.Files { - if err := ruleguard.RunRules(ctx, f, rset); err != nil { + if err := engine.Run(ctx, f); err != nil { return nil, err } } @@ -105,7 +117,33 @@ func runAnalyzer(pass *analysis.Pass) (interface{}, error) { return nil, nil } -func readRules() (*parseRulesResult, error) { +func prepareEngine() (*ruleguard.Engine, error) { + if ForceNewEngine { + return newEngine() + } + + globalEngineMu.Lock() + defer globalEngineMu.Unlock() + + if globalEngine != nil { + return globalEngine, nil + } + // If we already failed once, don't try again to avoid #167. + if globalEngineErrored { + return nil, nil + } + + engine, err := newEngine() + if err != nil { + globalEngineErrored = true + return nil, err + } + globalEngine = engine + return engine, nil +} + +func newEngine() (*ruleguard.Engine, error) { + e := ruleguard.NewEngine() fset := token.NewFileSet() disabledGroups := make(map[string]bool) @@ -123,6 +161,7 @@ func readRules() (*parseRulesResult, error) { ctx := &ruleguard.ParseContext{ Fset: fset, + DebugFilter: flagDebugFilter, DebugImports: flagDebugImports, DebugPrint: debugPrint, GroupFilter: func(g string) bool { @@ -148,28 +187,17 @@ func readRules() (*parseRulesResult, error) { switch { case flagRules != "": filenames := strings.Split(flagRules, ",") - multifile := len(filenames) > 1 - var ruleSets []*ruleguard.GoRuleSet for _, filename := range filenames { filename = strings.TrimSpace(filename) data, err := ioutil.ReadFile(filename) if err != nil { return nil, fmt.Errorf("read rules file: %v", err) } - rset, err := ruleguard.ParseRules(ctx, filename, bytes.NewReader(data)) - if err != nil { + if err := e.Load(ctx, filename, bytes.NewReader(data)); err != nil { return nil, fmt.Errorf("parse rules file: %v", err) } - if len(rset.Imports) != 0 { - multifile = true - } - ruleSets = append(ruleSets, rset) - } - rset, err := ruleguard.MergeRuleSets(ruleSets) - if err != nil { - return nil, fmt.Errorf("merge rule files: %v", err) } - return &parseRulesResult{rset: rset, multiFile: multifile}, nil + return e, nil case flagE != "": ruleText := fmt.Sprintf(` @@ -180,8 +208,11 @@ func readRules() (*parseRulesResult, error) { }`, flagE) r := strings.NewReader(ruleText) - rset, err := ruleguard.ParseRules(ctx, flagRules, r) - return &parseRulesResult{rset: rset}, err + err := e.Load(ctx, "e", r) + if err != nil { + return nil, err + } + return e, nil default: return nil, fmt.Errorf("both -e and -rules flags are empty") diff --git a/analyzer/analyzer_test.go b/analyzer/analyzer_test.go index d7ed049e..aad368fa 100644 --- a/analyzer/analyzer_test.go +++ b/analyzer/analyzer_test.go @@ -20,8 +20,10 @@ func TestAnalyzer(t *testing.T) { "golint", "regression", "testvendored", + "quasigo", } + analyzer.ForceNewEngine = true for _, test := range tests { t.Run(test, func(t *testing.T) { testdata := analysistest.TestData() diff --git a/analyzer/testdata/src/quasigo/rules.go b/analyzer/testdata/src/quasigo/rules.go new file mode 100644 index 00000000..dc7731c2 --- /dev/null +++ b/analyzer/testdata/src/quasigo/rules.go @@ -0,0 +1,110 @@ +// +build ignore + +package gorules + +import ( + "github.com/quasilyte/go-ruleguard/dsl" + "github.com/quasilyte/go-ruleguard/dsl/types" +) + +func stringUnderlying(ctx *dsl.VarFilterContext) bool { + // Test both Type.Underlying() and Type.String() methods. + return ctx.Type.Underlying().String() == `string` +} + +func isZeroSize(ctx *dsl.VarFilterContext) bool { + return ctx.SizeOf(ctx.Type) == 0 +} + +func isPointer(ctx *dsl.VarFilterContext) bool { + // There is no Type.IsT() methods (yet?), but it's possible to + // use nil comparison for that. + ptr := types.AsPointer(ctx.Type) + return ptr != nil +} + +func isInterface(ctx *dsl.VarFilterContext) bool { + // Nil can be used on either side. + return nil != types.AsInterface(ctx.Type.Underlying()) +} + +func isError(ctx *dsl.VarFilterContext) bool { + // Testing Interface.String() method. + iface := types.AsInterface(ctx.Type.Underlying()) + if iface != nil { + return iface.String() == `interface{Error() string}` + } + return false +} + +func isInterfacePtr(ctx *dsl.VarFilterContext) bool { + ptr := types.AsPointer(ctx.Type) + if ptr != nil { + return types.AsInterface(ptr.Elem().Underlying()) != nil + } + return false +} + +func typeNameHasErrorSuffix(ctx *dsl.VarFilterContext) bool { + // Test string operations; this is basically strings.HasSuffix(). + s := ctx.Type.String() + suffix := "Error" + return len(s) >= len(suffix) && s[len(s)-len(suffix):] == suffix +} + +func implementsStringer(ctx *dsl.VarFilterContext) bool { + // Using a non-constant argument to GetInterface() on purpose. + ifaceName := `fmt.Stringer` + stringer := ctx.GetInterface(ifaceName) + return types.Implements(ctx.Type, stringer) || + types.Implements(types.NewPointer(ctx.Type), stringer) +} + +func ptrElemSmallerThanUintptr(ctx *dsl.VarFilterContext) bool { + ptr := types.AsPointer(ctx.Type) + if ptr == nil { + return false // Not a pointer + } + uintptrSize := ctx.SizeOf(ctx.GetType(`uintptr`)) + elemSize := ctx.SizeOf(ptr.Elem()) + return elemSize < uintptrSize +} + +func testRules(m dsl.Matcher) { + m.Match(`test($x, "underlying type is string")`). + Where(m["x"].Filter(stringUnderlying)). + Report(`true`) + + m.Match(`test($x, "zero sized")`). + Where(m["x"].Filter(isZeroSize)). + Report(`true`) + + m.Match(`test($x, "type is pointer")`). + Where(m["x"].Filter(isPointer)). + Report(`true`) + + m.Match(`test($x, "type is error")`). + Where(m["x"].Filter(isError)). + Report(`true`) + + // Use a custom filter negation. + m.Match(`test($x, "type is not interface")`). + Where(!m["x"].Filter(isInterface)). + Report(`true`) + + m.Match(`test($x, "type name has Error suffix")`). + Where(m["x"].Filter(typeNameHasErrorSuffix)). + Report(`true`) + + m.Match(`test($x, "implements fmt.Stringer")`). + Where(m["x"].Filter(implementsStringer)). + Report(`true`) + + m.Match(`test($x, "pointer to interface")`). + Where(m["x"].Filter(isInterfacePtr)). + Report(`true`) + + m.Match(`test($x, "pointer elem value size is smaller than uintptr")`). + Where(m["x"].Filter(ptrElemSmallerThanUintptr)). + Report(`true`) +} diff --git a/analyzer/testdata/src/quasigo/target.go b/analyzer/testdata/src/quasigo/target.go new file mode 100644 index 00000000..9527eebf --- /dev/null +++ b/analyzer/testdata/src/quasigo/target.go @@ -0,0 +1,89 @@ +package quasigo + +import "fmt" + +func f() { + var i int + var stringer fmt.Stringer + var err error + + test("foo", "underlying type is string") // want `true` + test(myString("123"), "underlying type is string") // want `true` + test(0, "underlying type is string") + test(myEmptyStruct{}, "underlying type is string") + + test(myEmptyStruct{}, "zero sized") // want `true` + test(struct{}{}, "zero sized") // want `true` + test([0]func(){}, "zero sized") // want `true` + test("", "zero sized") + test(10, "zero sized") + test(true, "zero sized") + + test(new(bool), "type is pointer") // want `true` + test((*int)(nil), "type is pointer") // want `true` + test(&myEmptyStruct{}, "type is pointer") // want `true` + test(&i, "type is pointer") // want `true` + test([]int(nil), "type is pointer") + test(interface{}(nil), "type is pointer") + test(10, "type is pointer") + + test(10, "type is not interface") // want `true` + test(&i, "type is not interface") // want `true` + test(true, "type is not interface") // want `true` + test(&stringer, "type is not interface") // want `true` + test(stringer, "type is not interface") + test(interface{}(nil), "type is not interface") + + test(MyError(""), "type name has Error suffix") // want `true` + test(new(MyError), "type name has Error suffix") // want `true` + test(parseError{}, "type name has Error suffix") // want `true` + test(&parseError{}, "type name has Error suffix") // want `true` + test(0, "type name has Error suffix") + test((error)(nil), "type name has Error suffix") + + test((error)(nil), "type is error") // want `true` + test(err, "type is error") // want `true` + test(0, "type is error") + test("", "type is error") + + test(&err, "pointer to interface") // want `true` + test((*error)(nil), "pointer to interface") // want `true` + test(&stringer, "pointer to interface") // want `true` + test(new(fmt.Stringer), "pointer to interface") // want `true` + test(0, "pointer to interface") + test("", "pointer to interface") + test(err, "pointer to interface") + test(i, "pointer to interface") + test(parseError{}, "pointer to interface") + test(&parseError{}, "pointer to interface") + + test(stringer, "implements fmt.Stringer") // want `true` + test(&stringerByValue{}, "implements fmt.Stringer") // want `true` + test(stringerByValue{}, "implements fmt.Stringer") // want `true` + test(&stringerByPtr{}, "implements fmt.Stringer") // want `true` + test(stringerByPtr{}, "implements fmt.Stringer") // want `true` + test(nil, "implements fmt.Stringer") + test("", "implements fmt.Stringer") + + test(new(byte), "pointer elem value size is smaller than uintptr") // want `true` + test(new(int16), "pointer elem value size is smaller than uintptr") // want `true` + test(&stringerByPtr{}, "pointer elem value size is smaller than uintptr") // want `true` + test(new(uintptr), "pointer elem value size is smaller than uintptr") + test(true, "pointer elem value size is smaller than uintptr") +} + +type myString string + +type myEmptyStruct struct{} + +type parseError struct{} + +type MyError myString + +type stringerByValue struct{} +type stringerByPtr struct{} + +func (*stringerByPtr) String() string { return "" } +func (stringerByValue) String() string { return "" } + +func test(args ...interface{}) {} diff --git a/dsl/dsl.go b/dsl/dsl.go index 80135cdd..7b301e92 100644 --- a/dsl/dsl.go +++ b/dsl/dsl.go @@ -87,6 +87,11 @@ type Var struct { Node MatchedNode } +// Filter applies a custom predicate function on a submatch. +// +// The callback function should use VarFilterContext to access the +// information that is usually accessed through Var. +// For example, `VarFilterContext.Type` is mapped to `Var.Type`. func (Var) Filter(pred func(*VarFilterContext) bool) bool { return boolResult } // MatchedNode represents an AST node associated with a named submatch. diff --git a/dsl/filter.go b/dsl/filter.go index 61ccea09..e201408c 100644 --- a/dsl/filter.go +++ b/dsl/filter.go @@ -16,9 +16,12 @@ type VarFilterContext struct { func (*VarFilterContext) SizeOf(x types.Type) int { return 0 } // GetType finds a type value by a given name. +// // A name can be: // - builtin type name, like `error` or `string` // - fully-qualified type name, like `github.com/username/pkgname.TypeName` +// +// If a type can't be found (or a name is malformed), this function panics. func (*VarFilterContext) GetType(name string) types.Type { return nil } // GetInterface finds a type value that represents an interface by a given name. diff --git a/go.mod b/go.mod index 8a5e443e..6e7170a4 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.15 require ( github.com/google/go-cmp v0.5.2 - github.com/quasilyte/go-ruleguard/dsl v0.0.0-20201222100711-8bb37f2dbd8a + github.com/quasilyte/go-ruleguard/dsl v0.0.0-20210106184943-e47d54850b18 github.com/quasilyte/go-ruleguard/rules v0.0.0-20201231183845-9e62ed36efe1 golang.org/x/tools v0.0.0-20200812195022-5ae4c3c160a0 ) diff --git a/go.sum b/go.sum index 26679fa1..64e279f9 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/quasilyte/go-ruleguard/dsl v0.0.0-20201222100711-8bb37f2dbd8a h1:EFwllm9UIMC8jzbcPzTSDsx9gYEpYz6pwJr89Gz3Jzo= github.com/quasilyte/go-ruleguard/dsl v0.0.0-20201222100711-8bb37f2dbd8a/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU= +github.com/quasilyte/go-ruleguard/dsl v0.0.0-20210106184943-e47d54850b18 h1:YJnDZ7AWC93dIJ0nqah6jtIJWTKMhdmYWVhTyY2kpAE= +github.com/quasilyte/go-ruleguard/dsl v0.0.0-20210106184943-e47d54850b18/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU= github.com/quasilyte/go-ruleguard/rules v0.0.0-20201231183845-9e62ed36efe1 h1:PX/E0GYUnSV8vwVfpOUEIBKnPG3KmYunmNOBlL+zDko= github.com/quasilyte/go-ruleguard/rules v0.0.0-20201231183845-9e62ed36efe1/go.mod h1:7JTjp89EGyU1d6XfBiXihJNG37wB2VRkd125Q1u7Plc= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/internal/mvdan.cc/gogrep/kludge.go b/internal/mvdan.cc/gogrep/kludge.go index f366af84..f62c4aaf 100644 --- a/internal/mvdan.cc/gogrep/kludge.go +++ b/internal/mvdan.cc/gogrep/kludge.go @@ -34,6 +34,15 @@ type MatchData struct { Values map[string]ast.Node } +// Clone creates a pattern copy. +func (p *Pattern) Clone() *Pattern { + clone := *p + clone.m = &matcher{} + *clone.m = *p.m + clone.m.values = make(map[string]ast.Node) + return &clone +} + // MatchNode calls cb if n matches a pattern. func (p *Pattern) MatchNode(n ast.Node, cb func(MatchData)) { p.m.values = map[string]ast.Node{} diff --git a/ruleguard/debug_test.go b/ruleguard/debug_test.go index 5dfae900..a1e5e58f 100644 --- a/ruleguard/debug_test.go +++ b/ruleguard/debug_test.go @@ -142,7 +142,7 @@ func TestDebug(t *testing.T) { }, } - exprToRules := func(s string) *GoRuleSet { + loadRulesFromExpr := func(e *Engine, s string) { file := fmt.Sprintf(` package gorules import "github.com/quasilyte/go-ruleguard/dsl" @@ -150,22 +150,24 @@ func TestDebug(t *testing.T) { %s.Report("$$") }`, s) - ctx := &ParseContext{Fset: token.NewFileSet()} - rset, err := ParseRules(ctx, "rules.go", strings.NewReader(file)) + ctx := &ParseContext{ + Fset: token.NewFileSet(), + } + err := e.Load(ctx, "rules.go", strings.NewReader(file)) if err != nil { t.Fatalf("parse %s: %v", s, err) } - return rset } for expr, testCases := range allTests { - rset := exprToRules(expr) + e := NewEngine() + loadRulesFromExpr(e, expr) for input, lines := range testCases { runner, err := newDebugTestRunner(input) if err != nil { t.Fatalf("init %s: %s: %v", expr, input, err) } - if err := runner.Run(t, rset); err != nil { + if err := runner.Run(t, e); err != nil { t.Fatalf("run %s: %s: %v", expr, input, err) } if diff := cmp.Diff(runner.out, lines); diff != "" { @@ -176,13 +178,13 @@ func TestDebug(t *testing.T) { } type debugTestRunner struct { - ctx *Context + ctx *RunContext f *ast.File out []string } -func (r debugTestRunner) Run(t *testing.T, rset *GoRuleSet) error { - if err := RunRules(r.ctx, r.f, rset); err != nil { +func (r debugTestRunner) Run(t *testing.T, e *Engine) error { + if err := e.Run(r.ctx, r.f); err != nil { return err } return nil @@ -211,7 +213,7 @@ func newDebugTestRunner(input string) (*debugTestRunner, error) { return nil, err } runner := &debugTestRunner{f: f} - ctx := &Context{ + ctx := &RunContext{ Debug: "testrule", DebugPrint: func(s string) { runner.out = append(runner.out, s) diff --git a/ruleguard/engine.go b/ruleguard/engine.go new file mode 100644 index 00000000..f8d1e390 --- /dev/null +++ b/ruleguard/engine.go @@ -0,0 +1,174 @@ +package ruleguard + +import ( + "errors" + "fmt" + "go/ast" + "go/types" + "io" + "strings" + "sync" + + "github.com/quasilyte/go-ruleguard/ruleguard/quasigo" + "github.com/quasilyte/go-ruleguard/ruleguard/typematch" +) + +type engine struct { + state *engineState + + ruleSet *goRuleSet +} + +func newEngine() *engine { + return &engine{ + state: newEngineState(), + } +} + +func (e *engine) Load(ctx *ParseContext, filename string, r io.Reader) error { + config := rulesParserConfig{ + state: e.state, + ctx: ctx, + importer: newGoImporter(e.state, goImporterConfig{ + fset: ctx.Fset, + debugImports: ctx.DebugImports, + debugPrint: ctx.DebugPrint, + }), + itab: typematch.NewImportsTab(stdlibPackages), + } + p := newRulesParser(config) + rset, err := p.ParseFile(filename, r) + if err != nil { + return err + } + + if e.ruleSet == nil { + e.ruleSet = rset + } else { + combinedRuleSet, err := mergeRuleSets([]*goRuleSet{e.ruleSet, rset}) + if err != nil { + return err + } + e.ruleSet = combinedRuleSet + } + + return nil +} + +func (e *engine) Run(ctx *RunContext, f *ast.File) error { + if e.ruleSet == nil { + return errors.New("used Run() with an empty rule set; forgot to call Load() first?") + } + rset := cloneRuleSet(e.ruleSet) + return newRulesRunner(ctx, e.state, rset).run(f) +} + +// engineState is a shared state inside the engine. +type engineState struct { + env *quasigo.Env + + typeByFQNMu sync.RWMutex + typeByFQN map[string]types.Type + + pkgCacheMu sync.RWMutex + // pkgCache contains all imported packages, from any importer. + pkgCache map[string]*types.Package +} + +func newEngineState() *engineState { + env := quasigo.NewEnv() + state := &engineState{ + env: env, + pkgCache: make(map[string]*types.Package), + typeByFQN: map[string]types.Type{ + // Predeclared types. + `error`: types.Universe.Lookup("error").Type(), + `bool`: types.Typ[types.Bool], + `int`: types.Typ[types.Int], + `int8`: types.Typ[types.Int8], + `int16`: types.Typ[types.Int16], + `int32`: types.Typ[types.Int32], + `int64`: types.Typ[types.Int64], + `uint`: types.Typ[types.Uint], + `uint8`: types.Typ[types.Uint8], + `uint16`: types.Typ[types.Uint16], + `uint32`: types.Typ[types.Uint32], + `uint64`: types.Typ[types.Uint64], + `uintptr`: types.Typ[types.Uintptr], + `string`: types.Typ[types.String], + `float32`: types.Typ[types.Float32], + `float64`: types.Typ[types.Float64], + `complex64`: types.Typ[types.Complex64], + `complex128`: types.Typ[types.Complex128], + // Predeclared aliases (provided for convenience). + `byte`: types.Typ[types.Uint8], + `rune`: types.Typ[types.Int32], + }, + } + initEnv(state, env) + return state +} + +func (state *engineState) GetCachedPackage(pkgPath string) *types.Package { + state.pkgCacheMu.RLock() + pkg := state.pkgCache[pkgPath] + state.pkgCacheMu.RUnlock() + return pkg +} + +func (state *engineState) AddCachedPackage(pkgPath string, pkg *types.Package) { + state.pkgCacheMu.Lock() + state.pkgCache[pkgPath] = pkg + state.pkgCacheMu.Unlock() +} + +func (state *engineState) FindType(importer *goImporter, currentPkg *types.Package, fqn string) (types.Type, error) { + // TODO(quasilyte): we can pre-populate the cache during the Load() phase. + // If we inspect the AST of a user function, all constant FQN can be preloaded. + // It could be a good thing as Load() is not expected to be executed in + // concurrent environment, so write-locking is not a big deal there. + + state.typeByFQNMu.RLock() + cachedType, ok := state.typeByFQN[fqn] + state.typeByFQNMu.RUnlock() + if ok { + return cachedType, nil + } + + // Code below is under a write critical section. + state.typeByFQNMu.Lock() + defer state.typeByFQNMu.Unlock() + + typ, err := state.findTypeNoCache(importer, currentPkg, fqn) + if err != nil { + return nil, err + } + state.typeByFQN[fqn] = typ + return typ, nil +} + +func (state *engineState) findTypeNoCache(importer *goImporter, currentPkg *types.Package, fqn string) (types.Type, error) { + pos := strings.LastIndexByte(fqn, '.') + if pos == -1 { + return nil, fmt.Errorf("%s is not a valid FQN", fqn) + } + pkgPath := fqn[:pos] + objectName := fqn[pos+1:] + var pkg *types.Package + if directDep := findDependency(currentPkg, pkgPath); directDep != nil { + pkg = directDep + } else { + loadedPkg, err := importer.Import(pkgPath) + if err != nil { + return nil, err + } + pkg = loadedPkg + } + obj := pkg.Scope().Lookup(objectName) + if obj == nil { + return nil, fmt.Errorf("%s is not found in %s", objectName, pkgPath) + } + typ := obj.Type() + state.typeByFQN[fqn] = typ + return typ, nil +} diff --git a/ruleguard/filters.go b/ruleguard/filters.go index a25f340c..1620641e 100644 --- a/ruleguard/filters.go +++ b/ruleguard/filters.go @@ -9,6 +9,7 @@ import ( "regexp" "github.com/quasilyte/go-ruleguard/internal/xtypes" + "github.com/quasilyte/go-ruleguard/ruleguard/quasigo" "github.com/quasilyte/go-ruleguard/ruleguard/typematch" ) @@ -104,6 +105,20 @@ func makeAddressableFilter(src, varname string) filterFunc { } } +func makeCustomVarFilter(src, varname string, fn *quasigo.Func) filterFunc { + return func(params *filterParams) matchFilterResult { + // TODO(quasilyte): what if bytecode function panics due to the programming error? + // We should probably catch the panic here, print trace and return "false" + // from the filter (or even propagate that panic to let it crash). + params.varname = varname + result := quasigo.Call(params.env, fn, params) + if result.(bool) { + return filterSuccess + } + return filterFailure(src) + } +} + func makeTypeImplementsFilter(src, varname string, iface *types.Interface) filterFunc { return func(params *filterParams) matchFilterResult { typ := params.typeofNode(params.subExpr(varname)) diff --git a/ruleguard/gorule.go b/ruleguard/gorule.go index 42dc234d..5357ad67 100644 --- a/ruleguard/gorule.go +++ b/ruleguard/gorule.go @@ -1,12 +1,21 @@ package ruleguard import ( + "fmt" "go/ast" + "go/token" "go/types" "github.com/quasilyte/go-ruleguard/internal/mvdan.cc/gogrep" + "github.com/quasilyte/go-ruleguard/ruleguard/quasigo" ) +type goRuleSet struct { + universal *scopedGoRuleSet + + groups map[string]token.Position // To handle redefinitions +} + type scopedGoRuleSet struct { uncategorized []goRule categorizedNum int @@ -38,13 +47,19 @@ type matchFilter struct { } type filterParams struct { - ctx *Context + ctx *RunContext filename string imports map[string]struct{} + env *quasigo.EvalEnv + + importer *goImporter values map[string]ast.Node nodeText func(n ast.Node) []byte + + // varname is set only for custom filters before bytecode function is called. + varname string } func (params *filterParams) subExpr(name string) ast.Expr { @@ -67,3 +82,51 @@ func (params *filterParams) typeofNode(n ast.Node) types.Type { return types.Typ[types.Invalid] } + +func cloneRuleSet(rset *goRuleSet) *goRuleSet { + out, err := mergeRuleSets([]*goRuleSet{rset}) + if err != nil { + panic(err) // Should never happen + } + return out +} + +func mergeRuleSets(toMerge []*goRuleSet) (*goRuleSet, error) { + out := &goRuleSet{ + universal: &scopedGoRuleSet{}, + groups: make(map[string]token.Position), + } + + for _, x := range toMerge { + out.universal = appendScopedRuleSet(out.universal, x.universal) + for group, pos := range x.groups { + if prevPos, ok := out.groups[group]; ok { + newRef := fmt.Sprintf("%s:%d", pos.Filename, pos.Line) + oldRef := fmt.Sprintf("%s:%d", prevPos.Filename, prevPos.Line) + return nil, fmt.Errorf("%s: redefenition of %s(), previously defined at %s", newRef, group, oldRef) + } + out.groups[group] = pos + } + } + + return out, nil +} + +func appendScopedRuleSet(dst, src *scopedGoRuleSet) *scopedGoRuleSet { + dst.uncategorized = append(dst.uncategorized, cloneRuleSlice(src.uncategorized)...) + for cat, rules := range src.rulesByCategory { + dst.rulesByCategory[cat] = append(dst.rulesByCategory[cat], cloneRuleSlice(rules)...) + dst.categorizedNum += len(rules) + } + return dst +} + +func cloneRuleSlice(slice []goRule) []goRule { + out := make([]goRule, len(slice)) + for i, rule := range slice { + clone := rule + clone.pat = rule.pat.Clone() + out[i] = clone + } + return out +} diff --git a/ruleguard/goutil/goutil.go b/ruleguard/goutil/goutil.go new file mode 100644 index 00000000..6cc4d905 --- /dev/null +++ b/ruleguard/goutil/goutil.go @@ -0,0 +1,21 @@ +package goutil + +import ( + "go/ast" + "go/printer" + "go/token" + "strings" +) + +// SprintNode returns the textual representation of n. +// If fset is nil, freshly created file set will be used. +func SprintNode(fset *token.FileSet, n ast.Node) string { + if fset == nil { + fset = token.NewFileSet() + } + var buf strings.Builder + if err := printer.Fprint(&buf, fset, n); err != nil { + return "" + } + return buf.String() +} diff --git a/ruleguard/goutil/resolve.go b/ruleguard/goutil/resolve.go new file mode 100644 index 00000000..8705707a --- /dev/null +++ b/ruleguard/goutil/resolve.go @@ -0,0 +1,33 @@ +package goutil + +import ( + "go/ast" + "go/types" + + "golang.org/x/tools/go/ast/astutil" +) + +func ResolveFunc(info *types.Info, callable ast.Expr) (ast.Expr, *types.Func) { + switch callable := astutil.Unparen(callable).(type) { + case *ast.Ident: + sig, ok := info.ObjectOf(callable).(*types.Func) + if !ok { + return nil, nil + } + return nil, sig + + case *ast.SelectorExpr: + sig, ok := info.ObjectOf(callable.Sel).(*types.Func) + if !ok { + return nil, nil + } + isMethod := sig.Type().(*types.Signature).Recv() != nil + if _, ok := callable.X.(*ast.Ident); ok && !isMethod { + return nil, sig + } + return callable.X, sig + + default: + return nil, nil + } +} diff --git a/ruleguard/importer.go b/ruleguard/importer.go index bfe6affc..06a0bbf9 100644 --- a/ruleguard/importer.go +++ b/ruleguard/importer.go @@ -5,6 +5,7 @@ import ( "go/ast" "go/importer" "go/parser" + "go/token" "go/types" "path/filepath" "runtime" @@ -17,49 +18,56 @@ import ( type goImporter struct { // TODO(quasilyte): share importers with gogrep? - ctx *ParseContext - - // cache contains all imported packages, from any importer. - // Both default and source importers have their own caches, - // but since we use several importers, it's better to - // have our own, unified cache. - cache map[string]*types.Package + state *engineState defaultImporter types.Importer srcImporter types.Importer + + fset *token.FileSet + + debugImports bool + debugPrint func(string) +} + +type goImporterConfig struct { + fset *token.FileSet + debugImports bool + debugPrint func(string) } -func newGoImporter(ctx *ParseContext) *goImporter { +func newGoImporter(state *engineState, config goImporterConfig) *goImporter { return &goImporter{ - ctx: ctx, - cache: make(map[string]*types.Package), + state: state, + fset: config.fset, + debugImports: config.debugImports, + debugPrint: config.debugPrint, defaultImporter: importer.Default(), - srcImporter: importer.ForCompiler(ctx.Fset, "source", nil), + srcImporter: importer.ForCompiler(config.fset, "source", nil), } } func (imp *goImporter) Import(path string) (*types.Package, error) { - if pkg := imp.cache[path]; pkg != nil { - if imp.ctx.DebugImports { - imp.ctx.DebugPrint(fmt.Sprintf(`imported "%s" from importer cache`, path)) + if pkg := imp.state.GetCachedPackage(path); pkg != nil { + if imp.debugImports { + imp.debugPrint(fmt.Sprintf(`imported "%s" from importer cache`, path)) } return pkg, nil } pkg, err1 := imp.srcImporter.Import(path) if err1 == nil { - imp.cache[path] = pkg - if imp.ctx.DebugImports { - imp.ctx.DebugPrint(fmt.Sprintf(`imported "%s" from source importer`, path)) + imp.state.AddCachedPackage(path, pkg) + if imp.debugImports { + imp.debugPrint(fmt.Sprintf(`imported "%s" from source importer`, path)) } return pkg, nil } pkg, err2 := imp.defaultImporter.Import(path) if err2 == nil { - imp.cache[path] = pkg - if imp.ctx.DebugImports { - imp.ctx.DebugPrint(fmt.Sprintf(`imported "%s" from %s importer`, path, runtime.Compiler)) + imp.state.AddCachedPackage(path, pkg) + if imp.debugImports { + imp.debugPrint(fmt.Sprintf(`imported "%s" from %s importer`, path, runtime.Compiler)) } return pkg, nil } @@ -67,18 +75,18 @@ func (imp *goImporter) Import(path string) (*types.Package, error) { // Fallback to `go list` as a last resort. pkg, err3 := imp.golistImport(path) if err3 == nil { - imp.cache[path] = pkg - if imp.ctx.DebugImports { - imp.ctx.DebugPrint(fmt.Sprintf(`imported "%s" from golist importer`, path)) + imp.state.AddCachedPackage(path, pkg) + if imp.debugImports { + imp.debugPrint(fmt.Sprintf(`imported "%s" from golist importer`, path)) } return pkg, nil } - if imp.ctx.DebugImports { - imp.ctx.DebugPrint(fmt.Sprintf(`failed to import "%s":`, path)) - imp.ctx.DebugPrint(fmt.Sprintf(" source importer: %v", err1)) - imp.ctx.DebugPrint(fmt.Sprintf(" %s importer: %v", runtime.Compiler, err2)) - imp.ctx.DebugPrint(fmt.Sprintf(" golist importer: %v", err3)) + if imp.debugImports { + imp.debugPrint(fmt.Sprintf(`failed to import "%s":`, path)) + imp.debugPrint(fmt.Sprintf(" source importer: %v", err1)) + imp.debugPrint(fmt.Sprintf(" %s importer: %v", runtime.Compiler, err2)) + imp.debugPrint(fmt.Sprintf(" golist importer: %v", err3)) } return nil, err2 @@ -93,7 +101,7 @@ func (imp *goImporter) golistImport(path string) (*types.Package, error) { files := make([]*ast.File, 0, len(golistPkg.GoFiles)) for _, filename := range golistPkg.GoFiles { fullname := filepath.Join(golistPkg.Dir, filename) - f, err := parser.ParseFile(imp.ctx.Fset, fullname, nil, 0) + f, err := parser.ParseFile(imp.fset, fullname, nil, 0) if err != nil { return nil, err } @@ -104,5 +112,5 @@ func (imp *goImporter) golistImport(path string) (*types.Package, error) { // Otherwise it won't be able to resolve imports. var typecheker types.Config var info types.Info - return typecheker.Check(path, imp.ctx.Fset, files, &info) + return typecheker.Check(path, imp.fset, files, &info) } diff --git a/ruleguard/libdsl.go b/ruleguard/libdsl.go new file mode 100644 index 00000000..c0f4d85f --- /dev/null +++ b/ruleguard/libdsl.go @@ -0,0 +1,200 @@ +package ruleguard + +import ( + "go/types" + + "github.com/quasilyte/go-ruleguard/internal/xtypes" + "github.com/quasilyte/go-ruleguard/ruleguard/quasigo" +) + +// This file implements `dsl/*` packages as native functions in quasigo. +// +// Every function and method defined in any `dsl/*` package should have +// associated Go function that implements it. +// +// In quasigo, it's impossible to have a pointer to an interface and +// non-pointer struct type. All interface type methods have FQN without `*` prefix +// while all struct type methods always begin with `*`. +// +// Fields are readonly. +// Field access is compiled into a method call that have a name identical to the field. +// For example, `foo.Bar` field access will be compiled as `foo.Bar()`. +// This may change in the future; benchmarks are needed to figure out +// what is more efficient: reflect-based field access or a function call. +// +// To keep this code organized, every type and package functions are represented +// as structs with methods. Then we bind a method value to quasigo symbol. +// The naming scheme is `dsl{$name}Package` for packages and `dsl{$pkg}{$name}` for types. + +func initEnv(state *engineState, env *quasigo.Env) { + nativeTypes := map[string]quasigoNative{ + `*github.com/quasilyte/go-ruleguard/dsl.VarFilterContext`: dslVarFilterContext{state: state}, + `github.com/quasilyte/go-ruleguard/dsl/types.Type`: dslTypesType{}, + `*github.com/quasilyte/go-ruleguard/dsl/types.Interface`: dslTypesInterface{}, + `*github.com/quasilyte/go-ruleguard/dsl/types.Pointer`: dslTypesPointer{}, + } + + for qualifier, typ := range nativeTypes { + for methodName, fn := range typ.funcs() { + env.AddNativeMethod(qualifier, methodName, fn) + } + } + + nativePackages := map[string]quasigoNative{ + `github.com/quasilyte/go-ruleguard/dsl/types`: dslTypesPackage{}, + } + + for qualifier, pkg := range nativePackages { + for funcName, fn := range pkg.funcs() { + env.AddNativeMethod(qualifier, funcName, fn) + } + } +} + +type quasigoNative interface { + funcs() map[string]func(*quasigo.ValueStack) +} + +type dslTypesType struct{} + +func (native dslTypesType) funcs() map[string]func(*quasigo.ValueStack) { + return map[string]func(*quasigo.ValueStack){ + "Underlying": native.Underlying, + "String": native.String, + } +} + +func (dslTypesType) Underlying(stack *quasigo.ValueStack) { + stack.Push(stack.Pop().(types.Type).Underlying()) +} + +func (dslTypesType) String(stack *quasigo.ValueStack) { + stack.Push(stack.Pop().(types.Type).String()) +} + +type dslTypesInterface struct{} + +func (native dslTypesInterface) funcs() map[string]func(*quasigo.ValueStack) { + return map[string]func(*quasigo.ValueStack){ + "Underlying": native.Underlying, + "String": native.String, + } +} + +func (dslTypesInterface) Underlying(stack *quasigo.ValueStack) { + stack.Push(stack.Pop().(*types.Interface).Underlying()) +} + +func (dslTypesInterface) String(stack *quasigo.ValueStack) { + stack.Push(stack.Pop().(*types.Interface).String()) +} + +type dslTypesPointer struct{} + +func (native dslTypesPointer) funcs() map[string]func(*quasigo.ValueStack) { + return map[string]func(*quasigo.ValueStack){ + "Underlying": native.Underlying, + "String": native.String, + "Elem": native.Elem, + } +} + +func (dslTypesPointer) Underlying(stack *quasigo.ValueStack) { + stack.Push(stack.Pop().(*types.Pointer).Underlying()) +} + +func (dslTypesPointer) String(stack *quasigo.ValueStack) { + stack.Push(stack.Pop().(*types.Pointer).String()) +} + +func (dslTypesPointer) Elem(stack *quasigo.ValueStack) { + stack.Push(stack.Pop().(*types.Pointer).Elem()) +} + +type dslTypesPackage struct{} + +func (native dslTypesPackage) funcs() map[string]func(*quasigo.ValueStack) { + return map[string]func(*quasigo.ValueStack){ + "Implements": native.Implements, + "Identical": native.Identical, + "NewPointer": native.NewPointer, + "AsPointer": native.AsPointer, + "AsInterface": native.AsInterface, + } +} + +func (dslTypesPackage) Implements(stack *quasigo.ValueStack) { + iface := stack.Pop().(*types.Interface) + typ := stack.Pop().(types.Type) + stack.Push(xtypes.Implements(typ, iface)) +} + +func (dslTypesPackage) Identical(stack *quasigo.ValueStack) { + y := stack.Pop().(types.Type) + x := stack.Pop().(types.Type) + stack.Push(xtypes.Identical(x, y)) +} + +func (dslTypesPackage) NewPointer(stack *quasigo.ValueStack) { + typ := stack.Pop().(types.Type) + stack.Push(types.NewPointer(typ)) +} + +func (dslTypesPackage) AsPointer(stack *quasigo.ValueStack) { + typ, _ := stack.Pop().(types.Type).(*types.Pointer) + stack.Push(typ) +} + +func (dslTypesPackage) AsInterface(stack *quasigo.ValueStack) { + typ, _ := stack.Pop().(types.Type).(*types.Interface) + stack.Push(typ) +} + +type dslVarFilterContext struct { + state *engineState +} + +func (native dslVarFilterContext) funcs() map[string]func(*quasigo.ValueStack) { + return map[string]func(*quasigo.ValueStack){ + "Type": native.Type, + "SizeOf": native.SizeOf, + "GetType": native.GetType, + "GetInterface": native.GetInterface, + } +} + +func (dslVarFilterContext) Type(stack *quasigo.ValueStack) { + params := stack.Pop().(*filterParams) + typ := params.typeofNode(params.subExpr(params.varname)) + stack.Push(typ) +} + +func (native dslVarFilterContext) SizeOf(stack *quasigo.ValueStack) { + typ := stack.Pop().(types.Type) + params := stack.Pop().(*filterParams) + stack.Push(int(params.ctx.Sizes.Sizeof(typ))) +} + +func (native dslVarFilterContext) GetType(stack *quasigo.ValueStack) { + fqn := stack.Pop().(string) + params := stack.Pop().(*filterParams) + typ, err := native.state.FindType(params.importer, params.ctx.Pkg, fqn) + if err != nil { + panic(err) + } + stack.Push(typ) +} + +func (native dslVarFilterContext) GetInterface(stack *quasigo.ValueStack) { + fqn := stack.Pop().(string) + params := stack.Pop().(*filterParams) + typ, err := native.state.FindType(params.importer, params.ctx.Pkg, fqn) + if err != nil { + panic(err) + } + if ifaceType, ok := typ.Underlying().(*types.Interface); ok { + stack.Push(ifaceType) + return + } + stack.Push((*types.Interface)(nil)) // Not found or not an interface +} diff --git a/ruleguard/merge.go b/ruleguard/merge.go deleted file mode 100644 index 76e12be6..00000000 --- a/ruleguard/merge.go +++ /dev/null @@ -1,42 +0,0 @@ -package ruleguard - -import ( - "fmt" - "go/token" -) - -func mergeRuleSets(toMerge []*GoRuleSet) (*GoRuleSet, error) { - out := &GoRuleSet{ - local: &scopedGoRuleSet{}, - universal: &scopedGoRuleSet{}, - groups: make(map[string]token.Position), - Imports: make(map[string]struct{}), - } - - for _, x := range toMerge { - out.local = appendScopedRuleSet(out.local, x.local) - out.universal = appendScopedRuleSet(out.universal, x.universal) - for pkgPath := range x.Imports { - out.Imports[pkgPath] = struct{}{} - } - for group, pos := range x.groups { - if prevPos, ok := out.groups[group]; ok { - newRef := fmt.Sprintf("%s:%d", pos.Filename, pos.Line) - oldRef := fmt.Sprintf("%s:%d", prevPos.Filename, prevPos.Line) - return nil, fmt.Errorf("%s: redefenition of %s(), previously defined at %s", newRef, group, oldRef) - } - out.groups[group] = pos - } - } - - return out, nil -} - -func appendScopedRuleSet(dst, src *scopedGoRuleSet) *scopedGoRuleSet { - dst.uncategorized = append(dst.uncategorized, src.uncategorized...) - for cat, rules := range src.rulesByCategory { - dst.rulesByCategory[cat] = append(dst.rulesByCategory[cat], rules...) - dst.categorizedNum += len(rules) - } - return dst -} diff --git a/ruleguard/parser.go b/ruleguard/parser.go index dde6db57..89d2dc43 100644 --- a/ruleguard/parser.go +++ b/ruleguard/parser.go @@ -14,62 +14,69 @@ import ( "strconv" "github.com/quasilyte/go-ruleguard/internal/mvdan.cc/gogrep" + "github.com/quasilyte/go-ruleguard/ruleguard/goutil" + "github.com/quasilyte/go-ruleguard/ruleguard/quasigo" "github.com/quasilyte/go-ruleguard/ruleguard/typematch" ) +// TODO(quasilyte): use source code byte slicing instead of SprintNode? + type parseError string func (e parseError) Error() string { return string(e) } type rulesParser struct { - ctx *ParseContext + state *engineState + ctx *ParseContext prefix string // For imported packages, a prefix that is added to a rule group name importedPkg string // Package path; only for imported packages filename string group string - fset *token.FileSet - res *GoRuleSet + res *goRuleSet + pkg *types.Package types *types.Info - itab *typematch.ImportsTab importer *goImporter - imported []*GoRuleSet + itab *typematch.ImportsTab + + imported []*goRuleSet dslPkgname string // The local name of the "ruleguard/dsl" package (usually its just "dsl") } type rulesParserConfig struct { + state *engineState + ctx *ParseContext + importer *goImporter + prefix string importedPkg string - itab *typematch.ImportsTab - importer *goImporter + itab *typematch.ImportsTab } func newRulesParser(config rulesParserConfig) *rulesParser { return &rulesParser{ + state: config.state, ctx: config.ctx, + importer: config.importer, prefix: config.prefix, importedPkg: config.importedPkg, itab: config.itab, - importer: config.importer, } } -func (p *rulesParser) ParseFile(filename string, r io.Reader) (*GoRuleSet, error) { +func (p *rulesParser) ParseFile(filename string, r io.Reader) (*goRuleSet, error) { p.dslPkgname = "dsl" p.filename = filename - p.fset = p.ctx.Fset - p.res = &GoRuleSet{ - local: &scopedGoRuleSet{}, + p.res = &goRuleSet{ universal: &scopedGoRuleSet{}, groups: make(map[string]token.Position), - Imports: make(map[string]struct{}), } parserFlags := parser.Mode(0) @@ -98,12 +105,16 @@ func (p *rulesParser) ParseFile(filename string, r io.Reader) (*GoRuleSet, error p.types = &types.Info{ Types: map[ast.Expr]types.TypeAndValue{}, Uses: map[*ast.Ident]types.Object{}, + Defs: map[*ast.Ident]types.Object{}, } - _, err = typechecker.Check("gorules", p.ctx.Fset, []*ast.File{f}, p.types) + pkg, err := typechecker.Check("gorules", p.ctx.Fset, []*ast.File{f}, p.types) if err != nil { return nil, fmt.Errorf("typechecker error: %v", err) } + p.pkg = pkg + var matcherFuncs []*ast.FuncDecl + var userFuncs []*ast.FuncDecl for _, decl := range f.Decls { decl, ok := decl.(*ast.FuncDecl) if !ok { @@ -115,15 +126,29 @@ func (p *rulesParser) ParseFile(filename string, r io.Reader) (*GoRuleSet, error } continue } + + if p.isMatcherFunc(decl) { + matcherFuncs = append(matcherFuncs, decl) + } else { + userFuncs = append(userFuncs, decl) + } + } + + for _, decl := range userFuncs { + if err := p.parseUserFunc(decl); err != nil { + return nil, err + } + } + for _, decl := range matcherFuncs { if err := p.parseRuleGroup(decl); err != nil { return nil, err } } if len(p.imported) != 0 { - toMerge := []*GoRuleSet{p.res} + toMerge := []*goRuleSet{p.res} toMerge = append(toMerge, p.imported...) - merged, err := MergeRuleSets(toMerge) + merged, err := mergeRuleSets(toMerge) if err != nil { return nil, err } @@ -133,6 +158,23 @@ func (p *rulesParser) ParseFile(filename string, r io.Reader) (*GoRuleSet, error return p.res, nil } +func (p *rulesParser) parseUserFunc(f *ast.FuncDecl) error { + ctx := &quasigo.CompileContext{ + Env: p.state.env, + Types: p.types, + Fset: p.ctx.Fset, + } + compiled, err := quasigo.Compile(ctx, f) + if err != nil { + return err + } + if p.ctx.DebugFilter == f.Name.String() { + p.ctx.DebugPrint(quasigo.Disasm(p.state.env, compiled)) + } + ctx.Env.AddFunc(p.pkg.Path(), f.Name.String(), compiled) + return nil +} + func (p *rulesParser) parseInitFunc(f *ast.FuncDecl) error { type bundleImport struct { node ast.Node @@ -193,24 +235,23 @@ func (p *rulesParser) parseInitFunc(f *ast.FuncDecl) error { return p.errorf(imp.node, "import parsing error: %v", err) } p.imported = append(p.imported, rset) - p.res.Imports[imp.pkgPath] = struct{}{} } } return nil } -func (p *rulesParser) importRules(prefix, pkgPath, filename string) (*GoRuleSet, error) { +func (p *rulesParser) importRules(prefix, pkgPath, filename string) (*goRuleSet, error) { data, err := ioutil.ReadFile(filename) if err != nil { return nil, err } config := rulesParserConfig{ ctx: p.ctx, + importer: p.importer, prefix: prefix, importedPkg: pkgPath, itab: p.itab, - importer: p.importer, } rset, err := newRulesParser(config).ParseFile(filename, bytes.NewReader(data)) if err != nil { @@ -219,6 +260,13 @@ func (p *rulesParser) importRules(prefix, pkgPath, filename string) (*GoRuleSet, return rset, nil } +func (p *rulesParser) isMatcherFunc(f *ast.FuncDecl) bool { + typ := p.types.ObjectOf(f.Name).Type().(*types.Signature) + return typ.Results().Len() == 0 && + typ.Params().Len() == 1 && + typ.Params().At(0).Type().String() == "github.com/quasilyte/go-ruleguard/dsl.Matcher" +} + func (p *rulesParser) parseRuleGroup(f *ast.FuncDecl) (err error) { defer func() { rv := recover() @@ -238,14 +286,7 @@ func (p *rulesParser) parseRuleGroup(f *ast.FuncDecl) (err error) { if f.Body == nil { return p.errorf(f, "unexpected empty function body") } - if f.Type.Results != nil { - return p.errorf(f.Type.Results, "rule group function should not return anything") - } params := f.Type.Params.List - if len(params) != 1 || len(params[0].Names) != 1 { - return p.errorf(f.Type.Params, "rule group function should accept exactly 1 Matcher param") - } - // TODO(quasilyte): do an actual matcher param type check? matcher := params[0].Names[0].Name p.group = f.Name.Name @@ -261,7 +302,7 @@ func (p *rulesParser) parseRuleGroup(f *ast.FuncDecl) (err error) { } p.res.groups[p.group] = token.Position{ Filename: p.filename, - Line: p.fset.Position(f.Name.Pos()).Line, + Line: p.ctx.Fset.Position(f.Name.Pos()).Line, } p.itab.EnterScope() @@ -273,11 +314,11 @@ func (p *rulesParser) parseRuleGroup(f *ast.FuncDecl) (err error) { } stmtExpr, ok := stmt.(*ast.ExprStmt) if !ok { - return p.errorf(stmt, "expected a %s method call, found %s", matcher, sprintNode(p.fset, stmt)) + return p.errorf(stmt, "expected a %s method call, found %s", matcher, goutil.SprintNode(p.ctx.Fset, stmt)) } call, ok := stmtExpr.X.(*ast.CallExpr) if !ok { - return p.errorf(stmt, "expected a %s method call, found %s", matcher, sprintNode(p.fset, stmt)) + return p.errorf(stmt, "expected a %s method call, found %s", matcher, goutil.SprintNode(p.ctx.Fset, stmt)) } if err := p.parseCall(matcher, call); err != nil { return err @@ -365,7 +406,7 @@ func (p *rulesParser) parseRule(matcher string, call *ast.CallExpr) error { dst := p.res.universal proto := goRule{ filename: p.filename, - line: p.fset.Position(origCall.Pos()).Line, + line: p.ctx.Fset.Position(origCall.Pos()).Line, group: p.group, } var alternatives []string @@ -408,7 +449,7 @@ func (p *rulesParser) parseRule(matcher string, call *ast.CallExpr) error { for i, alt := range alternatives { rule := proto - pat, err := gogrep.Parse(p.fset, alt) + pat, err := gogrep.Parse(p.ctx.Fset, alt) if err != nil { return p.errorf((*matchArgs)[i], "parse match pattern: %v", err) } @@ -430,7 +471,7 @@ func (p *rulesParser) parseFilter(root ast.Expr) matchFilter { } func (p *rulesParser) errorf(n ast.Node, format string, args ...interface{}) parseError { - loc := p.fset.Position(n.Pos()) + loc := p.ctx.Fset.Position(n.Pos()) message := fmt.Sprintf("%s:%d: %s", loc.Filename, loc.Line, fmt.Sprintf(format, args...)) return parseError(message) } @@ -471,7 +512,7 @@ func (p *rulesParser) parseTypeStringArg(e ast.Expr) types.Type { } func (p *rulesParser) parseFilterExpr(e ast.Expr) matchFilter { - result := matchFilter{src: sprintNode(p.fset, e)} + result := matchFilter{src: goutil.SprintNode(p.ctx.Fset, e)} switch e := e.(type) { case *ast.ParenExpr: @@ -548,6 +589,18 @@ func (p *rulesParser) parseFilterExpr(e ast.Expr) matchFilter { case "Addressable": result.fn = makeAddressableFilter(result.src, operand.varName) + case "Filter": + expr, fn := goutil.ResolveFunc(p.types, args[0]) + if expr != nil { + panic(p.errorf(expr, "expected a simple function name, found expression")) + } + sig := fn.Type().(*types.Signature) + userFn := p.state.env.GetFunc(fn.Pkg().Path(), fn.Name()) + if userFn == nil { + panic(p.errorf(args[0], "can't find a compiled version of %s", sig.String())) + } + result.fn = makeCustomVarFilter(result.src, operand.varName, userFn) + case "Type.Is", "Type.Underlying.Is": typeString, ok := p.toStringValue(args[0]) if !ok { diff --git a/ruleguard/quasigo/compile.go b/ruleguard/quasigo/compile.go new file mode 100644 index 00000000..bb9639c8 --- /dev/null +++ b/ruleguard/quasigo/compile.go @@ -0,0 +1,598 @@ +package quasigo + +import ( + "fmt" + "go/ast" + "go/constant" + "go/token" + "go/types" + + "github.com/quasilyte/go-ruleguard/ruleguard/goutil" + "golang.org/x/tools/go/ast/astutil" +) + +func compile(ctx *CompileContext, fn *ast.FuncDecl) (compiled *Func, err error) { + defer func() { + rv := recover() + if rv == nil { + return + } + if compileErr, ok := rv.(compileError); ok { + err = compileErr + return + } + panic(rv) // not our panic + }() + + return compileFunc(ctx, fn), nil +} + +func compileFunc(ctx *CompileContext, fn *ast.FuncDecl) *Func { + fnType := ctx.Types.ObjectOf(fn.Name).Type().(*types.Signature) + if fnType.Results().Len() != 1 { + panic(compileError("only functions with a single non-void results are supported")) + } + + cl := compiler{ + ctx: ctx, + retType: fnType.Results().At(0).Type(), + constantsPool: make(map[interface{}]int), + locals: make(map[string]int), + } + return cl.compileFunc(fnType, fn) +} + +type compiler struct { + ctx *CompileContext + + retType types.Type + + lastOp opcode + + locals map[string]int + constantsPool map[interface{}]int + params map[string]int + + code []byte + constants []interface{} +} + +type label struct { + used bool + jumpPos int +} + +type compileError string + +func (e compileError) Error() string { return string(e) } + +func (cl *compiler) compileFunc(fnType *types.Signature, fn *ast.FuncDecl) *Func { + if !cl.isSupportedType(cl.retType) { + panic(cl.errorUnsupportedType(fn.Name, cl.retType, "function result")) + } + + dbg := funcDebugInfo{ + paramNames: make([]string, fnType.Params().Len()), + } + + cl.params = make(map[string]int, fnType.Params().Len()) + for i := 0; i < fnType.Params().Len(); i++ { + p := fnType.Params().At(i) + paramName := p.Name() + paramType := p.Type() + cl.params[paramName] = i + dbg.paramNames[i] = paramName + if !cl.isSupportedType(paramType) { + panic(cl.errorUnsupportedType(fn.Name, paramType, paramName+" param")) + } + } + + cl.compileStmt(fn.Body) + compiled := &Func{ + code: cl.code, + constants: cl.constants, + } + if len(cl.locals) != 0 { + dbg.localNames = make([]string, len(cl.locals)) + for localName, localIndex := range cl.locals { + dbg.localNames[localIndex] = localName + } + } + cl.ctx.Env.debug.funcs[compiled] = dbg + return compiled +} + +func (cl *compiler) compileStmt(stmt ast.Stmt) { + switch stmt := stmt.(type) { + case *ast.ReturnStmt: + cl.compileReturnStmt(stmt) + + case *ast.AssignStmt: + cl.compileAssignStmt(stmt) + + case *ast.IfStmt: + cl.compileIfStmt(stmt) + + case *ast.BlockStmt: + for i := range stmt.List { + cl.compileStmt(stmt.List[i]) + } + + default: + panic(cl.errorf(stmt, "can't compile %T yet", stmt)) + } +} + +func (cl *compiler) compileIfStmt(stmt *ast.IfStmt) { + if stmt.Else == nil { + var labelEnd label + cl.compileExpr(stmt.Cond) + cl.emitJump(opJumpFalse, &labelEnd) + cl.emit(opPop) + cl.compileStmt(stmt.Body) + cl.bindLabel(labelEnd) + return + } + + var labelEnd label + var labelElse label + cl.compileExpr(stmt.Cond) + cl.emitJump(opJumpFalse, &labelElse) + cl.emit(opPop) + cl.compileStmt(stmt.Body) + if !cl.isUncondJump(cl.lastOp) { + cl.emitJump(opJump, &labelEnd) + } + cl.bindLabel(labelElse) + cl.compileStmt(stmt.Else) + cl.bindLabel(labelEnd) +} + +func (cl *compiler) compileAssignStmt(assign *ast.AssignStmt) { + if len(assign.Lhs) != 1 { + panic(cl.errorf(assign, "only single left operand is allowed in assignments")) + } + if len(assign.Rhs) != 1 { + panic(cl.errorf(assign, "only single right operand is allowed in assignments")) + } + lhs := assign.Lhs[0] + rhs := assign.Rhs[0] + varname, ok := lhs.(*ast.Ident) + if !ok { + panic(cl.errorf(lhs, "can assign only to simple variables")) + } + + cl.compileExpr(rhs) + + if assign.Tok == token.DEFINE { + if _, ok := cl.locals[varname.String()]; ok { + panic(cl.errorf(lhs, "%s variable shadowing is not allowed", varname)) + } + typ := cl.ctx.Types.TypeOf(varname) + if !cl.isSupportedType(typ) { + panic(cl.errorUnsupportedType(varname, typ, varname.String()+" local variable")) + } + if len(cl.locals) == maxFuncLocals { + panic(cl.errorf(lhs, "can't define %s: too many locals", varname)) + } + id := len(cl.locals) + cl.locals[varname.String()] = id + cl.emit8(opSetLocal, id) + } else { + id, ok := cl.locals[varname.String()] + if !ok { + if _, ok := cl.params[varname.String()]; ok { + panic(cl.errorf(lhs, "can't assign to %s, params are readonly", varname.String())) + } else { + panic(cl.errorf(lhs, "%s is not a writeable local variable", varname.String())) + } + } + cl.emit8(opSetLocal, id) + } +} + +func (cl *compiler) compileReturnStmt(ret *ast.ReturnStmt) { + if ret.Results == nil { + panic(cl.errorf(ret, "'naked' return statements are not allowed")) + } + + switch { + case identName(ret.Results[0]) == "true": + cl.emit(opReturnTrue) + case identName(ret.Results[0]) == "false": + cl.emit(opReturnFalse) + default: + cl.compileExpr(ret.Results[0]) + cl.emit(opReturnTop) + } +} + +func (cl *compiler) compileExpr(e ast.Expr) { + cv := cl.ctx.Types.Types[e].Value + if cv != nil { + cl.compileConstantValue(e, cv) + return + } + + switch e := e.(type) { + case *ast.ParenExpr: + cl.compileExpr(e.X) + + case *ast.Ident: + cl.compileIdent(e) + + case *ast.SelectorExpr: + cl.compileSelectorExpr(e) + + case *ast.UnaryExpr: + switch e.Op { + case token.NOT: + cl.compileUnaryOp(opNot, e) + default: + panic(cl.errorf(e, "can't compile unary %s yet", e.Op)) + } + + case *ast.SliceExpr: + cl.compileSliceExpr(e) + + case *ast.BinaryExpr: + cl.compileBinaryExpr(e) + + case *ast.CallExpr: + cl.compileCallExpr(e) + + default: + panic(cl.errorf(e, "can't compile %T yet", e)) + } +} + +func (cl *compiler) compileSelectorExpr(e *ast.SelectorExpr) { + typ := cl.ctx.Types.TypeOf(e.X) + key := funcKey{ + name: e.Sel.String(), + qualifier: typ.String(), + } + + builtinID, ok := cl.ctx.Env.nameToBuiltinFuncID[key] + if ok { + cl.compileExpr(e.X) + cl.emit16(opCallBuiltin, int(builtinID)) + return + } + + panic(cl.errorf(e, "can't compile %s field access", e.Sel)) +} + +func (cl *compiler) compileBinaryExpr(e *ast.BinaryExpr) { + typ := cl.ctx.Types.TypeOf(e.X) + + switch e.Op { + case token.LOR: + cl.compileOr(e) + case token.LAND: + cl.compileAnd(e) + + case token.NEQ: + switch { + case identName(e.X) == "nil": + cl.compileExpr(e.Y) + cl.emit(opIsNotNil) + case identName(e.Y) == "nil": + cl.compileExpr(e.X) + cl.emit(opIsNotNil) + case typeIsString(typ): + cl.compileBinaryOp(opNotEqString, e) + case typeIsInt(typ): + cl.compileBinaryOp(opNotEqInt, e) + default: + panic(cl.errorf(e, "!= is not implemented for %s operands", typ)) + } + case token.EQL: + switch { + case identName(e.X) == "nil": + cl.compileExpr(e.Y) + cl.emit(opIsNil) + case identName(e.Y) == "nil": + cl.compileExpr(e.X) + cl.emit(opIsNil) + case typeIsString(cl.ctx.Types.TypeOf(e.X)): + cl.compileBinaryOp(opEqString, e) + case typeIsInt(cl.ctx.Types.TypeOf(e.X)): + cl.compileBinaryOp(opEqInt, e) + default: + panic(cl.errorf(e, "== is not implemented for %s operands", typ)) + } + + case token.GTR: + cl.compileIntBinaryOp(e, opGtInt, typ) + case token.GEQ: + cl.compileIntBinaryOp(e, opGtEqInt, typ) + case token.LSS: + cl.compileIntBinaryOp(e, opLtInt, typ) + case token.LEQ: + cl.compileIntBinaryOp(e, opLtEqInt, typ) + + case token.ADD: + switch { + case typeIsString(typ): + cl.compileBinaryOp(opConcat, e) + case typeIsInt(typ): + cl.compileBinaryOp(opAdd, e) + default: + panic(cl.errorf(e, "+ is not implemented for %s operands", typ)) + } + + case token.SUB: + cl.compileIntBinaryOp(e, opSub, typ) + + default: + panic(cl.errorf(e, "can't compile binary %s yet", e.Op)) + } +} + +func (cl *compiler) compileIntBinaryOp(e *ast.BinaryExpr, op opcode, typ types.Type) { + switch { + case typeIsInt(typ): + cl.compileBinaryOp(op, e) + default: + panic(cl.errorf(e, "%s is not implemented for %s operands", e.Op, typ)) + } +} + +func (cl *compiler) compileSliceExpr(slice *ast.SliceExpr) { + if slice.Slice3 { + panic(cl.errorf(slice, "can't compile 3-index slicing")) + } + + // No need to do slicing, its no-op `s[:]`. + if slice.Low == nil && slice.High == nil { + cl.compileExpr(slice.X) + return + } + + sliceOp := opStringSlice + sliceFromOp := opStringSliceFrom + sliceToOp := opStringSliceTo + + if !typeIsString(cl.ctx.Types.TypeOf(slice.X)) { + panic(cl.errorf(slice.X, "can't compile slicing of something that is not a string")) + } + + switch { + case slice.Low == nil && slice.High != nil: + cl.compileExpr(slice.X) + cl.compileExpr(slice.High) + cl.emit(sliceToOp) + case slice.Low != nil && slice.High == nil: + cl.compileExpr(slice.X) + cl.compileExpr(slice.Low) + cl.emit(sliceFromOp) + default: + cl.compileExpr(slice.X) + cl.compileExpr(slice.Low) + cl.compileExpr(slice.High) + cl.emit(sliceOp) + } +} + +func (cl *compiler) compileBuiltinCall(fn *ast.Ident, call *ast.CallExpr) { + switch fn.Name { + case `len`: + s := call.Args[0] + cl.compileExpr(s) + if !typeIsString(cl.ctx.Types.TypeOf(s)) { + panic(cl.errorf(s, "can't compile len() with non-string argument yet")) + } + cl.emit(opStringLen) + default: + panic(cl.errorf(fn, "can't compile %s() builtin function call yet", fn)) + } +} + +func (cl *compiler) compileCallExpr(call *ast.CallExpr) { + if id, ok := astutil.Unparen(call.Fun).(*ast.Ident); ok { + _, isBuiltin := cl.ctx.Types.ObjectOf(id).(*types.Builtin) + if isBuiltin { + cl.compileBuiltinCall(id, call) + return + } + } + + expr, fn := goutil.ResolveFunc(cl.ctx.Types, call.Fun) + if fn == nil { + panic(cl.errorf(call.Fun, "can't resolve the called function")) + } + + // TODO: just use Func.FullName as a key? + key := funcKey{name: fn.Name()} + sig := fn.Type().(*types.Signature) + if sig.Recv() != nil { + key.qualifier = sig.Recv().Type().String() + } else { + key.qualifier = fn.Pkg().Path() + } + + builtinID, ok := cl.ctx.Env.nameToBuiltinFuncID[key] + if ok { + if expr != nil { + cl.compileExpr(expr) + } + for _, arg := range call.Args { + cl.compileExpr(arg) + } + cl.emit16(opCallBuiltin, int(builtinID)) + return + } + + panic(cl.errorf(call.Fun, "can't compile a call to %s func", key)) +} + +func (cl *compiler) compileUnaryOp(op opcode, e *ast.UnaryExpr) { + cl.compileExpr(e.X) + cl.emit(op) +} + +func (cl *compiler) compileBinaryOp(op opcode, e *ast.BinaryExpr) { + cl.compileExpr(e.X) + cl.compileExpr(e.Y) + cl.emit(op) +} + +func (cl *compiler) compileOr(e *ast.BinaryExpr) { + var labelEnd label + cl.compileExpr(e.X) + cl.emitJump(opJumpTrue, &labelEnd) + cl.emit(opPop) + cl.compileExpr(e.Y) + cl.bindLabel(labelEnd) +} + +func (cl *compiler) compileAnd(e *ast.BinaryExpr) { + var labelEnd label + cl.compileExpr(e.X) + cl.emitJump(opJumpFalse, &labelEnd) + cl.emit(opPop) + cl.compileExpr(e.Y) + cl.bindLabel(labelEnd) +} + +func (cl *compiler) compileIdent(ident *ast.Ident) { + tv := cl.ctx.Types.Types[ident] + cv := tv.Value + if cv != nil { + cl.compileConstantValue(ident, cv) + return + } + if paramIndex, ok := cl.params[ident.String()]; ok { + cl.emit8(opPushParam, paramIndex) + return + } + if localIndex, ok := cl.locals[ident.String()]; ok { + cl.emit8(opPushLocal, localIndex) + return + } + + panic(cl.errorf(ident, "can't compile a %s (type %s) variable read", ident.String(), tv.Type)) +} + +func (cl *compiler) compileConstantValue(source ast.Expr, cv constant.Value) { + switch cv.Kind() { + case constant.Bool: + v := constant.BoolVal(cv) + if v { + cl.emit(opPushTrue) + } else { + cl.emit(opPushFalse) + } + + case constant.String: + v := constant.StringVal(cv) + id := cl.internConstant(v) + cl.emit8(opPushConst, id) + + case constant.Int: + v, exact := constant.Int64Val(cv) + if !exact { + panic(cl.errorf(source, "non-exact int value")) + } + id := cl.internConstant(int(v)) + cl.emit8(opPushConst, id) + + case constant.Complex: + panic(cl.errorf(source, "can't compile complex number constants yet")) + + case constant.Float: + panic(cl.errorf(source, "can't compile float constants yet")) + + default: + panic(cl.errorf(source, "unexpected constant %v", cv)) + } +} + +func (cl *compiler) internConstant(v interface{}) int { + if id, ok := cl.constantsPool[v]; ok { + return id + } + id := len(cl.constants) + cl.constants = append(cl.constants, v) + cl.constantsPool[v] = id + return id +} + +func (cl *compiler) bindLabel(l label) { + if !l.used { + return + } + offset := len(cl.code) - l.jumpPos + patchPos := l.jumpPos + 1 + put16(cl.code, patchPos, offset) +} + +func (cl *compiler) emit(op opcode) { + cl.lastOp = op + cl.code = append(cl.code, byte(op)) +} + +func (cl *compiler) emitJump(op opcode, l *label) { + l.jumpPos = len(cl.code) + l.used = true + cl.emit(op) + cl.code = append(cl.code, 0, 0) +} + +func (cl *compiler) emit8(op opcode, arg8 int) { + cl.emit(op) + cl.code = append(cl.code, byte(arg8)) +} + +func (cl *compiler) emit16(op opcode, arg16 int) { + cl.emit(op) + buf := make([]byte, 2) + put16(buf, 0, arg16) + cl.code = append(cl.code, buf...) +} + +func (cl *compiler) errorUnsupportedType(e ast.Node, typ types.Type, where string) compileError { + return cl.errorf(e, "%s type: %s is not supported, try something simpler", where, typ) +} + +func (cl *compiler) errorf(n ast.Node, format string, args ...interface{}) compileError { + loc := cl.ctx.Fset.Position(n.Pos()) + message := fmt.Sprintf("%s:%d: %s", loc.Filename, loc.Line, fmt.Sprintf(format, args...)) + return compileError(message) +} + +func (cl *compiler) isUncondJump(op opcode) bool { + switch op { + case opJump, opReturnFalse, opReturnTrue, opReturnTop: + return true + default: + return false + } +} + +func (cl *compiler) isSupportedType(typ types.Type) bool { + switch typ := typ.Underlying().(type) { + case *types.Pointer: + // 1. Pointers to structs are supported. + _, isStruct := typ.Elem().Underlying().(*types.Struct) + return isStruct + + case *types.Basic: + // 2. Some of the basic types are supported. + // TODO: support byte/uint8 and maybe float64. + switch typ.Kind() { + case types.Bool, types.Int, types.String: + return true + default: + return false + } + + case *types.Interface: + // 3. Interfaces are supported. + return true + + default: + return false + } +} diff --git a/ruleguard/quasigo/compile_test.go b/ruleguard/quasigo/compile_test.go new file mode 100644 index 00000000..e1ef1892 --- /dev/null +++ b/ruleguard/quasigo/compile_test.go @@ -0,0 +1,266 @@ +package quasigo + +import ( + "fmt" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestCompile(t *testing.T) { + tests := map[string][]string{ + `return 1`: { + ` PushConst 0 # value=1`, + ` ReturnTop`, + }, + + `return false`: { + ` ReturnFalse`, + }, + + `return true`: { + ` ReturnTrue`, + }, + + `return b`: { + ` PushParam 2 # b`, + ` ReturnTop`, + }, + + `return i == 2`: { + ` PushParam 0 # i`, + ` PushConst 0 # value=2`, + ` EqInt`, + ` ReturnTop`, + }, + + `return i == 10 || i == 2`: { + ` PushParam 0 # i`, + ` PushConst 0 # value=10`, + ` EqInt`, + ` JumpTrue 9 # L0`, + ` Pop`, + ` PushParam 0 # i`, + ` PushConst 1 # value=2`, + ` EqInt`, + `L0:`, + ` ReturnTop`, + }, + + `return i == 10 && s == "foo"`: { + ` PushParam 0 # i`, + ` PushConst 0 # value=10`, + ` EqInt`, + ` JumpFalse 9 # L0`, + ` Pop`, + ` PushParam 1 # s`, + ` PushConst 1 # value="foo"`, + ` EqString`, + `L0:`, + ` ReturnTop`, + }, + + `return imul(i, 5) == 10`: { + ` PushParam 0 # i`, + ` PushConst 0 # value=5`, + ` CallBuiltin 0 # testpkg.imul`, + ` PushConst 1 # value=10`, + ` EqInt`, + ` ReturnTop`, + }, + + `x := 10; y := x; return y`: { + ` PushConst 0 # value=10`, + ` SetLocal 0 # x`, + ` PushLocal 0 # x`, + ` SetLocal 1 # y`, + ` PushLocal 1 # y`, + ` ReturnTop`, + }, + + `if b { return 1 }; return 0`: { + ` PushParam 2 # b`, + ` JumpFalse 7 # L0`, + ` Pop`, + ` PushConst 0 # value=1`, + ` ReturnTop`, + `L0:`, + ` PushConst 1 # value=0`, + ` ReturnTop`, + }, + + `if b { return 1 } else { return 0 }`: { + ` PushParam 2 # b`, + ` JumpFalse 7 # L0`, + ` Pop`, + ` PushConst 0 # value=1`, + ` ReturnTop`, + `L0:`, + ` PushConst 1 # value=0`, + ` ReturnTop`, + }, + + `x := 0; if b { x = 5 } else { x = 50 }; return x`: { + ` PushConst 0 # value=0`, + ` SetLocal 0 # x`, + ` PushParam 2 # b`, + ` JumpFalse 11 # L0`, + ` Pop`, + ` PushConst 1 # value=5`, + ` SetLocal 0 # x`, + ` Jump 7 # L1`, + `L0:`, + ` PushConst 2 # value=50`, + ` SetLocal 0 # x`, + `L1:`, + ` PushLocal 0 # x`, + ` ReturnTop`, + }, + + `if i != 2 { return "a" } else if b { return "b" }; return "c"`: { + ` PushParam 0 # i`, + ` PushConst 0 # value=2`, + ` NotEqInt`, + ` JumpFalse 7 # L0`, + ` Pop`, + ` PushConst 1 # value="a"`, + ` ReturnTop`, + `L0:`, + ` PushParam 2 # b`, + ` JumpFalse 7 # L1`, + ` Pop`, + ` PushConst 2 # value="b"`, + ` ReturnTop`, + `L1:`, + ` PushConst 3 # value="c"`, + ` ReturnTop`, + }, + + `return eface == nil`: { + ` PushParam 3 # eface`, + ` IsNil`, + ` ReturnTop`, + }, + + `return nil == eface`: { + ` PushParam 3 # eface`, + ` IsNil`, + ` ReturnTop`, + }, + + `return eface != nil`: { + ` PushParam 3 # eface`, + ` IsNotNil`, + ` ReturnTop`, + }, + + `return nil != eface`: { + ` PushParam 3 # eface`, + ` IsNotNil`, + ` ReturnTop`, + }, + + `return s[:]`: { + ` PushParam 1 # s`, + ` ReturnTop`, + }, + + `return s[1:]`: { + ` PushParam 1 # s`, + ` PushConst 0 # value=1`, + ` StringSliceFrom`, + ` ReturnTop`, + }, + + `return s[:1]`: { + ` PushParam 1 # s`, + ` PushConst 0 # value=1`, + ` StringSliceTo`, + ` ReturnTop`, + }, + + `return s[1:2]`: { + ` PushParam 1 # s`, + ` PushConst 0 # value=1`, + ` PushConst 1 # value=2`, + ` StringSlice`, + ` ReturnTop`, + }, + + `return len(s) >= 0`: { + ` PushParam 1 # s`, + ` StringLen`, + ` PushConst 0 # value=0`, + ` GtEqInt`, + ` ReturnTop`, + }, + + `return i > 0`: { + ` PushParam 0 # i`, + ` PushConst 0 # value=0`, + ` GtInt`, + ` ReturnTop`, + }, + + `return i < 0`: { + ` PushParam 0 # i`, + ` PushConst 0 # value=0`, + ` LtInt`, + ` ReturnTop`, + }, + + `return i <= 0`: { + ` PushParam 0 # i`, + ` PushConst 0 # value=0`, + ` LtEqInt`, + ` ReturnTop`, + }, + } + + makePackageSource := func(body string) string { + return ` + package test + func f(i int, s string, b bool, eface interface{}) interface{} { + ` + body + ` + } + func imul(x, y int) int + func idiv(x, y int) int + ` + } + + env := NewEnv() + env.AddNativeFunc(testPackage, "imul", func(stack *ValueStack) { + x, y := stack.Pop2() + stack.Push(x.(int) * y.(int)) + }) + env.AddNativeFunc(testPackage, "idiv", func(stack *ValueStack) { + x, y := stack.Pop2() + stack.Push(x.(int) / y.(int)) + }) + + for testSrc, disasmLines := range tests { + src := makePackageSource(testSrc) + parsed, err := parseGoFile(src) + if err != nil { + t.Errorf("parse %s: %v", testSrc, err) + continue + } + compiled, err := compileTestFunc(env, "f", parsed) + if err != nil { + t.Errorf("compile %s: %v", testSrc, err) + continue + } + want := disasmLines + have := strings.Split(Disasm(env, compiled), "\n") + have = have[:len(have)-1] // Drop an empty line + if diff := cmp.Diff(have, want); diff != "" { + t.Errorf("compile %s (-have +want):\n%s", testSrc, diff) + fmt.Println("For copy/paste:") + for _, l := range have { + fmt.Printf(" `%s`,\n", l) + } + continue + } + } +} diff --git a/ruleguard/quasigo/debug_info.go b/ruleguard/quasigo/debug_info.go new file mode 100644 index 00000000..e42bbb76 --- /dev/null +++ b/ruleguard/quasigo/debug_info.go @@ -0,0 +1,16 @@ +package quasigo + +type debugInfo struct { + funcs map[*Func]funcDebugInfo +} + +type funcDebugInfo struct { + paramNames []string + localNames []string +} + +func newDebugInfo() *debugInfo { + return &debugInfo{ + funcs: make(map[*Func]funcDebugInfo), + } +} diff --git a/ruleguard/quasigo/disasm.go b/ruleguard/quasigo/disasm.go new file mode 100644 index 00000000..df033cfb --- /dev/null +++ b/ruleguard/quasigo/disasm.go @@ -0,0 +1,71 @@ +package quasigo + +import ( + "fmt" + "strings" +) + +// TODO(quasilyte): generate extra opcode info so we can simplify disasm function? + +func disasm(env *Env, fn *Func) string { + var out strings.Builder + + dbg, ok := env.debug.funcs[fn] + if !ok { + return "\n" + } + + code := fn.code + labels := map[int]string{} + walkBytecode(code, func(pc int, op opcode) { + switch op { + case opJumpTrue, opJumpFalse, opJump: + offset := decode16(code, pc+1) + targetPC := pc + offset + if _, ok := labels[targetPC]; !ok { + labels[targetPC] = fmt.Sprintf("L%d", len(labels)) + } + } + }) + + walkBytecode(code, func(pc int, op opcode) { + if l := labels[pc]; l != "" { + fmt.Fprintf(&out, "%s:\n", l) + } + var arg interface{} + var comment string + switch op { + case opCallBuiltin: + id := decode16(code, pc+1) + arg = id + comment = env.nativeFuncs[id].name + case opPushParam: + index := int(code[pc+1]) + arg = index + comment = dbg.paramNames[index] + case opSetLocal, opPushLocal: + index := int(code[pc+1]) + arg = index + comment = dbg.localNames[index] + case opPushConst: + arg = int(code[pc+1]) + comment = fmt.Sprintf("value=%#v", fn.constants[code[pc+1]]) + case opJumpTrue, opJumpFalse, opJump: + offset := decode16(code, pc+1) + targetPC := pc + offset + arg = offset + comment = labels[targetPC] + } + + if comment != "" { + comment = " # " + comment + } + if arg == nil { + fmt.Fprintf(&out, " %s%s\n", op, comment) + } else { + fmt.Fprintf(&out, " %s %#v%s\n", op, arg, comment) + } + }) + + return out.String() +} diff --git a/ruleguard/quasigo/env.go b/ruleguard/quasigo/env.go new file mode 100644 index 00000000..b3beebe8 --- /dev/null +++ b/ruleguard/quasigo/env.go @@ -0,0 +1,42 @@ +package quasigo + +type funcKey struct { + qualifier string + name string +} + +func (k funcKey) String() string { + if k.qualifier != "" { + return k.qualifier + "." + k.name + } + return k.name +} + +type nativeFunc struct { + mappedFunc func(*ValueStack) + name string // Needed for the readable disasm +} + +func newEnv() *Env { + return &Env{ + nameToBuiltinFuncID: make(map[funcKey]uint16), + nameToFuncID: make(map[funcKey]uint16), + + debug: newDebugInfo(), + } +} + +func (env *Env) addNativeFunc(key funcKey, f func(*ValueStack)) { + id := len(env.nativeFuncs) + env.nativeFuncs = append(env.nativeFuncs, nativeFunc{ + mappedFunc: f, + name: key.String(), + }) + env.nameToBuiltinFuncID[key] = uint16(id) +} + +func (env *Env) addFunc(key funcKey, f *Func) { + id := len(env.userFuncs) + env.userFuncs = append(env.userFuncs, f) + env.nameToFuncID[key] = uint16(id) +} diff --git a/ruleguard/quasigo/eval.go b/ruleguard/quasigo/eval.go new file mode 100644 index 00000000..aed76d7a --- /dev/null +++ b/ruleguard/quasigo/eval.go @@ -0,0 +1,180 @@ +package quasigo + +import ( + "fmt" + "reflect" +) + +const maxFuncLocals = 8 + +func eval(env *EvalEnv, fn *Func, args []interface{}) interface{} { + pc := 0 + code := fn.code + stack := env.stack + var locals [maxFuncLocals]interface{} + + for { + switch op := opcode(code[pc]); op { + case opPushParam: + index := code[pc+1] + stack.Push(args[index]) + pc += 2 + + case opPushLocal: + index := code[pc+1] + stack.Push(locals[index]) + pc += 2 + + case opSetLocal: + index := code[pc+1] + locals[index] = stack.Pop() + pc += 2 + + case opPop: + stack.Discard() + pc++ + case opDup: + stack.Dup() + pc++ + + case opPushConst: + id := code[pc+1] + stack.Push(fn.constants[id]) + pc += 2 + + case opPushTrue: + stack.Push(true) + pc++ + case opPushFalse: + stack.Push(false) + pc++ + + case opReturnTrue: + return true + case opReturnFalse: + return false + case opReturnTop: + return stack.Top() + + case opCallBuiltin: + id := decode16(code, pc+1) + fn := env.nativeFuncs[id].mappedFunc + fn(&stack) + pc += 3 + + case opJump: + offset := decode16(code, pc+1) + pc += offset + + case opJumpFalse: + if !stack.Top().(bool) { + offset := decode16(code, pc+1) + pc += offset + } else { + pc += 3 + } + case opJumpTrue: + if stack.Top().(bool) { + offset := decode16(code, pc+1) + pc += offset + } else { + pc += 3 + } + + case opNot: + stack.Push(!stack.Pop().(bool)) + pc++ + + case opConcat: + x, y := stack.Pop2() + stack.Push(x.(string) + y.(string)) + pc++ + + case opAdd: + x, y := stack.Pop2() + stack.Push(x.(int) + y.(int)) + pc++ + + case opSub: + x, y := stack.Pop2() + stack.Push(x.(int) - y.(int)) + pc++ + + case opEqInt: + x, y := stack.Pop2() + stack.Push(x.(int) == y.(int)) + pc++ + + case opNotEqInt: + x, y := stack.Pop2() + stack.Push(x.(int) != y.(int)) + pc++ + + case opGtInt: + x, y := stack.Pop2() + stack.Push(x.(int) > y.(int)) + pc++ + + case opGtEqInt: + x, y := stack.Pop2() + stack.Push(x.(int) >= y.(int)) + pc++ + + case opLtInt: + x, y := stack.Pop2() + stack.Push(x.(int) < y.(int)) + pc++ + + case opLtEqInt: + x, y := stack.Pop2() + stack.Push(x.(int) <= y.(int)) + pc++ + + case opEqString: + x, y := stack.Pop2() + stack.Push(x.(string) == y.(string)) + pc++ + + case opNotEqString: + x, y := stack.Pop2() + stack.Push(x.(string) != y.(string)) + pc++ + + case opIsNil: + x := stack.Pop() + stack.Push(x == nil || reflect.ValueOf(x).IsNil()) + pc++ + + case opIsNotNil: + x := stack.Pop() + stack.Push(x != nil && !reflect.ValueOf(x).IsNil()) + pc++ + + case opStringSlice: + to := stack.Pop().(int) + from := stack.Pop().(int) + s := stack.Pop().(string) + stack.Push(s[from:to]) + pc++ + + case opStringSliceFrom: + from := stack.Pop().(int) + s := stack.Pop().(string) + stack.Push(s[from:]) + pc++ + + case opStringSliceTo: + to := stack.Pop().(int) + s := stack.Pop().(string) + stack.Push(s[:to]) + pc++ + + case opStringLen: + stack.Push(len(stack.Pop().(string))) + pc++ + + default: + panic(fmt.Sprintf("malformed bytecode: unexpected %s found", op)) + } + } +} diff --git a/ruleguard/quasigo/eval_bench_test.go b/ruleguard/quasigo/eval_bench_test.go new file mode 100644 index 00000000..4bbe3fc1 --- /dev/null +++ b/ruleguard/quasigo/eval_bench_test.go @@ -0,0 +1,73 @@ +package quasigo + +import ( + "testing" +) + +func BenchmarkEval(b *testing.B) { + type testCase struct { + name string + src string + } + tests := []*testCase{ + { + `ReturnFalse`, + `return false`, + }, + + { + `LocalVars`, + `x := 1; y := x; return y`, + }, + + { + `IfStmt`, + `x := 100; if x == 1 { x = 10 } else if x == 2 { x = 20 } else { x = 30 }; return x`, + }, + + { + `CallNative`, + `return imul(1, 5) + imul(2, 2)`, + }, + } + + runBench := func(b *testing.B, env *EvalEnv, fn *Func) { + for i := 0; i < b.N; i++ { + _ = Call(env, fn) + } + } + + makePackageSource := func(body string) string { + return ` + package test + func f() interface{} { + ` + body + ` + } + func imul(x, y int) int + ` + } + + for _, test := range tests { + test := test + b.Run(test.name, func(b *testing.B) { + env := NewEnv() + env.AddNativeFunc(testPackage, "imul", func(stack *ValueStack) { + x, y := stack.Pop2() + stack.Push(x.(int) * y.(int)) + }) + src := makePackageSource(test.src) + parsed, err := parseGoFile(src) + if err != nil { + b.Fatalf("parse %s: %v", test.src, err) + } + compiled, err := compileTestFunc(env, "f", parsed) + if err != nil { + b.Fatalf("compile %s: %v", test.src, err) + } + + b.ResetTimer() + runBench(b, env.GetEvalEnv(), compiled) + }) + } + +} diff --git a/ruleguard/quasigo/eval_test.go b/ruleguard/quasigo/eval_test.go new file mode 100644 index 00000000..da523e08 --- /dev/null +++ b/ruleguard/quasigo/eval_test.go @@ -0,0 +1,199 @@ +package quasigo + +import ( + "fmt" + "testing" + + "github.com/quasilyte/go-ruleguard/ruleguard/quasigo/internal/evaltest" +) + +func TestEval(t *testing.T) { + type testCase struct { + src string + result interface{} + } + + exprTests := []testCase{ + // Const literals. + {`1`, 1}, + {`"foo"`, "foo"}, + {`true`, true}, + {`false`, false}, + + // Function args. + {`b`, true}, + {`i`, 10}, + + // Arith operators. + {`5 + 5`, 10}, + {`i + i`, 20}, + {`i - 5`, 5}, + {`5 - i`, -5}, + + // String operators. + {`s + s`, "foofoo"}, + + // Bool operators. + {`!b`, false}, + {`!!b`, true}, + {`i == 2`, false}, + {`i == 10`, true}, + {`i >= 10`, true}, + {`i >= 9`, true}, + {`i >= 11`, false}, + {`i > 10`, false}, + {`i > 9`, true}, + {`i > -1`, true}, + {`i < 10`, false}, + {`i < 11`, true}, + {`i <= 10`, true}, + {`i <= 11`, true}, + {`i != 2`, true}, + {`i != 10`, false}, + {`s != "foo"`, false}, + {`s != "bar"`, true}, + + // || operator. + {`i == 2 || i == 10`, true}, + {`i == 10 || i == 2`, true}, + {`i == 2 || i == 3 || i == 10`, true}, + {`i == 2 || i == 10 || i == 3`, true}, + {`i == 10 || i == 2 || i == 3`, true}, + {`!(i == 10 || i == 2 || i == 3)`, false}, + + // && operator. + {`i == 10 && s == "foo"`, true}, + {`i == 10 && s == "foo" && true`, true}, + {`i == 20 && s == "foo"`, false}, + {`i == 10 && s == "bar"`, false}, + {`i == 10 && s == "foo" && false`, false}, + + // Builtin func call. + {`imul(2, 3)`, 6}, + {`idiv(9, 3)`, 3}, + {`idiv(imul(2, 3), 1 + 1)`, 3}, + + // Method call. + {`foo.Method1(40)`, "Hello40"}, + {`newFoo("x").Method1(11)`, "x11"}, + + // Accesing the fields. + {`foo.Prefix`, "Hello"}, + + // Nil checks. + {`nilfoo == nil`, true}, + {`nileface == nil`, true}, + {`nil == nilfoo`, true}, + {`nil == nileface`, true}, + {`nilfoo != nil`, false}, + {`nileface != nil`, false}, + {`nil != nilfoo`, false}, + {`nil != nileface`, false}, + {`foo == nil`, false}, + {`foo != nil`, true}, + + // String slicing. + {`s[:]`, "foo"}, + {`s[0:]`, "foo"}, + {`s[1:]`, "oo"}, + {`s[:1]`, "f"}, + {`s[:0]`, ""}, + {`s[1:2]`, "o"}, + {`s[1:3]`, "oo"}, + + // Builtin len(). + {`len(s)`, 3}, + {`len(s) == 3`, true}, + {`len(s[1:])`, 2}, + } + + tests := []testCase{ + {`if b { return 1 }; return 0`, 1}, + {`if !b { return 1 }; return 0`, 0}, + {`if b { return 1 } else { return 0 }`, 1}, + {`if !b { return 1 } else { return 0 }`, 0}, + {`x := 2; if x == 2 { return "a" } else if x == 0 { return "b" }; return "c"`, "a"}, + {`x := 2; if x == 0 { return "a" } else if x == 2 { return "b" }; return "c"`, "b"}, + {`x := 2; if x == 0 { return "a" } else if x == 1 { return "b" }; return "c"`, "c"}, + {`x := 2; if x == 2 { return "a" } else if x == 0 { return "b" } else { return "c" }`, "a"}, + {`x := 2; if x == 0 { return "a" } else if x == 2 { return "b" } else { return "c" }`, "b"}, + {`x := 2; if x == 0 { return "a" } else if x == 1 { return "b" } else { return "c" }`, "c"}, + {`x := 0; if b { x = 5 } else { x = 50 }; return x`, 5}, + {`x := 0; if !b { x = 5 } else { x = 50 }; return x`, 50}, + {`x := 0; if b { x = 1 } else if x == 0 { x = 2 } else { x = 3 }; return x`, 1}, + {`x := 0; if !b { x = 1 } else if x == 0 { x = 2 } else { x = 3 }; return x`, 2}, + {`x := 0; if !b { x = 1 } else if x == 1 { x = 2 } else { x = 3 }; return x`, 3}, + } + + for _, test := range exprTests { + test.src = `return ` + test.src + tests = append(tests, test) + } + + makePackageSource := func(body string, result interface{}) string { + var returnType string + switch result.(type) { + case int: + returnType = "int" + case string: + returnType = "string" + case bool: + returnType = "bool" + } + return ` + package test + import "github.com/quasilyte/go-ruleguard/ruleguard/quasigo/internal/evaltest" + func target(i int, s string, b bool, foo, nilfoo *evaltest.Foo, nileface interface{}) ` + returnType + ` { + ` + body + ` + } + func imul(x, y int) int + func idiv(x, y int) int + func newFoo(prefix string) * evaltest.Foo + ` + } + + env := NewEnv() + env.AddNativeFunc(testPackage, "imul", func(stack *ValueStack) { + x, y := stack.Pop2() + stack.Push(x.(int) * y.(int)) + }) + env.AddNativeFunc(testPackage, "idiv", func(stack *ValueStack) { + x, y := stack.Pop2() + stack.Push(x.(int) / y.(int)) + }) + env.AddNativeFunc(testPackage, "newFoo", func(stack *ValueStack) { + prefix := stack.Pop().(string) + stack.Push(&evaltest.Foo{Prefix: prefix}) + }) + + const evaltestPkgPath = `github.com/quasilyte/go-ruleguard/ruleguard/quasigo/internal/evaltest` + const evaltestFoo = `*` + evaltestPkgPath + `.Foo` + env.AddNativeMethod(evaltestFoo, "Method1", func(stack *ValueStack) { + obj, x := stack.Pop2() + foo := obj.(*evaltest.Foo) + stack.Push(foo.Prefix + fmt.Sprint(x.(int))) + }) + env.AddNativeMethod(evaltestFoo, "Prefix", func(stack *ValueStack) { + foo := stack.Pop().(*evaltest.Foo) + stack.Push(foo.Prefix) + }) + + for _, test := range tests { + src := makePackageSource(test.src, test.result) + parsed, err := parseGoFile(src) + if err != nil { + t.Errorf("parse %s: %v", test.src, err) + continue + } + compiled, err := compileTestFunc(env, "target", parsed) + if err != nil { + t.Errorf("compile %s: %v", test.src, err) + continue + } + result := Call(env.GetEvalEnv(), compiled, + 10, "foo", true, &evaltest.Foo{Prefix: "Hello"}, (*evaltest.Foo)(nil), nil) + if result != test.result { + t.Errorf("eval %s:\nhave: %#v\nwant: %#v", test.src, result, test.result) + } + } +} diff --git a/ruleguard/quasigo/gen_opcodes.go b/ruleguard/quasigo/gen_opcodes.go new file mode 100644 index 00000000..01d517c8 --- /dev/null +++ b/ruleguard/quasigo/gen_opcodes.go @@ -0,0 +1,177 @@ +// +build main + +package main + +import ( + "bytes" + "fmt" + "go/format" + "io/ioutil" + "log" + "strings" + "text/template" +) + +var opcodePrototypes = []opcodeProto{ + {"Pop", "op", "(value) -> ()"}, + {"Dup", "op", "(x) -> (x x)"}, + + {"PushParam", "op index:u8", "() -> (value)"}, + {"PushLocal", "op index:u8", "() -> (value)"}, + {"PushFalse", "op", "() -> (false)"}, + {"PushTrue", "op", "() -> (true)"}, + {"PushConst", "op constid:u8", "() -> (const)"}, + + {"SetLocal", "op index:u8", "(value) -> ()"}, + + {"ReturnTop", "op", "(value) -> (value)"}, + {"ReturnFalse", "op", stackUnchanged}, + {"ReturnTrue", "op", stackUnchanged}, + + {"Jump", "op offset:i16", stackUnchanged}, + {"JumpFalse", "op offset:i16", "(cond:bool) -> (cond:bool)"}, + {"JumpTrue", "op offset:i16", "(cond:bool) -> (cond:bool)"}, + + {"CallBuiltin", "op funcid:u16", "(args...) -> (results...)"}, + + {"IsNil", "op", "(value) -> (result:bool)"}, + {"IsNotNil", "op", "(value) -> (result:bool)"}, + + {"Not", "op", "(value:bool) -> (result:bool)"}, + + {"EqInt", "op", "(x:int y:int) -> (result:bool)"}, + {"NotEqInt", "op", "(x:int y:int) -> (result:bool)"}, + {"GtInt", "op", "(x:int y:int) -> (result:bool)"}, + {"GtEqInt", "op", "(x:int y:int) -> (result:bool)"}, + {"LtInt", "op", "(x:int y:int) -> (result:bool)"}, + {"LtEqInt", "op", "(x:int y:int) -> (result:bool)"}, + + {"EqString", "op", "(x:string y:string) -> (result:bool)"}, + {"NotEqString", "op", "(x:string y:string) -> (result:bool)"}, + + {"Concat", "op", "(x:string y:string) -> (result:string)"}, + {"Add", "op", "(x:int y:int) -> (result:int)"}, + {"Sub", "op", "(x:int y:int) -> (result:int)"}, + + {"StringSlice", "op", "(s:string from:int to:int) -> (result:string)"}, + {"StringSliceFrom", "op", "(s:string from:int) -> (result:string)"}, + {"StringSliceTo", "op", "(s:string to:int) -> (result:string)"}, + {"StringLen", "op", "(s:string) -> (result:int)"}, +} + +type opcodeProto struct { + name string + enc string + stack string +} + +type encodingInfo struct { + width int + parts int +} + +type opcodeInfo struct { + Opcode byte + Name string + Enc string + EncString string + Stack string + Width int +} + +const stackUnchanged = "" + +var fileTemplate = template.Must(template.New("opcodes.go").Parse(`// Code generated "gen_opcodes.go"; DO NOT EDIT. + +package quasigo + +//go:generate stringer -type=opcode -trimprefix=op +type opcode byte + +const ( + opInvalid opcode = 0 +{{ range .Opcodes }} + // Encoding: {{.EncString}} + // Stack effect: {{ if .Stack}}{{.Stack}}{{else}}unchanged{{end}} + op{{ .Name }} opcode = {{.Opcode}} +{{ end -}} +) + +type opcodeInfo struct { + width int +} + +var opcodeInfoTable = [256]opcodeInfo{ + opInvalid: {width: 1}, + +{{ range .Opcodes -}} + op{{.Name}}: {width: {{.Width}}}, +{{ end }} +} +`)) + +func main() { + opcodes := make([]opcodeInfo, len(opcodePrototypes)) + for i, proto := range opcodePrototypes { + opcode := byte(i + 1) + encInfo := decodeEnc(proto.enc) + var encString string + if encInfo.parts == 1 { + encString = fmt.Sprintf("0x%02x (width=%d)", opcode, encInfo.width) + } else { + encString = fmt.Sprintf("0x%02x %s (width=%d)", + opcode, strings.TrimPrefix(proto.enc, "op "), encInfo.width) + } + + opcodes[i] = opcodeInfo{ + Opcode: opcode, + Name: proto.name, + Enc: proto.enc, + EncString: encString, + Stack: proto.stack, + Width: encInfo.width, + } + } + + var buf bytes.Buffer + err := fileTemplate.Execute(&buf, map[string]interface{}{ + "Opcodes": opcodes, + }) + if err != nil { + log.Panicf("execute template: %v", err) + } + writeFile("opcodes.gen.go", buf.Bytes()) +} + +func decodeEnc(enc string) encodingInfo { + fields := strings.Fields(enc) + width := 0 + for _, f := range fields { + parts := strings.Split(f, ":") + var typ string + if len(parts) == 2 { + typ = parts[1] + } else { + typ = "u8" + } + switch typ { + case "i8", "u8": + width++ + case "i16", "u16": + width += 2 + default: + panic(fmt.Sprintf("unknown op argument type: %s", typ)) + } + } + return encodingInfo{width: width, parts: len(fields)} +} + +func writeFile(filename string, data []byte) { + pretty, err := format.Source(data) + if err != nil { + log.Panicf("gofmt: %v", err) + } + if err := ioutil.WriteFile(filename, pretty, 0666); err != nil { + log.Panicf("write %s: %v", filename, err) + } +} diff --git a/ruleguard/quasigo/internal/evaltest/evaltest.go b/ruleguard/quasigo/internal/evaltest/evaltest.go new file mode 100644 index 00000000..b94010a7 --- /dev/null +++ b/ruleguard/quasigo/internal/evaltest/evaltest.go @@ -0,0 +1,9 @@ +package evaltest + +// This package is used for quasigo testing. + +type Foo struct { + Prefix string +} + +func (*Foo) Method1(x int) string { return "" } diff --git a/ruleguard/quasigo/opcode_string.go b/ruleguard/quasigo/opcode_string.go new file mode 100644 index 00000000..191e210f --- /dev/null +++ b/ruleguard/quasigo/opcode_string.go @@ -0,0 +1,56 @@ +// Code generated by "stringer -type=opcode -trimprefix=op"; DO NOT EDIT. + +package quasigo + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[opInvalid-0] + _ = x[opPop-1] + _ = x[opDup-2] + _ = x[opPushParam-3] + _ = x[opPushLocal-4] + _ = x[opPushFalse-5] + _ = x[opPushTrue-6] + _ = x[opPushConst-7] + _ = x[opSetLocal-8] + _ = x[opReturnTop-9] + _ = x[opReturnFalse-10] + _ = x[opReturnTrue-11] + _ = x[opJump-12] + _ = x[opJumpFalse-13] + _ = x[opJumpTrue-14] + _ = x[opCallBuiltin-15] + _ = x[opIsNil-16] + _ = x[opIsNotNil-17] + _ = x[opNot-18] + _ = x[opEqInt-19] + _ = x[opNotEqInt-20] + _ = x[opGtInt-21] + _ = x[opGtEqInt-22] + _ = x[opLtInt-23] + _ = x[opLtEqInt-24] + _ = x[opEqString-25] + _ = x[opNotEqString-26] + _ = x[opConcat-27] + _ = x[opAdd-28] + _ = x[opSub-29] + _ = x[opStringSlice-30] + _ = x[opStringSliceFrom-31] + _ = x[opStringSliceTo-32] + _ = x[opStringLen-33] +} + +const _opcode_name = "InvalidPopDupPushParamPushLocalPushFalsePushTruePushConstSetLocalReturnTopReturnFalseReturnTrueJumpJumpFalseJumpTrueCallBuiltinIsNilIsNotNilNotEqIntNotEqIntGtIntGtEqIntLtIntLtEqIntEqStringNotEqStringConcatAddSubStringSliceStringSliceFromStringSliceToStringLen" + +var _opcode_index = [...]uint16{0, 7, 10, 13, 22, 31, 40, 48, 57, 65, 74, 85, 95, 99, 108, 116, 127, 132, 140, 143, 148, 156, 161, 168, 173, 180, 188, 199, 205, 208, 211, 222, 237, 250, 259} + +func (i opcode) String() string { + if i >= opcode(len(_opcode_index)-1) { + return "opcode(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _opcode_name[_opcode_index[i]:_opcode_index[i+1]] +} diff --git a/ruleguard/quasigo/opcodes.gen.go b/ruleguard/quasigo/opcodes.gen.go new file mode 100644 index 00000000..99b7bc39 --- /dev/null +++ b/ruleguard/quasigo/opcodes.gen.go @@ -0,0 +1,184 @@ +// Code generated "gen_opcodes.go"; DO NOT EDIT. + +package quasigo + +//go:generate stringer -type=opcode -trimprefix=op +type opcode byte + +const ( + opInvalid opcode = 0 + + // Encoding: 0x01 (width=1) + // Stack effect: (value) -> () + opPop opcode = 1 + + // Encoding: 0x02 (width=1) + // Stack effect: (x) -> (x x) + opDup opcode = 2 + + // Encoding: 0x03 index:u8 (width=2) + // Stack effect: () -> (value) + opPushParam opcode = 3 + + // Encoding: 0x04 index:u8 (width=2) + // Stack effect: () -> (value) + opPushLocal opcode = 4 + + // Encoding: 0x05 (width=1) + // Stack effect: () -> (false) + opPushFalse opcode = 5 + + // Encoding: 0x06 (width=1) + // Stack effect: () -> (true) + opPushTrue opcode = 6 + + // Encoding: 0x07 constid:u8 (width=2) + // Stack effect: () -> (const) + opPushConst opcode = 7 + + // Encoding: 0x08 index:u8 (width=2) + // Stack effect: (value) -> () + opSetLocal opcode = 8 + + // Encoding: 0x09 (width=1) + // Stack effect: (value) -> (value) + opReturnTop opcode = 9 + + // Encoding: 0x0a (width=1) + // Stack effect: unchanged + opReturnFalse opcode = 10 + + // Encoding: 0x0b (width=1) + // Stack effect: unchanged + opReturnTrue opcode = 11 + + // Encoding: 0x0c offset:i16 (width=3) + // Stack effect: unchanged + opJump opcode = 12 + + // Encoding: 0x0d offset:i16 (width=3) + // Stack effect: (cond:bool) -> (cond:bool) + opJumpFalse opcode = 13 + + // Encoding: 0x0e offset:i16 (width=3) + // Stack effect: (cond:bool) -> (cond:bool) + opJumpTrue opcode = 14 + + // Encoding: 0x0f funcid:u16 (width=3) + // Stack effect: (args...) -> (results...) + opCallBuiltin opcode = 15 + + // Encoding: 0x10 (width=1) + // Stack effect: (value) -> (result:bool) + opIsNil opcode = 16 + + // Encoding: 0x11 (width=1) + // Stack effect: (value) -> (result:bool) + opIsNotNil opcode = 17 + + // Encoding: 0x12 (width=1) + // Stack effect: (value:bool) -> (result:bool) + opNot opcode = 18 + + // Encoding: 0x13 (width=1) + // Stack effect: (x:int y:int) -> (result:bool) + opEqInt opcode = 19 + + // Encoding: 0x14 (width=1) + // Stack effect: (x:int y:int) -> (result:bool) + opNotEqInt opcode = 20 + + // Encoding: 0x15 (width=1) + // Stack effect: (x:int y:int) -> (result:bool) + opGtInt opcode = 21 + + // Encoding: 0x16 (width=1) + // Stack effect: (x:int y:int) -> (result:bool) + opGtEqInt opcode = 22 + + // Encoding: 0x17 (width=1) + // Stack effect: (x:int y:int) -> (result:bool) + opLtInt opcode = 23 + + // Encoding: 0x18 (width=1) + // Stack effect: (x:int y:int) -> (result:bool) + opLtEqInt opcode = 24 + + // Encoding: 0x19 (width=1) + // Stack effect: (x:string y:string) -> (result:bool) + opEqString opcode = 25 + + // Encoding: 0x1a (width=1) + // Stack effect: (x:string y:string) -> (result:bool) + opNotEqString opcode = 26 + + // Encoding: 0x1b (width=1) + // Stack effect: (x:string y:string) -> (result:string) + opConcat opcode = 27 + + // Encoding: 0x1c (width=1) + // Stack effect: (x:int y:int) -> (result:int) + opAdd opcode = 28 + + // Encoding: 0x1d (width=1) + // Stack effect: (x:int y:int) -> (result:int) + opSub opcode = 29 + + // Encoding: 0x1e (width=1) + // Stack effect: (s:string from:int to:int) -> (result:string) + opStringSlice opcode = 30 + + // Encoding: 0x1f (width=1) + // Stack effect: (s:string from:int) -> (result:string) + opStringSliceFrom opcode = 31 + + // Encoding: 0x20 (width=1) + // Stack effect: (s:string to:int) -> (result:string) + opStringSliceTo opcode = 32 + + // Encoding: 0x21 (width=1) + // Stack effect: (s:string) -> (result:int) + opStringLen opcode = 33 +) + +type opcodeInfo struct { + width int +} + +var opcodeInfoTable = [256]opcodeInfo{ + opInvalid: {width: 1}, + + opPop: {width: 1}, + opDup: {width: 1}, + opPushParam: {width: 2}, + opPushLocal: {width: 2}, + opPushFalse: {width: 1}, + opPushTrue: {width: 1}, + opPushConst: {width: 2}, + opSetLocal: {width: 2}, + opReturnTop: {width: 1}, + opReturnFalse: {width: 1}, + opReturnTrue: {width: 1}, + opJump: {width: 3}, + opJumpFalse: {width: 3}, + opJumpTrue: {width: 3}, + opCallBuiltin: {width: 3}, + opIsNil: {width: 1}, + opIsNotNil: {width: 1}, + opNot: {width: 1}, + opEqInt: {width: 1}, + opNotEqInt: {width: 1}, + opGtInt: {width: 1}, + opGtEqInt: {width: 1}, + opLtInt: {width: 1}, + opLtEqInt: {width: 1}, + opEqString: {width: 1}, + opNotEqString: {width: 1}, + opConcat: {width: 1}, + opAdd: {width: 1}, + opSub: {width: 1}, + opStringSlice: {width: 1}, + opStringSliceFrom: {width: 1}, + opStringSliceTo: {width: 1}, + opStringLen: {width: 1}, +} diff --git a/ruleguard/quasigo/quasigo.go b/ruleguard/quasigo/quasigo.go new file mode 100644 index 00000000..f3ed0690 --- /dev/null +++ b/ruleguard/quasigo/quasigo.go @@ -0,0 +1,148 @@ +// Package quasigo implements a Go subset compiler and interpreter. +// +// The implementation details are not part of the contract of this package. +package quasigo + +import ( + "go/ast" + "go/token" + "go/types" +) + +// TODO(quasilyte): document what is thread-safe and what not. +// TODO(quasilyte): add a readme. + +// Env is used to hold both compilation and evaluation data. +type Env struct { + // TODO(quasilyte): store both builtin and user func ids in one map? + + nativeFuncs []nativeFunc + nameToBuiltinFuncID map[funcKey]uint16 + + userFuncs []*Func + nameToFuncID map[funcKey]uint16 + + // debug contains all information that is only needed + // for better debugging and compiled code introspection. + // Right now it's always enabled, but we may allow stripping it later. + debug *debugInfo +} + +// EvalEnv is a goroutine-local handle for Env. +// To get one, use Env.GetEvalEnv() method. +type EvalEnv struct { + nativeFuncs []nativeFunc + userFuncs []*Func + + stack ValueStack +} + +// NewEnv creates a new empty environment. +func NewEnv() *Env { + return newEnv() +} + +// GetEvalEnv creates a new goroutine-local handle of env. +func (env *Env) GetEvalEnv() *EvalEnv { + return &EvalEnv{ + nativeFuncs: env.nativeFuncs, + userFuncs: env.userFuncs, + stack: make([]interface{}, 0, 32), + } +} + +// AddNativeMethod binds `$typeName.$methodName` symbol with f. +// A typeName should be fully qualified, like `github.com/user/pkgname.TypeName`. +// It method is defined only on pointer type, the typeName should start with `*`. +func (env *Env) AddNativeMethod(typeName, methodName string, f func(*ValueStack)) { + env.addNativeFunc(funcKey{qualifier: typeName, name: methodName}, f) +} + +// AddNativeFunc binds `$pkgPath.$funcName` symbol with f. +// A pkgPath should be a full package path in which funcName is defined. +func (env *Env) AddNativeFunc(pkgPath, funcName string, f func(*ValueStack)) { + env.addNativeFunc(funcKey{qualifier: pkgPath, name: funcName}, f) +} + +// AddFunc binds `$pkgPath.$funcName` symbol with f. +func (env *Env) AddFunc(pkgPath, funcName string, f *Func) { + env.addFunc(funcKey{qualifier: pkgPath, name: funcName}, f) +} + +// GetFunc finds previously bound function searching for the `$pkgPath.$funcName` symbol. +func (env *Env) GetFunc(pkgPath, funcName string) *Func { + id := env.nameToFuncID[funcKey{qualifier: pkgPath, name: funcName}] + return env.userFuncs[id] +} + +// CompileContext is used to provide necessary data to the compiler. +type CompileContext struct { + // Env is shared environment that should be used for all functions + // being compiled; then it should be used to execute these functions. + Env *Env + + Types *types.Info + Fset *token.FileSet +} + +// Compile prepares an executable version of fn. +func Compile(ctx *CompileContext, fn *ast.FuncDecl) (compiled *Func, err error) { + return compile(ctx, fn) +} + +// Call invokes a given function with provided arguments. +func Call(env *EvalEnv, fn *Func, args ...interface{}) interface{} { + env.stack = env.stack[:0] + return eval(env, fn, args) +} + +// Disasm returns the compiled function disassembly text. +// This output is not guaranteed to be stable between versions +// and should be used only for debugging purposes. +func Disasm(env *Env, fn *Func) string { + return disasm(env, fn) +} + +// Func is a compiled function that is ready to be executed. +type Func struct { + code []byte + + constants []interface{} +} + +// ValueStack is used to manipulate runtime values during the evaluation. +// Function arguments are pushed to the stack. +// Function results are returned via stack as well. +type ValueStack []interface{} + +// Pop removes the top stack element and returns it. +func (s *ValueStack) Pop() interface{} { + x := (*s)[len(*s)-1] + *s = (*s)[:len(*s)-1] + return x +} + +// Pop2 removes the two top stack elements and returns them. +// +// Note that it returns the popped elements in the reverse order +// to make it easier to map the order in which they were pushed. +func (s *ValueStack) Pop2() (second interface{}, top interface{}) { + x := (*s)[len(*s)-2] + y := (*s)[len(*s)-1] + *s = (*s)[:len(*s)-2] + return x, y +} + +// Push adds x to the stack. +func (s *ValueStack) Push(x interface{}) { *s = append(*s, x) } + +// Top returns top of the stack without popping it. +func (s *ValueStack) Top() interface{} { return (*s)[len(*s)-1] } + +// Dup copies the top stack element. +// Identical to s.Push(s.Top()), but more concise. +func (s *ValueStack) Dup() { *s = append(*s, (*s)[len(*s)-1]) } + +// Discard drops the top stack element. +// Identical to s.Pop() without using the result. +func (s *ValueStack) Discard() { *s = (*s)[:len(*s)-1] } diff --git a/ruleguard/quasigo/quasigo_test.go b/ruleguard/quasigo/quasigo_test.go new file mode 100644 index 00000000..12f679fe --- /dev/null +++ b/ruleguard/quasigo/quasigo_test.go @@ -0,0 +1,65 @@ +package quasigo + +import ( + "fmt" + "go/ast" + "go/importer" + "go/parser" + "go/token" + "go/types" +) + +const testPackage = "testpkg" + +type parsedTestFile struct { + ast *ast.File + types *types.Info + fset *token.FileSet +} + +func parseGoFile(src string) (*parsedTestFile, error) { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "test.go", src, 0) + if err != nil { + return nil, err + } + typechecker := &types.Config{ + Importer: importer.ForCompiler(fset, "source", nil), + } + info := &types.Info{ + Types: map[ast.Expr]types.TypeAndValue{}, + Uses: map[*ast.Ident]types.Object{}, + Defs: map[*ast.Ident]types.Object{}, + } + _, err = typechecker.Check(testPackage, fset, []*ast.File{file}, info) + result := &parsedTestFile{ + ast: file, + types: info, + fset: fset, + } + return result, err +} + +func compileTestFunc(env *Env, fn string, parsed *parsedTestFile) (*Func, error) { + var target *ast.FuncDecl + for _, decl := range parsed.ast.Decls { + decl, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + if decl.Name.String() == fn { + target = decl + break + } + } + if target == nil { + return nil, fmt.Errorf("test function %s not found", fn) + } + + ctx := &CompileContext{ + Env: env, + Types: parsed.types, + Fset: parsed.fset, + } + return Compile(ctx, target) +} diff --git a/ruleguard/quasigo/utils.go b/ruleguard/quasigo/utils.go new file mode 100644 index 00000000..ddabe05f --- /dev/null +++ b/ruleguard/quasigo/utils.go @@ -0,0 +1,53 @@ +package quasigo + +import ( + "encoding/binary" + "go/ast" + "go/types" +) + +func put16(code []byte, pos, value int) { + binary.LittleEndian.PutUint16(code[pos:], uint16(value)) +} + +func decode16(code []byte, pos int) int { + return int(binary.LittleEndian.Uint16(code[pos:])) +} + +func typeIsInt(typ types.Type) bool { + basic, ok := typ.Underlying().(*types.Basic) + if !ok { + return false + } + switch basic.Kind() { + case types.Int, types.UntypedInt: + return true + default: + return false + } +} + +func typeIsString(typ types.Type) bool { + basic, ok := typ.Underlying().(*types.Basic) + if !ok { + return false + } + return basic.Info()&types.IsString != 0 +} + +func walkBytecode(code []byte, fn func(pc int, op opcode)) { + pc := 0 + for pc < len(code) { + op := opcode(code[pc]) + fn(pc, op) + pc += opcodeInfoTable[op].width + } +} + +func identName(n ast.Expr) string { + id, ok := n.(*ast.Ident) + if ok { + return id.Name + } + return "" +} diff --git a/ruleguard/ruleguard.go b/ruleguard/ruleguard.go index 8dc8a192..ba23861a 100644 --- a/ruleguard/ruleguard.go +++ b/ruleguard/ruleguard.go @@ -5,11 +5,46 @@ import ( "go/token" "go/types" "io" - - "github.com/quasilyte/go-ruleguard/ruleguard/typematch" ) +// Engine is the main ruleguard package API object. +// +// First, load some ruleguard files with Load() to build a rule set. +// Then use Run() to execute the rules. +// +// It's advised to have only 1 engine per application as it does a lot of caching. +// The Run() method is synchronized, so it can be used concurrently. +// +// An Engine must be created with NewEngine() function. +type Engine struct { + impl *engine +} + +// NewEngine creates an engine with empty rule set. +func NewEngine() *Engine { + return &Engine{impl: newEngine()} +} + +// Load reads a ruleguard file from r and adds it to the engine rule set. +// +// Load() is not thread-safe, especially if used concurrently with Run() method. +// It's advised to Load() all ruleguard files under a critical section (like sync.Once) +// and then use Run() to execute all of them. +func (e *Engine) Load(ctx *ParseContext, filename string, r io.Reader) error { + return e.impl.Load(ctx, filename, r) +} + +// Run executes all loaded rules on a given file. +// Matched rules invoke `RunContext.Report()` method. +// +// Run() is thread-safe, unless used in parallel with Load(), +// which modifies the engine state. +func (e *Engine) Run(ctx *RunContext, f *ast.File) error { + return e.impl.Run(ctx, f) +} + type ParseContext struct { + DebugFilter string DebugImports bool DebugPrint func(string) @@ -22,9 +57,10 @@ type ParseContext struct { Fset *token.FileSet } -type Context struct { - Debug string - DebugPrint func(string) +type RunContext struct { + Debug string + DebugImports bool + DebugPrint func(string) Types *types.Info Sizes types.Sizes @@ -39,20 +75,6 @@ type Suggestion struct { Replacement []byte } -func ParseRules(ctx *ParseContext, filename string, r io.Reader) (*GoRuleSet, error) { - config := rulesParserConfig{ - ctx: ctx, - itab: typematch.NewImportsTab(stdlibPackages), - importer: newGoImporter(ctx), - } - p := newRulesParser(config) - return p.ParseFile(filename, r) -} - -func RunRules(ctx *Context, f *ast.File, rules *GoRuleSet) error { - return newRulesRunner(ctx, rules).run(f) -} - type GoRuleInfo struct { // Filename is a file that defined this rule. Filename string @@ -63,17 +85,3 @@ type GoRuleInfo struct { // Group is a function name that contained this rule. Group string } - -type GoRuleSet struct { - universal *scopedGoRuleSet - local *scopedGoRuleSet - - groups map[string]token.Position // To handle redefinitions - - // Imports is a set of rule bundles that were imported. - Imports map[string]struct{} -} - -func MergeRuleSets(toMerge []*GoRuleSet) (*GoRuleSet, error) { - return mergeRuleSets(toMerge) -} diff --git a/ruleguard/ruleguard_error_test.go b/ruleguard/ruleguard_error_test.go new file mode 100644 index 00000000..c722ccb1 --- /dev/null +++ b/ruleguard/ruleguard_error_test.go @@ -0,0 +1,304 @@ +package ruleguard + +import ( + "fmt" + "go/token" + "strings" + "testing" +) + +func TestParseFilterFuncError(t *testing.T) { + type testCase struct { + src string + err string + } + + simpleTests := []testCase{ + // Unsupported features. + // Some of them might be implemented later, but for now + // we want to ensure that the user gets understandable error messages. + { + `b := true; switch true {}; return b`, + `can't compile *ast.SwitchStmt yet`, + }, + { + `b := 0; return &b != nil`, + `can't compile unary & yet`, + }, + { + `b := 0; return (b << 1) != 0`, + `can't compile binary << yet`, + }, + { + `return g(ctx)`, + `can't compile a call to gorules.g func`, + }, + { + `return new(int) != nil`, + `can't compile new() builtin function call yet`, + }, + { + `x := 5.6; return x != 0`, + `can't compile float constants yet`, + }, + { + `s := ""; return s >= "a"`, + `>= is not implemented for string operands`, + }, + { + `s := "foo"; b := s[0]; return b == 0`, + `can't compile *ast.IndexExpr yet`, + }, + { + `s := Foo{}; return s.X == 0`, + `can't compile *ast.CompositeLit yet`, + }, + + // Assignment errors. + { + `x, y := 1, 2; return x == y`, + `only single left operand is allowed in assignments`, + }, + { + `x := 0; { x := 1; return x == 1 }; return x == 0`, + `x variable shadowing is not allowed`, + }, + { + `ctx = ctx; return true`, + `can't assign to ctx, params are readonly`, + }, + { + `ctx.Type = nil; return true`, + `can assign only to simple variables`, + }, + + // Unsupported type errors. + { + `x := int32(0); return x == 0`, + `x local variable type: int32 is not supported, try something simpler`, + }, + + // Implementation limits. + { + `x1:=1; x2:=x1; x3:=x2; x4:=x3; x5:=x4; x6:=x5; x7:=x6; x8:=x7; x9:=x8; return x9 == 1`, + `can't define x9: too many locals`, + }, + } + + tests := []testCase{ + { + `func f() int32 { return 0 }`, + `function result type: int32 is not supported, try something simpler`, + }, + { + `func f() []int { return nil }`, + `function result type: []int is not supported, try something simpler`, + }, + { + `func f(s *string) int { return 0 }`, + `s param type: *string is not supported, try something simpler`, + }, + + { + `func f(foo *Foo) int { return foo.X }`, + `can't compile X field access`, + }, + { + `func f(foo *Foo) string { return foo.String() }`, + `can't compile a call to *gorules.Foo.String func`, + }, + + { + `func f() {}`, + `only functions with a single non-void results are supported`, + }, + { + `func f() (int, int) { return 0, 0 }`, + `only functions with a single non-void results are supported`, + }, + + { + `func f() (b bool) { return }`, + `'naked' return statements are not allowed`, + }, + } + + for _, test := range simpleTests { + test.src = `func f(ctx *dsl.VarFilterContext) bool { ` + test.src + ` }` + tests = append(tests, test) + } + + for _, test := range tests { + file := fmt.Sprintf(` + package gorules + import "github.com/quasilyte/go-ruleguard/dsl" + type Foo struct { X int } + func (foo *Foo) String() string { return "" } + func g(ctx *dsl.VarFilterContext) bool { return false } + ` + test.src) + e := NewEngine() + ctx := &ParseContext{ + Fset: token.NewFileSet(), + } + err := e.Load(ctx, "rules.go", strings.NewReader(file)) + if err == nil { + t.Errorf("parse %s: expected %s error, got none", test.src, test.err) + continue + } + have := err.Error() + want := test.err + if !strings.Contains(have, want) { + t.Errorf("parse %s: errors mismatch:\nhave: %s\nwant: %s", test.src, have, want) + continue + } + } +} + +func TestParseRuleError(t *testing.T) { + tests := []struct { + expr string + err string + }{ + { + `m.Where(m.File().Imports("strings")).Report("no match call")`, + `missing Match() call`, + }, + + { + `m.Match("$x").Where(m["x"].Pure)`, + `missing Report() or Suggest() call`, + }, + + { + `m.Match("$x").Match("$x")`, + `Match() can't be repeated`, + }, + + { + `m.Match().Report("$$")`, + `too few arguments in call to m.Match`, + }, + + { + `m.Match("func[]").Report("$$")`, + `parse match pattern: cannot parse expr: 1:5: expected '(', found '['`, + }, + } + + for _, test := range tests { + file := fmt.Sprintf(` + package gorules + import "github.com/quasilyte/go-ruleguard/dsl" + func testrule(m dsl.Matcher) { + %s + }`, + test.expr) + e := NewEngine() + ctx := &ParseContext{ + Fset: token.NewFileSet(), + } + err := e.Load(ctx, "rules.go", strings.NewReader(file)) + if err == nil { + t.Errorf("parse %s: expected %s error, got none", test.expr, test.err) + continue + } + have := err.Error() + want := test.err + if !strings.Contains(have, want) { + t.Errorf("parse %s: errors mismatch:\nhave: %s\nwant: %s", test.expr, have, want) + continue + } + } +} + +func TestParseFilterError(t *testing.T) { + tests := []struct { + expr string + err string + }{ + { + `true`, + `unsupported expr: true`, + }, + + { + `m["x"].Text == 5`, + `cannot convert 5 (untyped int constant) to string`, + }, + + { + `m["x"].Text.Matches("(12")`, + `error parsing regexp: missing closing )`, + }, + + { + `m["x"].Type.Is("%illegal")`, + `parse type expr: 1:1: expected operand, found '%'`, + }, + + { + `m["x"].Type.Is("interface{String() string}")`, + `parse type expr: can't convert interface{String() string} type expression`, + }, + + { + `m["x"].Type.ConvertibleTo("interface{String() string}")`, + `can't convert interface{String() string} into a type constraint yet`, + }, + + { + `m["x"].Type.AssignableTo("interface{String() string}")`, + `can't convert interface{String() string} into a type constraint yet`, + }, + + { + `m["x"].Type.Implements("foo")`, + "only `error` unqualified type is recognized", + }, + + { + `m["x"].Type.Implements("func()")`, + "only qualified names (and `error`) are supported", + }, + + { + `m["x"].Type.Implements("foo.Bar")`, + `package foo is not imported`, + }, + + { + `m["x"].Type.Implements("strings.Replacer3")`, + `Replacer3 is not found in strings`, + }, + + { + `m["x"].Node.Is("abc")`, + `abc is not a valid go/ast type name`, + }, + } + + for _, test := range tests { + file := fmt.Sprintf(` + package gorules + import "github.com/quasilyte/go-ruleguard/dsl" + func testrule(m dsl.Matcher) { + m.Match("$x + $y[$key]").Where(%s).Report("$$") + }`, + test.expr) + e := NewEngine() + ctx := &ParseContext{ + Fset: token.NewFileSet(), + } + err := e.Load(ctx, "rules.go", strings.NewReader(file)) + if err == nil { + t.Errorf("parse %s: expected %s error, got none", test.expr, test.err) + continue + } + have := err.Error() + want := test.err + if !strings.Contains(have, want) { + t.Errorf("parse %s: errors mismatch:\nhave: %s\nwant: %s", test.expr, have, want) + continue + } + } +} diff --git a/ruleguard/ruleguard_test.go b/ruleguard/ruleguard_test.go index b309b6ea..a01be33e 100644 --- a/ruleguard/ruleguard_test.go +++ b/ruleguard/ruleguard_test.go @@ -1,158 +1,13 @@ package ruleguard import ( - "fmt" "go/ast" "go/token" - "strings" "testing" "github.com/google/go-cmp/cmp" ) -func TestParseRuleError(t *testing.T) { - tests := []struct { - expr string - err string - }{ - { - `m.Where(m.File().Imports("strings")).Report("no match call")`, - `missing Match() call`, - }, - - { - `m.Match("$x").Where(m["x"].Pure)`, - `missing Report() or Suggest() call`, - }, - - { - `m.Match("$x").Match("$x")`, - `Match() can't be repeated`, - }, - - { - `m.Match().Report("$$")`, - `too few arguments in call to m.Match`, - }, - - { - `m.Match("func[]").Report("$$")`, - `parse match pattern: cannot parse expr: 1:5: expected '(', found '['`, - }, - } - - for _, test := range tests { - file := fmt.Sprintf(` - package gorules - import "github.com/quasilyte/go-ruleguard/dsl" - func testrule(m dsl.Matcher) { - %s - }`, - test.expr) - ctx := &ParseContext{Fset: token.NewFileSet()} - _, err := ParseRules(ctx, "rules.go", strings.NewReader(file)) - if err == nil { - t.Errorf("parse %s: expected %s error, got none", test.expr, test.err) - continue - } - have := err.Error() - want := test.err - if !strings.Contains(have, want) { - t.Errorf("parse %s: errors mismatch:\nhave: %s\nwant: %s", test.expr, have, want) - continue - } - } -} - -func TestParseFilterError(t *testing.T) { - tests := []struct { - expr string - err string - }{ - { - `true`, - `unsupported expr: true`, - }, - - { - `m["x"].Text == 5`, - `cannot convert 5 (untyped int constant) to string`, - }, - - { - `m["x"].Text.Matches("(12")`, - `error parsing regexp: missing closing )`, - }, - - { - `m["x"].Type.Is("%illegal")`, - `parse type expr: 1:1: expected operand, found '%'`, - }, - - { - `m["x"].Type.Is("interface{String() string}")`, - `parse type expr: can't convert interface{String() string} type expression`, - }, - - { - `m["x"].Type.ConvertibleTo("interface{String() string}")`, - `can't convert interface{String() string} into a type constraint yet`, - }, - - { - `m["x"].Type.AssignableTo("interface{String() string}")`, - `can't convert interface{String() string} into a type constraint yet`, - }, - - { - `m["x"].Type.Implements("foo")`, - "only `error` unqualified type is recognized", - }, - - { - `m["x"].Type.Implements("func()")`, - "only qualified names (and `error`) are supported", - }, - - { - `m["x"].Type.Implements("foo.Bar")`, - `package foo is not imported`, - }, - - { - `m["x"].Type.Implements("strings.Replacer3")`, - `Replacer3 is not found in strings`, - }, - - { - `m["x"].Node.Is("abc")`, - `abc is not a valid go/ast type name`, - }, - } - - for _, test := range tests { - file := fmt.Sprintf(` - package gorules - import "github.com/quasilyte/go-ruleguard/dsl" - func testrule(m dsl.Matcher) { - m.Match("$x + $y[$key]").Where(%s).Report("$$") - }`, - test.expr) - ctx := &ParseContext{Fset: token.NewFileSet()} - _, err := ParseRules(ctx, "rules.go", strings.NewReader(file)) - if err == nil { - t.Errorf("parse %s: expected %s error, got none", test.expr, test.err) - continue - } - have := err.Error() - want := test.err - if !strings.Contains(have, want) { - t.Errorf("parse %s: errors mismatch:\nhave: %s\nwant: %s", test.expr, have, want) - continue - } - } -} - func TestRenderMessage(t *testing.T) { tests := []struct { msg string @@ -227,8 +82,10 @@ func TestRenderMessage(t *testing.T) { }, } + e := NewEngine() var rr rulesRunner - rr.ctx = &Context{ + rr.state = e.impl.state + rr.ctx = &RunContext{ Fset: token.NewFileSet(), } for _, test := range tests { diff --git a/ruleguard/runner.go b/ruleguard/runner.go index 966c3e86..2048ce3e 100644 --- a/ruleguard/runner.go +++ b/ruleguard/runner.go @@ -12,11 +12,16 @@ import ( "strings" "github.com/quasilyte/go-ruleguard/internal/mvdan.cc/gogrep" + "github.com/quasilyte/go-ruleguard/ruleguard/goutil" ) type rulesRunner struct { - ctx *Context - rules *GoRuleSet + state *engineState + + ctx *RunContext + rules *goRuleSet + + importer *goImporter filename string src []byte @@ -27,12 +32,20 @@ type rulesRunner struct { filterParams filterParams } -func newRulesRunner(ctx *Context, rules *GoRuleSet) *rulesRunner { +func newRulesRunner(ctx *RunContext, state *engineState, rules *goRuleSet) *rulesRunner { + importer := newGoImporter(state, goImporterConfig{ + fset: ctx.Fset, + debugImports: ctx.DebugImports, + debugPrint: ctx.DebugPrint, + }) rr := &rulesRunner{ - ctx: ctx, - rules: rules, + ctx: ctx, + importer: importer, + rules: rules, filterParams: filterParams{ - ctx: ctx, + env: state.env.GetEvalEnv(), + importer: importer, + ctx: ctx, }, sortScratch: make([]string, 0, 8), } @@ -143,7 +156,7 @@ func (rr *rulesRunner) reject(rule goRule, reason string, m gogrep.MatchData) { if typ != nil { typeString = typ.String() } - s := strings.ReplaceAll(sprintNode(rr.ctx.Fset, expr), "\n", `\n`) + s := strings.ReplaceAll(goutil.SprintNode(rr.ctx.Fset, expr), "\n", `\n`) rr.ctx.DebugPrint(fmt.Sprintf(" $%s %s: %s", name, typeString, s)) } } diff --git a/ruleguard/utils.go b/ruleguard/utils.go index 797f2151..16fd7d68 100644 --- a/ruleguard/utils.go +++ b/ruleguard/utils.go @@ -4,22 +4,24 @@ import ( "go/ast" "go/constant" "go/parser" - "go/printer" "go/token" "go/types" "strconv" "strings" ) -func sprintNode(fset *token.FileSet, n ast.Node) string { - if fset == nil { - fset = token.NewFileSet() +func findDependency(pkg *types.Package, path string) *types.Package { + if pkg.Path() == path { + return pkg } - var buf strings.Builder - if err := printer.Fprint(&buf, fset, n); err != nil { - return "" + // It looks like indirect dependencies are always incomplete? + // If it's true, then we don't have to recurse here. + for _, imported := range pkg.Imports() { + if dep := findDependency(imported, path); dep != nil && dep.Complete() { + return dep + } } - return buf.String() + return nil } var basicTypeByName = map[string]types.Type{