Skip to content

Commit

Permalink
ast+rego: disable compiler stages for IR-based eval paths (#6335)
Browse files Browse the repository at this point in the history
Only topdown can make sense of rules and comprehension indices, so Wasm and any
eval plugins should instruct the compiler to avoid that work.

Signed-off-by: Stephan Renatus <[email protected]>
  • Loading branch information
srenatus authored Oct 24, 2023
1 parent c76d5d6 commit 544fd03
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 19 deletions.
68 changes: 49 additions & 19 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,10 @@ type Compiler struct {
// with the key being the generated name and value being the original.
RewrittenVars map[Var]Var

localvargen *localVarGenerator
moduleLoader ModuleLoader
ruleIndices *util.HashMap
stages []struct {
name string
metricName string
f func()
}
localvargen *localVarGenerator
moduleLoader ModuleLoader
ruleIndices *util.HashMap
stages []stage
maxErrs int
sorted []string // list of sorted module names
pathExists func([]string) (bool, error)
Expand All @@ -145,11 +141,27 @@ type Compiler struct {
keepModules bool // whether to keep the unprocessed, parse modules (below)
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
evalMode CompilerEvalMode
}

// CompilerStage defines the interface for stages in the compiler.
type CompilerStage func(*Compiler) *Error

// CompilerEvalMode allows toggling certain stages that are only
// needed for certain modes, Concretely, only "topdown" mode will
// have the compiler build comprehension and rule indices.
type CompilerEvalMode int

const (
// EvalModeTopdown (default) instructs the compiler to build rule
// and comprehension indices used by topdown evaluation.
EvalModeTopdown CompilerEvalMode = iota

// EvalModeIR makes the compiler skip the stages for comprehension
// and rule indices.
EvalModeIR
)

// CompilerStageDefinition defines a compiler stage
type CompilerStageDefinition struct {
Name string
Expand Down Expand Up @@ -266,6 +278,12 @@ type QueryCompilerStageDefinition struct {
Stage QueryCompilerStage
}

type stage struct {
name string
metricName string
f func()
}

// NewCompiler returns a new empty compiler.
func NewCompiler() *Compiler {

Expand All @@ -289,11 +307,7 @@ func NewCompiler() *Compiler {
c.ModuleTree = NewModuleTree(nil)
c.RuleTree = NewRuleTree(c.ModuleTree)

c.stages = []struct {
name string
metricName string
f func()
}{
c.stages = []stage{
// Reference resolution should run first as it may be used to lazily
// load additional modules. If any stages run before resolution, they
// need to be re-run after resolution.
Expand Down Expand Up @@ -436,6 +450,12 @@ func (c *Compiler) WithUseTypeCheckAnnotations(enabled bool) *Compiler {
return c
}

// WithEvalMode allows setting the CompilerEvalMode of the compiler
func (c *Compiler) WithEvalMode(e CompilerEvalMode) *Compiler {
c.evalMode = e
return c
}

// ParsedModules returns the parsed, unprocessed modules from the compiler.
// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`.
// The map includes all modules loaded via the ModuleLoader, if one was used.
Expand Down Expand Up @@ -1469,6 +1489,12 @@ func (c *Compiler) compile() {
}()

for _, s := range c.stages {
if c.evalMode == EvalModeIR {
switch s.name {
case "BuildRuleIndices", "BuildComprehensionIndices":
continue // skip these stages
}
}
c.runStage(s.metricName, s.f)
if c.Failed() {
return
Expand Down Expand Up @@ -2620,18 +2646,20 @@ func (qc *queryCompiler) runStageAfter(metricName string, query Body, s QueryCom
return s(qc, query)
}

type queryStage = struct {
name string
metricName string
f func(*QueryContext, Body) (Body, error)
}

func (qc *queryCompiler) Compile(query Body) (Body, error) {
if len(query) == 0 {
return nil, Errors{NewError(CompileErr, nil, "empty query cannot be compiled")}
}

query = query.Copy()

stages := []struct {
name string
metricName string
f func(*QueryContext, Body) (Body, error)
}{
stages := []queryStage{
{"CheckKeywordOverrides", "query_compile_stage_check_keyword_overrides", qc.checkKeywordOverrides},
{"ResolveRefs", "query_compile_stage_resolve_refs", qc.resolveRefs},
{"RewriteLocalVars", "query_compile_stage_rewrite_local_vars", qc.rewriteLocalVars},
Expand All @@ -2646,7 +2674,9 @@ func (qc *queryCompiler) Compile(query Body) (Body, error) {
{"CheckTypes", "query_compile_stage_check_types", qc.checkTypes},
{"CheckUnsafeBuiltins", "query_compile_stage_check_unsafe_builtins", qc.checkUnsafeBuiltins},
{"CheckDeprecatedBuiltins", "query_compile_stage_check_deprecated_builtins", qc.checkDeprecatedBuiltins},
{"BuildComprehensionIndex", "query_compile_stage_build_comprehension_index", qc.buildComprehensionIndices},
}
if qc.compiler.evalMode == EvalModeTopdown {
stages = append(stages, queryStage{"BuildComprehensionIndex", "query_compile_stage_build_comprehension_index", qc.buildComprehensionIndices})
}

qctx := qc.qctx.Copy()
Expand Down
11 changes: 11 additions & 0 deletions rego/rego.go
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,12 @@ func New(options ...func(r *Rego)) *Rego {
WithEnablePrintStatements(r.enablePrintStatements).
WithStrict(r.strict).
WithUseTypeCheckAnnotations(true)

// topdown could be target "" or "rego", but both could be overridden by
// a target plugin (checked below)
if r.target == targetWasm {
r.compiler = r.compiler.WithEvalMode(ast.EvalModeIR)
}
}

if r.store == nil {
Expand Down Expand Up @@ -1252,6 +1258,11 @@ func New(options ...func(r *Rego)) *Rego {
}
}
}

if t := r.targetPlugin(r.target); t != nil {
r.compiler = r.compiler.WithEvalMode(ast.EvalModeIR)
}

return r
}

Expand Down
32 changes: 32 additions & 0 deletions rego/rego_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,38 @@ func TestRegoInstrumentExtraPartialCompilerStage(t *testing.T) {
}
}

func TestRegoTargetWasmAndTargetPluginDisablesIndexingTopdownStages(t *testing.T) {
tp := testPlugin{}
RegisterPlugin("rego.target.foo", &tp)
t.Cleanup(resetPlugins)

for _, tgt := range []string{"wasm", "foo"} {
t.Run(tgt, func(t *testing.T) {
m := metrics.New()
r := New(Query("foo = 1"), Module("foo.rego", "package x"), Metrics(m), Instrument(true), Target(tgt))
ctx := context.Background()
_, err := r.Eval(ctx)
if err != nil {
t.Fatal(err)
}

expAbsent := []string{
"timer_query_compile_stage_build_comprehension_index_ns",
"timer_compile_stage_rebuild_comprehension_indices_ns",
"timer_compile_stage_rebuild_indices_ns",
}

all := m.All()

for _, name := range expAbsent {
if _, ok := all[name]; ok {
t.Errorf("Expected NOT to find %v but did", name)
}
}
})
}
}

func TestRegoInstrumentExtraPartialResultCompilerStage(t *testing.T) {
m := metrics.New()
r := New(Query("input.x"), Module("foo.rego", "package x"), Metrics(m), Instrument(true))
Expand Down

0 comments on commit 544fd03

Please sign in to comment.