From b5516699c74e538e5ceec89616668bb1c7d2785a Mon Sep 17 00:00:00 2001 From: Torin Sandall Date: Thu, 20 Dec 2018 16:56:35 -0800 Subject: [PATCH] Fix panic caused by assignment rewriting During the assignment rewriting stages, variables that have been assigned locally are rewritten, e.g., k := "foo"; {k: 1} becomes __local0__ = "foo"; {__local0__: 1}. Previously, the rewriting would simply mutate all of the term values without considering where the term was defined. The problem with this was that it would corrupt the object's hashtable if the rewritten term was an object key. This commit resolves the issue by making a copy of the object key before mutating the term value. This approach requires the compiler to copy the object as well. If this proves to be a performance bottleneck, we can do something more clever. Note, it's arguable that the Object struct should not allow keys to be mutated at all. This would be tricky given the current Object interface. Perhaps we can improve this in the future to avoid similar issues. These changes also resolve an issue with the rewritten var set returned by the query compiler. Previously, all seen vars were returned in the set (as opposed to only those that had been redeclared.) When the query compiler inverts the map before returning it could accidentally elide vars, e.g., given {a: b, b: b} if it inverted this to {b: a} then the output would be correct but if it inverted this to {b: b} (which it could depending on the iteration order) the output would be incorrect. Fixes #1125 Signed-off-by: Torin Sandall --- ast/compile.go | 44 +++++++++++++++++++++++++++++++++++++------- ast/compile_test.go | 18 ++++++++++++++++++ 2 files changed, 55 insertions(+), 7 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index 0a0518931e..16f0356ab2 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -841,15 +841,32 @@ func (c *Compiler) rewriteLocalAssignments() { // Rewrite vars in head that refer to locally declared vars in the body. vis := NewGenericVisitor(func(x interface{}) bool { - switch x := x.(type) { - case *Term: - if v, ok := x.Value.(Var); ok { - if gv, ok := declared[v]; ok { - x.Value = gv - return true + + term, ok := x.(*Term) + if !ok { + return false + } + + switch v := term.Value.(type) { + case Object: + // Make a copy of the object because the keys may be mutated. + cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) { + if vark, ok := k.Value.(Var); ok { + if gv, ok := declared[vark]; ok { + k = k.Copy() + k.Value = gv + } } + return k, v, nil + }) + term.Value = cpy + case Var: + if gv, ok := declared[v]; ok { + term.Value = gv + return true } } + return false }) @@ -1029,7 +1046,12 @@ func (qc *queryCompiler) rewriteLocalAssignments(_ *QueryContext, body Body) (Bo } qc.rewritten = make(map[Var]Var, len(declared)) for k, v := range declared { - qc.rewritten[v] = k + // The vars returned during the rewrite will include all seen vars, + // even if they're not declared with an assignment operation. We don't + // want to include these inside the rewritten set though. + if Compare(k, v) != 0 { + qc.rewritten[v] = k + } } return body, nil } @@ -2554,6 +2576,14 @@ func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, t return true, errs } return false, errs + case Object: + cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) { + kcpy := k.Copy() + _, errs = rewriteDeclaredVarsInTerm(g, stack, kcpy, errs) + _, errs = rewriteDeclaredVarsInTerm(g, stack, v, errs) + return kcpy, v, nil + }) + term.Value = cpy case *ArrayComprehension: errs = rewriteDeclaredVarsInArrayComprehension(g, stack, v, errs) case *SetComprehension: diff --git a/ast/compile_test.go b/ast/compile_test.go index 8e1be4fd3e..99399e335a 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -1106,6 +1106,19 @@ func TestCompilerRewriteLocalAssignments(t *testing.T) { head_array_comprehensions = [[x] | x := 1] head_set_comprehensions = {[x] | x := 1} head_object_comprehensions = {k: [x] | k := "foo"; x := 1} + + rewritten_object_key { + k := "foo" + {k: 1} + } + + rewritten_object_key_head[[{k: 1}]] { + k := "foo" + } + + rewritten_object_key_head_value = [{k: 1}] { + k := "foo" + } `) c.Modules["test2"] = MustParseModule(`package test @@ -1163,6 +1176,10 @@ func TestCompilerRewriteLocalAssignments(t *testing.T) { head_array_comprehensions = [[__local21__] | __local21__ = 1] head_set_comprehensions = {[__local22__] | __local22__ = 1} head_object_comprehensions = {__local23__: [__local24__] | __local23__ = "foo"; __local24__ = 1} + + rewritten_object_key = true { __local25__ = "foo"; {__local25__: 1} } + rewritten_object_key_head[[{__local26__: 1}]] { __local26__ = "foo" } + rewritten_object_key_head_value = [{__local27__: 1}] { __local27__ = "foo" } `) if len(module1.Rules) != len(expectedModule.Rules) { @@ -2209,6 +2226,7 @@ func TestQueryCompilerRewrittenVars(t *testing.T) { vars map[string]string }{ {"assign", "a := 1", map[string]string{"__local0__": "a"}}, + {"suppress only seen", "b = 1; a := b", map[string]string{"__local0__": "a"}}, } for _, tc := range tests { t.Run(tc.note, func(t *testing.T) {