Skip to content

Commit

Permalink
ruleguard/quasigo: implement void funcs (#307)
Browse files Browse the repository at this point in the history
Added EvalTest files to simplify the quasigo testing.
  • Loading branch information
quasilyte authored Nov 17, 2021
1 parent f98e474 commit 2d7358b
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 79 deletions.
77 changes: 64 additions & 13 deletions ruleguard/quasigo/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"golang.org/x/tools/go/ast/astutil"
)

var voidType = &types.Tuple{}

func compile(ctx *CompileContext, fn *ast.FuncDecl) (compiled *Func, err error) {
defer func() {
if err != nil {
Expand Down Expand Up @@ -74,10 +76,14 @@ type compileError string
func (e compileError) Error() string { return string(e) }

func (cl *compiler) compileFunc(fn *ast.FuncDecl) *Func {
if cl.fnType.Results().Len() != 1 {
panic(cl.errorf(fn.Name, "only functions with a single non-void results are supported"))
switch cl.fnType.Results().Len() {
case 0:
cl.retType = voidType
case 1:
cl.retType = cl.fnType.Results().At(0).Type()
default:
panic(cl.errorf(fn.Name, "multi-result functions are not supported"))
}
cl.retType = cl.fnType.Results().At(0).Type()

if !cl.isSupportedType(cl.retType) {
panic(cl.errorUnsupportedType(fn.Name, cl.retType, "function result"))
Expand Down Expand Up @@ -136,6 +142,9 @@ func (cl *compiler) compileStmt(stmt ast.Stmt) {
case *ast.BranchStmt:
cl.compileBranchStmt(stmt)

case *ast.ExprStmt:
cl.compileExprStmt(stmt)

case *ast.BlockStmt:
for i := range stmt.List {
cl.compileStmt(stmt.List[i])
Expand Down Expand Up @@ -172,6 +181,19 @@ func (cl *compiler) compileBranchStmt(branch *ast.BranchStmt) {
}
}

func (cl *compiler) compileExprStmt(stmt *ast.ExprStmt) {
if call, ok := stmt.X.(*ast.CallExpr); ok {
sig := cl.ctx.Types.TypeOf(call.Fun).(*types.Signature)
if sig.Results() != nil {
panic(cl.errorf(call, "only void funcs can be used in stmt context"))
}
cl.compileCallExpr(call)
return
}

panic(cl.errorf(stmt.X, "can't compile this expr stmt yet: %T", stmt.X))
}

func (cl *compiler) compileForStmt(stmt *ast.ForStmt) {
labelBreak := cl.newLabel()
labelContinue := cl.newLabel()
Expand Down Expand Up @@ -279,6 +301,11 @@ func (cl *compiler) getLocal(v ast.Expr, varname string) int {
}

func (cl *compiler) compileReturnStmt(ret *ast.ReturnStmt) {
if cl.retType == voidType {
cl.emit(opReturn)
return
}

if ret.Results == nil {
panic(cl.errorf(ret, "'naked' return statements are not allowed"))
}
Expand Down Expand Up @@ -471,6 +498,20 @@ func (cl *compiler) compileBuiltinCall(fn *ast.Ident, call *ast.CallExpr) {
panic(cl.errorf(s, "can't compile len() with non-string argument yet"))
}
cl.emit(opStringLen)

case `println`:
if len(call.Args) != 1 {
panic(cl.errorf(call, "only 1-arg form of println() is supported"))
}
funcName := "Print"
if typeIsInt(cl.ctx.Types.TypeOf(call.Args[0])) {
funcName = "PrintInt"
}
key := funcKey{qualifier: "builtin", name: funcName}
if !cl.compileNativeCall(key, nil, call.Args) {
panic(cl.errorf(fn, "builtin.%s native func is not registered", funcName))
}

default:
panic(cl.errorf(fn, "can't compile %s() builtin function call yet", fn))
}
Expand Down Expand Up @@ -499,18 +540,24 @@ func (cl *compiler) compileCallExpr(call *ast.CallExpr) {
key.qualifier = fn.Pkg().Path()
}

if funcID, ok := cl.ctx.Env.nameToNativeFuncID[key]; ok {
if expr != nil {
cl.compileExpr(expr)
}
for _, arg := range call.Args {
cl.compileExpr(arg)
}
cl.emit16(opCallNative, int(funcID))
return
if !cl.compileNativeCall(key, expr, call.Args) {
panic(cl.errorf(call.Fun, "can't compile a call to %s func", key))
}
}

panic(cl.errorf(call.Fun, "can't compile a call to %s func", key))
func (cl *compiler) compileNativeCall(key funcKey, expr ast.Expr, args []ast.Expr) bool {
funcID, ok := cl.ctx.Env.nameToNativeFuncID[key]
if !ok {
return false
}
if expr != nil {
cl.compileExpr(expr)
}
for _, arg := range args {
cl.compileExpr(arg)
}
cl.emit16(opCallNative, int(funcID))
return true
}

func (cl *compiler) compileUnaryOp(op opcode, e *ast.UnaryExpr) {
Expand Down Expand Up @@ -681,6 +728,10 @@ func (cl *compiler) isUncondJump(op opcode) bool {
}

func (cl *compiler) isSupportedType(typ types.Type) bool {
if typ == voidType {
return true
}

switch typ := typ.Underlying().(type) {
case *types.Pointer:
// 1. Pointers to structs are supported.
Expand Down
2 changes: 2 additions & 0 deletions ruleguard/quasigo/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ func eval(env *EvalEnv, fn *Func, args []interface{}) CallResult {
return CallResult{value: stack.top()}
case opReturnIntTop:
return CallResult{scalarValue: uint64(stack.topInt())}
case opReturn:
return CallResult{}

case opCallNative:
id := decode16(code, pc+1)
Expand Down
95 changes: 95 additions & 0 deletions ruleguard/quasigo/eval_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
package quasigo

import (
"bytes"
"errors"
"fmt"
"go/ast"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/quasilyte/go-ruleguard/ruleguard/quasigo/internal/evaltest"
)

Expand Down Expand Up @@ -221,3 +229,90 @@ func TestEval(t *testing.T) {
}
}
}

func TestEvalFile(t *testing.T) {
files, err := ioutil.ReadDir("testdata")
if err != nil {
t.Fatal(err)
}

runGo := func(main string) (string, error) {
out, err := exec.Command("go", "run", main).CombinedOutput()
if err != nil {
return "", fmt.Errorf("%v: %s", err, out)
}
return string(out), nil
}

runQuasigo := func(main string) (string, error) {
src, err := os.ReadFile(main)
if err != nil {
return "", err
}
env := NewEnv()
parsed, err := parseGoFile(string(src))
if err != nil {
return "", fmt.Errorf("parse: %v", err)
}

var stdout bytes.Buffer
env.AddNativeFunc("builtin", "Print", func(stack *ValueStack) {
arg := stack.Pop()
fmt.Fprintln(&stdout, arg)
})
env.AddNativeFunc("builtin", "PrintInt", func(stack *ValueStack) {
fmt.Fprintln(&stdout, stack.PopInt())
})

var mainFunc *Func
for _, decl := range parsed.ast.Decls {
decl, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
ctx := &CompileContext{
Env: env,
Types: parsed.types,
Fset: parsed.fset,
}
fn, err := Compile(ctx, decl)
if err != nil {
return "", fmt.Errorf("compile %s func: %v", decl.Name, err)
}
if decl.Name.String() == "main" {
mainFunc = fn
}
}
if mainFunc == nil {
return "", errors.New("can't find main() function")
}

Call(env.GetEvalEnv(), mainFunc)
return stdout.String(), nil
}

runTest := func(t *testing.T, mainFile string) {
goResult, err := runGo(mainFile)
if err != nil {
t.Fatalf("run go: %v", err)
}
quasigoResult, err := runQuasigo(mainFile)
if err != nil {
t.Fatalf("run quasigo: %v", err)
}
if diff := cmp.Diff(quasigoResult, goResult); diff != "" {
t.Errorf("output mismatch:\nhave (+): `%s`\nwant (-): `%s`\ndiff: %s", quasigoResult, goResult, diff)
}
}

for _, f := range files {
if !f.IsDir() {
continue
}
mainFile := filepath.Join("testdata", f.Name(), "main.go")
t.Run(f.Name(), func(t *testing.T) {
runTest(t, mainFile)
})
}

}
3 changes: 2 additions & 1 deletion ruleguard/quasigo/gen_opcodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ var opcodePrototypes = []opcodeProto{
{"ReturnIntTop", "op", "(value) -> (value)"},
{"ReturnFalse", "op", stackUnchanged},
{"ReturnTrue", "op", stackUnchanged},
{"Return", "op", stackUnchanged},

{"Jump", "op offset:i16", stackUnchanged},
{"JumpFalse", "op offset:i16", "(cond:bool) -> ()"},
Expand All @@ -45,7 +46,7 @@ var opcodePrototypes = []opcodeProto{
{"IsNotNil", "op", "(value) -> (result:bool)"},

{"Not", "op", "(value:bool) -> (result:bool)"},

{"EqInt", "op", "(x:int y:int) -> (result:bool)"},
{"NotEqInt", "op", "(x:int y:int) -> (result:bool)"},
{"GtInt", "op", "(x:int y:int) -> (result:bool)"},
Expand Down
49 changes: 25 additions & 24 deletions ruleguard/quasigo/opcode_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 2d7358b

Please sign in to comment.