Skip to content

Commit

Permalink
More stringent ast.ref_to_string (#1106)
Browse files Browse the repository at this point in the history
While I started looking into #1104 I encountered some cases where
`ast.ref_to_string`, and it's `static` counterpart would return wrong
representations. We already had a bit of a homegrown format, which was
convenient for some things, but would also leak out in places where
it shouldn't be. Now use the same format as OPA for representing refs
as strings.

Also added the query in error messages where we fail to prepare or
eval in a couple of places, as the lack of those made debugging this
issue much harder than it had to be.

Signed-off-by: Anders Eknert <[email protected]>
  • Loading branch information
anderseknert authored Sep 12, 2024
1 parent 822e42c commit df31dda
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 28 deletions.
26 changes: 17 additions & 9 deletions bundle/regal/ast/ast.rego
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ is_constant(value) if {
not has_term_var(value.value)
}

default builtin_names := set()

builtin_names := object.keys(config.capabilities.builtins)

builtin_namespaces contains namespace if {
Expand Down Expand Up @@ -183,28 +185,34 @@ _exclude_arg("assign", 0, _)

# METADATA
# description: returns the "path" string of any given ref value
ref_to_string(ref) := concat(".", [_ref_part_to_string(i, part) | some i, part in ref])
ref_to_string(ref) := concat("", [_ref_part_to_string(i, part) | some i, part in ref])

_ref_part_to_string(0, ref) := ref.value
_ref_part_to_string(0, part) := part.value

_ref_part_to_string(_, ref) := ref.value if ref.type == "string"
_ref_part_to_string(i, part) := _format_part(part) if i > 0

_ref_part_to_string(i, ref) := concat("", ["$", ref.value]) if {
ref.type != "string"
i > 0
}
_format_part(part) := sprintf(".%s", [part.value]) if {
part.type == "string"
regex.match(`^[a-zA-Z_][a-zA-Z1-9_]*$`, part.value)
} else := sprintf(`["%v"]`, [part.value]) if {
part.type == "string"
} else := sprintf(`[%v]`, [part.value])

# METADATA
# description: |
# returns the string representation of a ref up until its first
# non-static (i.e. variable) value, if any:
# foo.bar -> foo.bar
# foo.bar[baz] -> foo.bar
ref_static_to_string(ref) := ss if {
ref_static_to_string(ref) := str if {
rs := ref_to_string(ref)
ss := substring(rs, 0, indexof(rs, ".$"))
str := _trim_from_var(rs, regex.find_n(`\[[^"]`, rs, 1))
}

_trim_from_var(ref_str, vars) := ref_str if {
count(vars) == 0
} else := substring(ref_str, 0, indexof(ref_str, vars[0]))

static_ref(ref) if every t in array.slice(ref.value, 1, count(ref.value)) {
t.type != "var"
}
Expand Down
51 changes: 51 additions & 0 deletions bundle/regal/ast/ast_test.rego
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,54 @@ test_function_calls if {
test_implicit_boolean_assignment if {
ast.implicit_boolean_assignment(ast.with_rego_v1(`a.b if true`).rules[0])
}

test_ref_to_string if {
ast.ref_to_string([{"type": "var", "value": "data"}]) == `data`
ast.ref_to_string([{"type": "var", "value": "foo"}, {"type": "var", "value": "bar"}]) == `foo[bar]`
ast.ref_to_string([{"type": "var", "value": "data"}, {"type": "string", "value": "/foo/"}]) == `data["/foo/"]`
ast.ref_to_string([
{"type": "var", "value": "foo"},
{"type": "var", "value": "bar"},
{"type": "var", "value": "baz"},
]) == `foo[bar][baz]`
ast.ref_to_string([
{"type": "var", "value": "foo"},
{"type": "var", "value": "bar"},
{"type": "var", "value": "baz"},
{"type": "string", "value": "qux"},
]) == `foo[bar][baz].qux`
ast.ref_to_string([
{"type": "var", "value": "foo"},
{"type": "string", "value": "~bar~"},
{"type": "string", "value": "boo"},
{"type": "var", "value": "baz"},
]) == `foo["~bar~"].boo[baz]`
ast.ref_to_string([
{"type": "var", "value": "data"},
{"type": "string", "value": "regal"},
{"type": "string", "value": "lsp"},
{"type": "string", "value": "completion_test"},
]) == `data.regal.lsp.completion_test`
}

test_ref_static_to_string if {
ast.ref_static_to_string([{"type": "var", "value": "data"}]) == `data`
ast.ref_static_to_string([{"type": "var", "value": "foo"}, {"type": "var", "value": "bar"}]) == `foo`
ast.ref_static_to_string([{"type": "var", "value": "data"}, {"type": "string", "value": "/foo/"}]) == `data["/foo/"]`
ast.ref_static_to_string([
{"type": "var", "value": "foo"},
{"type": "string", "value": "bar"},
{"type": "var", "value": "baz"},
]) == `foo.bar`
ast.ref_static_to_string([
{"type": "var", "value": "foo"},
{"type": "string", "value": "~bar~"},
{"type": "string", "value": "qux"},
]) == `foo["~bar~"].qux`
ast.ref_static_to_string([
{"type": "var", "value": "data"},
{"type": "string", "value": "regal"},
{"type": "string", "value": "lsp"},
{"type": "string", "value": "completion_test"},
]) == `data.regal.lsp.completion_test`
}
1 change: 0 additions & 1 deletion bundle/regal/ast/rule_head_locations_test.rego
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ ref_rule[foo] := true if {
some foo in [1,2,3]
}
`

result := ast.rule_head_locations with input as regal.parse_module("p.rego", policy)

result == {
Expand Down
2 changes: 1 addition & 1 deletion bundle/regal/lsp/completion/ref_names.rego
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import data.regal.ast
# ref_names returns a list of ref names that are used in the module.
# built-in functions are not included as they are provided by another completions provider.
ref_names contains name if {
name := ast.ref_to_string(ast.found.refs[_][_].value)
name := ast.ref_static_to_string(ast.found.refs[_][_].value)

not name in ast.builtin_names
}
Expand Down
4 changes: 2 additions & 2 deletions bundle/regal/lsp/completion/ref_names_test.rego
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ test_ref_names if {
ref_names == {
"imp",
"bb",
"input.foo.$x",
"data.bar.$x",
"input.foo",
"data.bar",
"imp.foo",
"data.x",
}
Expand Down
2 changes: 1 addition & 1 deletion internal/lsp/completions/providers/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func prepareQuery(store storage.Store, query string) (*rego.PreparedEvalQuery, e
// and how to present it if enabled.
pq, err := rego.New(regoArgs...).PrepareForEval(context.Background())
if err != nil {
return nil, fmt.Errorf("failed preparing query: %w", err)
return nil, fmt.Errorf("failed preparing query: %s, %w", query, err)
}

err = store.Commit(context.Background(), txn)
Expand Down
18 changes: 11 additions & 7 deletions internal/lsp/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (l *LanguageServer) Eval(

pq, err := rego.New(regoArgs...).PrepareForEval(ctx)
if err != nil {
return nil, fmt.Errorf("failed preparing query: %w", err)
return nil, fmt.Errorf("failed preparing query %s: %w", query, err)
}

if input != nil {
Expand All @@ -88,9 +88,9 @@ func (l *LanguageServer) Eval(
}

type EvalPathResult struct {
Value any `json:"value"`
IsUndefined bool `json:"isUndefined"`
PrintOutput map[int][]string `json:"printOutput"`
Value any `json:"value"`
IsUndefined bool `json:"isUndefined"`
PrintOutput map[string]map[int][]string `json:"printOutput"`
}

func (l *LanguageServer) EvalWorkspacePath(
Expand All @@ -100,7 +100,7 @@ func (l *LanguageServer) EvalWorkspacePath(
) (EvalPathResult, error) {
resultQuery := "result := " + query

hook := PrintHook{Output: make(map[int][]string)}
hook := PrintHook{Output: make(map[string]map[int][]string)}

var bs map[string]bundle.Bundle
if l.bundleCache != nil {
Expand Down Expand Up @@ -142,11 +142,15 @@ func prepareRegoArgs(query ast.Body, bundles map[string]bundle.Bundle, printHook
}

type PrintHook struct {
Output map[int][]string
Output map[string]map[int][]string
}

func (h PrintHook) Print(ctx print.Context, msg string) error {
h.Output[ctx.Location.Row] = append(h.Output[ctx.Location.Row], msg)
if _, ok := h.Output[ctx.Location.File]; !ok {
h.Output[ctx.Location.File] = make(map[int][]string)
}

h.Output[ctx.Location.File][ctx.Location.Row] = append(h.Output[ctx.Location.File][ctx.Location.Row], msg)

return nil
}
2 changes: 1 addition & 1 deletion internal/lsp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ func (l *LanguageServer) StartCommandWorker(ctx context.Context) { // nolint:mai

responseParams := map[string]any{
"type": "opa-debug",
"name": "Debug " + path,
"name": path,
"request": "launch",
"command": "eval",
"query": path,
Expand Down
14 changes: 8 additions & 6 deletions pkg/linter/linter.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,19 +363,21 @@ func (l Linter) DetermineEnabledRules(ctx context.Context) ([]string, error) {
enabledRules = append(enabledRules, rule.Name())
}

query := ast.MustParseBody(`[rule|
data.regal.rules[cat][rule]
data.regal.config.for_rule(cat, rule).level != "ignore"
]`)
queryStr := `[rule |
data.regal.rules[cat][rule]
data.regal.config.for_rule(cat, rule).level != "ignore"
]`

query := ast.MustParseBody(queryStr)

regoArgs, err := l.prepareRegoArgs(query)
if err != nil {
return nil, fmt.Errorf("failed preparing query: %w", err)
return nil, fmt.Errorf("failed preparing query %s: %w", queryStr, err)
}

rs, err := rego.New(regoArgs...).Eval(ctx)
if err != nil {
return nil, fmt.Errorf("failed preparing query: %w", err)
return nil, fmt.Errorf("failed evaluating query %s: %w", queryStr, err)
}

if len(rs) != 1 || len(rs[0].Expressions) != 1 {
Expand Down

0 comments on commit df31dda

Please sign in to comment.