diff --git a/ast/parser_ext.go b/ast/parser_ext.go index d20bb8a51c..9151fa378c 100644 --- a/ast/parser_ext.go +++ b/ast/parser_ext.go @@ -142,6 +142,10 @@ func ParseRuleFromExpr(module *Module, expr *Expr) (*Rule, error) { return nil, fmt.Errorf("negated expressions cannot be used for rule head") } + if _, ok := expr.Terms.(*SomeDecl); ok { + return nil, errors.New("some declarations cannot be used for rule head") + } + if term, ok := expr.Terms.(*Term); ok { switch v := term.Value.(type) { case Ref: @@ -151,6 +155,12 @@ func ParseRuleFromExpr(module *Module, expr *Expr) (*Rule, error) { } } + if _, ok := expr.Terms.([]*Term); !ok { + // This is a defensive check in case other kinds of expression terms are + // introduced in the future. + return nil, errors.New("expression cannot be used for rule head") + } + if expr.IsAssignment() { lhs, rhs := expr.Operand(0), expr.Operand(1) diff --git a/ast/parser_test.go b/ast/parser_test.go index 4b83c90ae7..251f1cb170 100644 --- a/ast/parser_test.go +++ b/ast/parser_test.go @@ -1427,6 +1427,11 @@ data = {"bar": 2} { true }` "foo" := 1` + someDecl := ` + package a + + some x` + assertParseModuleError(t, "multiple expressions", multipleExprs) assertParseModuleError(t, "non-equality", nonEquality) assertParseModuleError(t, "non-var name", nonVarName) @@ -1437,6 +1442,13 @@ data = {"bar": 2} { true }` assertParseModuleError(t, "non ref term", nonRefTerm) assertParseModuleError(t, "zero args", zeroArgs) assertParseModuleError(t, "assign to term", assignToTerm) + assertParseModuleError(t, "some decl", someDecl) + + if _, err := ParseRuleFromExpr(&Module{}, &Expr{ + Terms: struct{}{}, + }); err == nil { + t.Fatal("expected error for unknown expression term type") + } } func TestWildcards(t *testing.T) {