Skip to content

Commit

Permalink
fix compiling value arguments to bind inside function body (close #107)
Browse files Browse the repository at this point in the history
  • Loading branch information
itchyny committed Jul 25, 2021
1 parent 49a818b commit 7de7578
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 62 deletions.
14 changes: 13 additions & 1 deletion cli/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1749,10 +1749,11 @@
- name: function declaration with duplicate variable names
args:
- 'def f($g;$g): $g; f(1;2)'
- 'def f($g;$g): $g; f(1;2,3)'
input: 'null'
expected: |
2
3
- name: function declaration with argument name conflict
args:
Expand Down Expand Up @@ -1787,6 +1788,17 @@
[111,121,112,122,211,221,212,222]
[111,112,121,122,211,212,221,222]
- name: function declaration with function reference to value argument
args:
- -c
- 'def f($x; $y): [$x, x, $y, y]; f(1,2; 3,4)'
input: 'null'
expected: |
[1,1,2,3,3,4]
[1,1,2,4,3,4]
[2,1,2,3,3,4]
[2,1,2,4,3,4]
- name: function declaration inside query
args:
- -c
Expand Down
103 changes: 43 additions & 60 deletions compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,9 @@ type varinfo struct {
}

type funcinfo struct {
name string
pc int
args []string
argsorder []int
name string
pc int
args []string
}

// Compile compiles a query.
Expand Down Expand Up @@ -227,18 +226,21 @@ func (c *compiler) compileModule(q *Query, alias string) error {
}

func (c *compiler) newVariable() [2]int {
return c.pushVariable("")
return c.createVariable("")
}

func (c *compiler) pushVariable(name string) [2]int {
s := c.scopes[len(c.scopes)-1]
if name != "" {
for _, v := range s.variables {
if v.name == name && v.depth == s.depth {
return v.index
}
for _, v := range s.variables {
if v.name == name && v.depth == s.depth {
return v.index
}
}
return c.createVariable(name)
}

func (c *compiler) createVariable(name string) [2]int {
s := c.scopes[len(c.scopes)-1]
v := [2]int{s.id, s.variablecnt}
s.variablecnt++
s.variables = append(s.variables, &varinfo{name, v, s.depth})
Expand Down Expand Up @@ -310,8 +312,8 @@ func (c *compiler) compileFuncDef(e *FuncDef, builtin bool) error {
})()
c.appendCodeInfo(e.Name)
defer c.appendCodeInfo("end of " + e.Name)
pc, argsorder := c.pc(), getArgsOrder(e.Args)
scope.funcs = append(scope.funcs, &funcinfo{e.Name, pc, e.Args, argsorder})
pc := c.pc()
scope.funcs = append(scope.funcs, &funcinfo{e.Name, pc, e.Args})
defer func(l int, variables []string) {
c.scopes, c.variables = c.scopes[:l], variables
}(len(c.scopes), c.variables)
Expand All @@ -322,46 +324,36 @@ func (c *compiler) compileFuncDef(e *FuncDef, builtin bool) error {
return &code{op: opscope, v: [3]int{scope.id, scope.variablecnt, len(e.Args)}}
})()
if len(e.Args) > 0 {
type varIndex struct {
name string
index [2]int
}
vis := make([]varIndex, 0, len(e.Args))
v := c.newVariable()
c.append(&code{op: opstore, v: v})
skip := make([]bool, len(e.Args))
for i, name := range e.Args {
for j := 0; j < i; j++ {
if name == e.Args[j] {
skip[j] = true
break
}
}
}
for _, i := range argsorder {
if skip[i] {
c.append(&code{op: oppop})
for _, arg := range e.Args {
if arg[0] == '$' {
c.appendCodeInfo(arg[1:])
w := c.createVariable(arg[1:])
c.append(&code{op: opstore, v: w})
vis = append(vis, varIndex{arg, w})
} else {
c.append(&code{op: opstore, v: c.pushVariable(e.Args[i])})
c.appendCodeInfo(arg)
c.append(&code{op: opstore, v: c.createVariable(arg)})
}
}
for _, w := range vis {
c.append(&code{op: opload, v: v})
c.append(&code{op: opload, v: w.index})
c.append(&code{op: opcallpc})
c.appendCodeInfo(w.name)
c.append(&code{op: opstore, v: c.pushVariable(w.name)})
}
c.append(&code{op: opload, v: v})
}
return c.compile(e.Body)
}

func getArgsOrder(args []string) []int {
xs := make([]int, len(args))
if len(xs) > 0 {
for i := range xs {
xs[i] = i
}
sort.Slice(xs, func(i, j int) bool {
xi, xj := xs[i], xs[j]
if args[xi][0] == '$' {
return args[xj][0] == '$' && xi > xj // reverse the order of variables
}
return args[xj][0] == '$' || xi < xj
})
}
return xs
}

func (c *compiler) compileQuery(e *Query) error {
for _, fd := range e.FuncDefs {
if err := c.compileFuncDef(fd, false); err != nil {
Expand Down Expand Up @@ -941,7 +933,7 @@ func (c *compiler) compileFunc(e *Func) error {
return c.compileCallInternal(
[3]interface{}{c.funcBuiltins, 0, e.Name},
e.Args,
nil,
true,
false,
)
case "input":
Expand All @@ -951,14 +943,14 @@ func (c *compiler) compileFunc(e *Func) error {
return c.compileCallInternal(
[3]interface{}{c.funcInput, 0, e.Name},
e.Args,
nil,
true,
false,
)
case "modulemeta":
return c.compileCallInternal(
[3]interface{}{c.funcModulemeta, 0, e.Name},
e.Args,
nil,
true,
false,
)
default:
Expand All @@ -969,7 +961,7 @@ func (c *compiler) compileFunc(e *Func) error {
if err := c.compileCallInternal(
[3]interface{}{fn.callback, len(e.Args), e.Name},
e.Args,
nil,
true,
false,
); err != nil {
return err
Expand Down Expand Up @@ -1344,26 +1336,17 @@ func (c *compiler) compileCall(name string, args []*Query) error {
return c.compileCallInternal(
[3]interface{}{internalFuncs[name].callback, len(args), name},
args,
nil,
true,
name == "_index" || name == "_slice",
)
}

func (c *compiler) compileCallPc(fn *funcinfo, args []*Query) error {
if len(args) == 0 {
return c.compileCallInternal(fn.pc, args, nil, false)
}
xs, vars := make([]*Query, len(args)), make(map[int]bool, len(fn.args))
for i, j := range fn.argsorder {
xs[i] = args[j]
if fn.args[j][0] == '$' {
vars[i] = true
}
}
return c.compileCallInternal(fn.pc, xs, vars, false)
return c.compileCallInternal(fn.pc, args, false, false)
}

func (c *compiler) compileCallInternal(fn interface{}, args []*Query, vars map[int]bool, indexing bool) error {
func (c *compiler) compileCallInternal(
fn interface{}, args []*Query, internal, indexing bool) error {
if len(args) == 0 {
c.append(&code{op: opcall, v: fn})
return nil
Expand All @@ -1379,7 +1362,7 @@ func (c *compiler) compileCallInternal(fn interface{}, args []*Query, vars map[i
if err := c.compileFuncDef(&FuncDef{Name: name, Body: args[i]}, false); err != nil {
return err
}
if vars == nil || vars[i] {
if internal {
switch c.pc() - pc {
case 2: // optimize identity argument (opscope, opret)
j := len(c.codes) - 3
Expand Down
2 changes: 1 addition & 1 deletion compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func TestCodeCompile_OptimizeTailRec_Range(t *testing.T) {
t.Fatal(err)
}
codes := reflect.ValueOf(code).Elem().FieldByName("codes")
if got, expected := codes.Len(), 85; expected != got {
if got, expected := codes.Len(), 103; expected != got {
t.Errorf("expected: %v, got: %v", expected, got)
}
op1 := codes.Index(1).Elem().FieldByName("op")
Expand Down

0 comments on commit 7de7578

Please sign in to comment.