Skip to content

Commit

Permalink
Fix panic caused by assignment rewriting
Browse files Browse the repository at this point in the history
During the assignment rewriting stages, variables that have been
assigned locally are rewritten, e.g., k := "foo"; {k: 1} becomes
__local0__ = "foo"; {__local0__: 1}. Previously, the rewriting would
simply mutate all of the term values without considering where the term
was defined. The problem with this was that it would corrupt the
object's hashtable if the rewritten term was an object key.

This commit resolves the issue by making a copy of the object key before
mutating the term value. This approach requires the compiler to copy the
object as well. If this proves to be a performance bottleneck, we can do
something more clever.

Note, it's arguable that the Object struct should not allow keys to be
mutated at all. This would be tricky given the current Object interface.
Perhaps we can improve this in the future to avoid similar issues.

These changes also resolve an issue with the rewritten var set returned
by the query compiler. Previously, all seen vars were returned in the
set (as opposed to only those that had been redeclared.) When the query
compiler inverts the map before returning it could accidentally elide
vars, e.g., given {a: b, b: b} if it inverted this to {b: a} then the
output would be correct but if it inverted this to {b: b} (which it
could depending on the iteration order) the output would be incorrect.

Fixes open-policy-agent#1125

Signed-off-by: Torin Sandall <[email protected]>
  • Loading branch information
tsandall committed Dec 21, 2018
1 parent a9f819f commit b551669
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 7 deletions.
44 changes: 37 additions & 7 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -841,15 +841,32 @@ func (c *Compiler) rewriteLocalAssignments() {

// Rewrite vars in head that refer to locally declared vars in the body.
vis := NewGenericVisitor(func(x interface{}) bool {
switch x := x.(type) {
case *Term:
if v, ok := x.Value.(Var); ok {
if gv, ok := declared[v]; ok {
x.Value = gv
return true

term, ok := x.(*Term)
if !ok {
return false
}

switch v := term.Value.(type) {
case Object:
// Make a copy of the object because the keys may be mutated.
cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) {
if vark, ok := k.Value.(Var); ok {
if gv, ok := declared[vark]; ok {
k = k.Copy()
k.Value = gv
}
}
return k, v, nil
})
term.Value = cpy
case Var:
if gv, ok := declared[v]; ok {
term.Value = gv
return true
}
}

return false
})

Expand Down Expand Up @@ -1029,7 +1046,12 @@ func (qc *queryCompiler) rewriteLocalAssignments(_ *QueryContext, body Body) (Bo
}
qc.rewritten = make(map[Var]Var, len(declared))
for k, v := range declared {
qc.rewritten[v] = k
// The vars returned during the rewrite will include all seen vars,
// even if they're not declared with an assignment operation. We don't
// want to include these inside the rewritten set though.
if Compare(k, v) != 0 {
qc.rewritten[v] = k
}
}
return body, nil
}
Expand Down Expand Up @@ -2554,6 +2576,14 @@ func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, t
return true, errs
}
return false, errs
case Object:
cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) {
kcpy := k.Copy()
_, errs = rewriteDeclaredVarsInTerm(g, stack, kcpy, errs)
_, errs = rewriteDeclaredVarsInTerm(g, stack, v, errs)
return kcpy, v, nil
})
term.Value = cpy
case *ArrayComprehension:
errs = rewriteDeclaredVarsInArrayComprehension(g, stack, v, errs)
case *SetComprehension:
Expand Down
18 changes: 18 additions & 0 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,19 @@ func TestCompilerRewriteLocalAssignments(t *testing.T) {
head_array_comprehensions = [[x] | x := 1]
head_set_comprehensions = {[x] | x := 1}
head_object_comprehensions = {k: [x] | k := "foo"; x := 1}
rewritten_object_key {
k := "foo"
{k: 1}
}
rewritten_object_key_head[[{k: 1}]] {
k := "foo"
}
rewritten_object_key_head_value = [{k: 1}] {
k := "foo"
}
`)

c.Modules["test2"] = MustParseModule(`package test
Expand Down Expand Up @@ -1163,6 +1176,10 @@ func TestCompilerRewriteLocalAssignments(t *testing.T) {
head_array_comprehensions = [[__local21__] | __local21__ = 1]
head_set_comprehensions = {[__local22__] | __local22__ = 1}
head_object_comprehensions = {__local23__: [__local24__] | __local23__ = "foo"; __local24__ = 1}
rewritten_object_key = true { __local25__ = "foo"; {__local25__: 1} }
rewritten_object_key_head[[{__local26__: 1}]] { __local26__ = "foo" }
rewritten_object_key_head_value = [{__local27__: 1}] { __local27__ = "foo" }
`)

if len(module1.Rules) != len(expectedModule.Rules) {
Expand Down Expand Up @@ -2209,6 +2226,7 @@ func TestQueryCompilerRewrittenVars(t *testing.T) {
vars map[string]string
}{
{"assign", "a := 1", map[string]string{"__local0__": "a"}},
{"suppress only seen", "b = 1; a := b", map[string]string{"__local0__": "a"}},
}
for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
Expand Down

0 comments on commit b551669

Please sign in to comment.