Skip to content

Commit

Permalink
ast: Disallow shadowing of called functions in comprehension heads
Browse files Browse the repository at this point in the history
If a local var in a comprehension overrides a function call in the comprehension head, it will never be possible to make that call.

Fixes: open-policy-agent#4762

Signed-off-by: Johan Fylling <[email protected]>
  • Loading branch information
johanfylling committed Jun 9, 2022
1 parent 4ea50cb commit 26c23c0
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 0 deletions.
11 changes: 11 additions & 0 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -4663,6 +4663,17 @@ func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, t
return true, errs
}
return false, errs
case Call:
ref := v[0]
WalkVars(ref, func(v Var) bool {
if gv, ok := stack.Declared(v); ok && !gv.Equal(v) {
// We will rewrite the ref of a function call, which is never ok since we don't have first-class functions.
errs = append(errs, NewError(CompileErr, term.Location, "called function %s shadowed", ref))
return true
}
return false
})
return false, errs
case *object:
cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) {
kcpy := k.Copy()
Expand Down
112 changes: 112 additions & 0 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1515,6 +1515,118 @@ func TestCompilerRewriteExprTerms(t *testing.T) {
}
}

func TestIllegalFunctionCallRewrite(t *testing.T) {
cases := []struct {
note string
module string
expectedErrors []string
}{
/*{
note: "function call override in function value",
module: `package test
foo(x) := x
p := foo(bar) {
#foo := 1
bar := 2
}`,
expectedErrors: []string{
"undefined function foo",
},
},*/
{
note: "function call override in array comprehension value",
module: `package test
p := [foo(bar) | foo := 1; bar := 2]`,
expectedErrors: []string{
"called function foo shadowed",
},
},
{
note: "function call override in set comprehension value",
module: `package test
p := {foo(bar) | foo := 1; bar := 2}`,
expectedErrors: []string{
"called function foo shadowed",
},
},
{
note: "function call override in object comprehension value",
module: `package test
p := {foo(bar): bar(foo) | foo := 1; bar := 2}`,
expectedErrors: []string{
"called function bar shadowed",
"called function foo shadowed",
},
},
{
note: "function call override in array comprehension value",
module: `package test
p := [foo.bar(baz) | foo := 1; bar := 2; baz := 3]`,
expectedErrors: []string{
"called function foo.bar shadowed",
},
},
{
note: "nested function call override in array comprehension value",
module: `package test
p := [baz(foo(bar)) | foo := 1; bar := 2]`,
expectedErrors: []string{
"called function foo shadowed",
},
},
{
note: "function call override of 'input' root document",
module: `package test
p := [input() | input := 1]`,
expectedErrors: []string{
"called function input shadowed",
},
},
{
note: "function call override of 'data' root document",
module: `package test
p := [data() | data := 1]`,
expectedErrors: []string{
"called function data shadowed",
},
},
}

for _, tc := range cases {
t.Run(tc.note, func(t *testing.T) {
compiler := NewCompiler()
opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true}

compiler.Modules = map[string]*Module{
"test": MustParseModuleWithOpts(tc.module, opts),
}
compileStages(compiler, compiler.rewriteLocalVars)

result := make([]string, 0, len(compiler.Errors))
for i := range compiler.Errors {
result = append(result, compiler.Errors[i].Message)
}

sort.Strings(tc.expectedErrors)
sort.Strings(result)

if len(tc.expectedErrors) != len(result) {
t.Fatalf("Expected %d errors but got %d:\n\n%v\n\nGot:\n\n%v",
len(tc.expectedErrors), len(result),
strings.Join(tc.expectedErrors, "\n"), strings.Join(result, "\n"))
}

for i := range result {
if result[i] != tc.expectedErrors[i] {
t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v",
strings.Join(tc.expectedErrors, "\n"), strings.Join(result, "\n"))
}
}
})
}
}

func TestCompilerCheckUnusedImports(t *testing.T) {
cases := []strictnessTestCase{
{
Expand Down

0 comments on commit 26c23c0

Please sign in to comment.