diff --git a/ast/compile.go b/ast/compile.go index df4695331d..bb86d97011 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -4663,6 +4663,17 @@ func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, t return true, errs } return false, errs + case Call: + ref := v[0] + WalkVars(ref, func(v Var) bool { + if gv, ok := stack.Declared(v); ok && !gv.Equal(v) { + // We will rewrite the ref of a function call, which is never ok since we don't have first-class functions. + errs = append(errs, NewError(CompileErr, term.Location, "called function %s shadowed", ref)) + return true + } + return false + }) + return false, errs case *object: cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) { kcpy := k.Copy() diff --git a/ast/compile_test.go b/ast/compile_test.go index d00dfffb9c..7456fddde5 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -1515,6 +1515,118 @@ func TestCompilerRewriteExprTerms(t *testing.T) { } } +func TestIllegalFunctionCallRewrite(t *testing.T) { + cases := []struct { + note string + module string + expectedErrors []string + }{ + /*{ + note: "function call override in function value", + module: `package test + foo(x) := x + + p := foo(bar) { + #foo := 1 + bar := 2 + }`, + expectedErrors: []string{ + "undefined function foo", + }, + },*/ + { + note: "function call override in array comprehension value", + module: `package test +p := [foo(bar) | foo := 1; bar := 2]`, + expectedErrors: []string{ + "called function foo shadowed", + }, + }, + { + note: "function call override in set comprehension value", + module: `package test +p := {foo(bar) | foo := 1; bar := 2}`, + expectedErrors: []string{ + "called function foo shadowed", + }, + }, + { + note: "function call override in object comprehension value", + module: `package test +p := {foo(bar): bar(foo) | foo := 1; bar := 2}`, + expectedErrors: []string{ + "called function bar shadowed", + "called function foo shadowed", + }, + }, + { + note: "function call override in array comprehension value", + module: `package test +p := [foo.bar(baz) | foo := 1; bar := 2; baz := 3]`, + expectedErrors: []string{ + "called function foo.bar shadowed", + }, + }, + { + note: "nested function call override in array comprehension value", + module: `package test +p := [baz(foo(bar)) | foo := 1; bar := 2]`, + expectedErrors: []string{ + "called function foo shadowed", + }, + }, + { + note: "function call override of 'input' root document", + module: `package test +p := [input() | input := 1]`, + expectedErrors: []string{ + "called function input shadowed", + }, + }, + { + note: "function call override of 'data' root document", + module: `package test +p := [data() | data := 1]`, + expectedErrors: []string{ + "called function data shadowed", + }, + }, + } + + for _, tc := range cases { + t.Run(tc.note, func(t *testing.T) { + compiler := NewCompiler() + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + + compiler.Modules = map[string]*Module{ + "test": MustParseModuleWithOpts(tc.module, opts), + } + compileStages(compiler, compiler.rewriteLocalVars) + + result := make([]string, 0, len(compiler.Errors)) + for i := range compiler.Errors { + result = append(result, compiler.Errors[i].Message) + } + + sort.Strings(tc.expectedErrors) + sort.Strings(result) + + if len(tc.expectedErrors) != len(result) { + t.Fatalf("Expected %d errors but got %d:\n\n%v\n\nGot:\n\n%v", + len(tc.expectedErrors), len(result), + strings.Join(tc.expectedErrors, "\n"), strings.Join(result, "\n")) + } + + for i := range result { + if result[i] != tc.expectedErrors[i] { + t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", + strings.Join(tc.expectedErrors, "\n"), strings.Join(result, "\n")) + } + } + }) + } +} + func TestCompilerCheckUnusedImports(t *testing.T) { cases := []strictnessTestCase{ {