diff --git a/internal/ast/rule.go b/internal/ast/rule.go index 0b0a8835..94fa68ce 100644 --- a/internal/ast/rule.go +++ b/internal/ast/rule.go @@ -5,14 +5,12 @@ import ( "strings" "github.com/open-policy-agent/opa/ast" - - "github.com/styrainc/regal/internal/lsp/rego" ) // GetRuleDetail returns a short descriptive string value for a given rule stating // if the rule is constant, multi-value, single-value etc and the type of the rule's // value if known. -func GetRuleDetail(rule *ast.Rule) string { +func GetRuleDetail(rule *ast.Rule, builtins map[string]*ast.Builtin) string { if rule.Head.Args != nil { return "function" + rule.Head.Args.String() } @@ -53,9 +51,7 @@ func GetRuleDetail(rule *ast.Rule) string { case ast.Call: name := v[0].String() - bis := rego.GetBuiltins() - - if builtin, ok := bis[name]; ok { + if builtin, ok := builtins[name]; ok { retType := builtin.Decl.NamedResult().String() detail += fmt.Sprintf(" (%s)", simplifyType(retType)) diff --git a/internal/ast/rule_test.go b/internal/ast/rule_test.go index a950dc24..b328f0ac 100644 --- a/internal/ast/rule_test.go +++ b/internal/ast/rule_test.go @@ -3,6 +3,9 @@ package ast import ( "testing" + "github.com/open-policy-agent/opa/ast" + + "github.com/styrainc/regal/internal/lsp/rego" "github.com/styrainc/regal/internal/parse" ) @@ -47,7 +50,9 @@ func TestGetRuleDetail(t *testing.T) { rule := mod.Rules[0] - result := GetRuleDetail(rule) + bis := rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()) + + result := GetRuleDetail(rule, bis) if result != tc.expected { t.Errorf("Expected %s, got %s", tc.expected, result) } diff --git a/internal/lsp/completions/providers/builtins.go b/internal/lsp/completions/providers/builtins.go index 259419e3..30973e51 100644 --- a/internal/lsp/completions/providers/builtins.go +++ b/internal/lsp/completions/providers/builtins.go @@ -6,7 +6,6 @@ import ( "github.com/styrainc/regal/internal/lsp/cache" "github.com/styrainc/regal/internal/lsp/hover" - "github.com/styrainc/regal/internal/lsp/rego" "github.com/styrainc/regal/internal/lsp/types" "github.com/styrainc/regal/internal/lsp/types/completion" ) @@ -21,7 +20,7 @@ func (*BuiltIns) Run( _ context.Context, c *cache.Cache, params types.CompletionParams, - _ *Options, + opts *Options, ) ([]types.CompletionItem, error) { fileURI := params.TextDocument.URI @@ -45,9 +44,7 @@ func (*BuiltIns) Run( items := []types.CompletionItem{} - bis := rego.GetBuiltins() - - for _, builtIn := range bis { + for _, builtIn := range opts.Builtins { key := builtIn.Name if builtIn.Infix != "" { diff --git a/internal/lsp/completions/providers/builtins_test.go b/internal/lsp/completions/providers/builtins_test.go index 7b406ef3..7b97ea98 100644 --- a/internal/lsp/completions/providers/builtins_test.go +++ b/internal/lsp/completions/providers/builtins_test.go @@ -6,7 +6,10 @@ import ( "strings" "testing" + "github.com/open-policy-agent/opa/ast" + "github.com/styrainc/regal/internal/lsp/cache" + "github.com/styrainc/regal/internal/lsp/rego" "github.com/styrainc/regal/internal/lsp/types" ) @@ -33,7 +36,9 @@ allow if c` }, } - completions, err := p.Run(context.Background(), c, completionParams, nil) + completions, err := p.Run(context.Background(), c, completionParams, &Options{ + Builtins: rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()), + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -68,7 +73,9 @@ allow := c` }, } - completions, err := p.Run(context.Background(), c, completionParams, nil) + completions, err := p.Run(context.Background(), c, completionParams, &Options{ + Builtins: rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()), + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -105,7 +112,9 @@ allow if { }, } - completions, err := p.Run(context.Background(), c, completionParams, nil) + completions, err := p.Run(context.Background(), c, completionParams, &Options{ + Builtins: rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()), + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -140,7 +149,9 @@ allow if gt` }, } - completions, err := p.Run(context.Background(), c, completionParams, nil) + completions, err := p.Run(context.Background(), c, completionParams, &Options{ + Builtins: rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()), + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -173,7 +184,9 @@ allow if c` }, } - completions, err := p.Run(context.Background(), c, completionParams, nil) + completions, err := p.Run(context.Background(), c, completionParams, &Options{ + Builtins: rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()), + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -208,7 +221,9 @@ default allow := f` }, } - completions, err := p.Run(context.Background(), c, completionParams, nil) + completions, err := p.Run(context.Background(), c, completionParams, &Options{ + Builtins: rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()), + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } diff --git a/internal/lsp/completions/providers/options.go b/internal/lsp/completions/providers/options.go index ad05c452..d2136125 100644 --- a/internal/lsp/completions/providers/options.go +++ b/internal/lsp/completions/providers/options.go @@ -1,10 +1,15 @@ package providers import ( + "github.com/open-policy-agent/opa/ast" + "github.com/styrainc/regal/internal/lsp/clients" ) type Options struct { ClientIdentifier clients.Identifier RootURI string + // Builtins is a map of built-in functions to their definitions required in + // the context of the current completion request. + Builtins map[string]*ast.Builtin } diff --git a/internal/lsp/completions/providers/packagerefs_test.go b/internal/lsp/completions/providers/packagerefs_test.go index f9db464c..7d32885b 100644 --- a/internal/lsp/completions/providers/packagerefs_test.go +++ b/internal/lsp/completions/providers/packagerefs_test.go @@ -7,8 +7,11 @@ import ( "strings" "testing" + "github.com/open-policy-agent/opa/ast" + "github.com/styrainc/regal/internal/lsp/cache" "github.com/styrainc/regal/internal/lsp/completions/refs" + "github.com/styrainc/regal/internal/lsp/rego" "github.com/styrainc/regal/internal/lsp/types" "github.com/styrainc/regal/internal/parse" ) @@ -42,11 +45,13 @@ import c.SetFileContents("file:///bar/file2.rego", fileContents) + builtins := rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()) + for uri, contents := range regoFiles { mod := parse.MustParseModule(contents) c.SetModule(uri, mod) - c.SetFileRefs(uri, refs.DefinedInModule(mod)) + c.SetFileRefs(uri, refs.DefinedInModule(mod, builtins)) } p := &PackageRefs{} @@ -116,11 +121,13 @@ import c.SetFileContents("file:///file.rego", fileContents) + builtins := rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()) + for uri, contents := range regoFiles { mod := parse.MustParseModule(contents) c.SetModule(uri, mod) - c.SetFileRefs(uri, refs.DefinedInModule(mod)) + c.SetFileRefs(uri, refs.DefinedInModule(mod, builtins)) } p := &PackageRefs{} diff --git a/internal/lsp/completions/providers/rulehead_test.go b/internal/lsp/completions/providers/rulehead_test.go index 0a2c3f3c..11394608 100644 --- a/internal/lsp/completions/providers/rulehead_test.go +++ b/internal/lsp/completions/providers/rulehead_test.go @@ -5,8 +5,11 @@ import ( "slices" "testing" + "github.com/open-policy-agent/opa/ast" + "github.com/styrainc/regal/internal/lsp/cache" "github.com/styrainc/regal/internal/lsp/completions/refs" + "github.com/styrainc/regal/internal/lsp/rego" "github.com/styrainc/regal/internal/lsp/types" "github.com/styrainc/regal/internal/parse" ) @@ -37,6 +40,8 @@ funckyfunc := true `, } + builtins := rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()) + for uri, contents := range regoFiles { mod, err := parse.Module(uri, contents) if err != nil { @@ -45,7 +50,7 @@ funckyfunc := true c.SetFileContents(uri, contents) c.SetModule(uri, mod) - c.SetFileRefs(uri, refs.DefinedInModule(mod)) + c.SetFileRefs(uri, refs.DefinedInModule(mod, builtins)) } p := &RuleHead{} diff --git a/internal/lsp/completions/refs/defined.go b/internal/lsp/completions/refs/defined.go index 228b8b47..e5a8c1d4 100644 --- a/internal/lsp/completions/refs/defined.go +++ b/internal/lsp/completions/refs/defined.go @@ -15,7 +15,7 @@ import ( // DefinedInModule returns a map of refs and details about them to be used in completions that // were found in the given module. -func DefinedInModule(module *ast.Module) map[string]types.Ref { +func DefinedInModule(module *ast.Module, builtins map[string]*ast.Builtin) map[string]types.Ref { modKey := module.Package.Path.String() // first, create a reference for the package using the metadata @@ -92,7 +92,7 @@ func DefinedInModule(module *ast.Module) map[string]types.Ref { items[ruleKey] = types.Ref{ Kind: kind, Label: ruleKey, - Detail: rast.GetRuleDetail(rs[0]), + Detail: rast.GetRuleDetail(rs[0], builtins), Description: ruleDescription, } } diff --git a/internal/lsp/completions/refs/defined_test.go b/internal/lsp/completions/refs/defined_test.go index 8e40fafa..ddbac2bd 100644 --- a/internal/lsp/completions/refs/defined_test.go +++ b/internal/lsp/completions/refs/defined_test.go @@ -5,6 +5,9 @@ import ( "strings" "testing" + "github.com/open-policy-agent/opa/ast" + + "github.com/styrainc/regal/internal/lsp/rego" "github.com/styrainc/regal/internal/lsp/types" rparse "github.com/styrainc/regal/internal/parse" ) @@ -32,7 +35,9 @@ func TestForModule_Package(t *testing.T) { package example `) - items := DefinedInModule(mod) + bis := rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()) + + items := DefinedInModule(mod, bis) expectedRefs := map[string]types.Ref{ "data.example": { @@ -121,8 +126,9 @@ deny contains "strings" if true pi := 3.14 `) + bis := rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()) - items := DefinedInModule(mod) + items := DefinedInModule(mod, bis) expectedRefs := map[string]types.Ref{ "data.example": { diff --git a/internal/lsp/documentsymbol.go b/internal/lsp/documentsymbol.go index 79904e74..cc97caa4 100644 --- a/internal/lsp/documentsymbol.go +++ b/internal/lsp/documentsymbol.go @@ -15,6 +15,7 @@ import ( func documentSymbols( contents string, module *ast.Module, + builtins map[string]*ast.Builtin, ) []types.DocumentSymbol { // Only pkgSymbols would likely suffice, but we're keeping docSymbols around in case // we ever want to add more top-level symbols than the package. @@ -62,7 +63,7 @@ func documentSymbols( SelectionRange: ruleRange, } - if detail := rast.GetRuleDetail(rule); detail != "" { + if detail := rast.GetRuleDetail(rule, builtins); detail != "" { ruleSymbol.Detail = &detail } @@ -88,7 +89,7 @@ func documentSymbols( SelectionRange: groupRange, } - detail := rast.GetRuleDetail(rules[0]) + detail := rast.GetRuleDetail(rules[0], builtins) if detail != "" { groupSymbol.Detail = &detail } @@ -104,7 +105,7 @@ func documentSymbols( SelectionRange: childRange, } - childDetail := rast.GetRuleDetail(rule) + childDetail := rast.GetRuleDetail(rule, builtins) if childDetail != "" { childRule.Detail = &childDetail } diff --git a/internal/lsp/documentsymbol_test.go b/internal/lsp/documentsymbol_test.go index 2af904fe..1d291813 100644 --- a/internal/lsp/documentsymbol_test.go +++ b/internal/lsp/documentsymbol_test.go @@ -5,6 +5,7 @@ import ( "github.com/open-policy-agent/opa/ast" + "github.com/styrainc/regal/internal/lsp/rego" "github.com/styrainc/regal/internal/lsp/types" "github.com/styrainc/regal/internal/lsp/types/symbols" ) @@ -69,7 +70,9 @@ func TestDocumentSymbols(t *testing.T) { t.Fatal(err) } - syms := documentSymbols(tc.policy, module) + bis := rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()) + + syms := documentSymbols(tc.policy, module, bis) pkg := syms[0] if pkg.Name != tc.expected.Name { diff --git a/internal/lsp/hover/hover.go b/internal/lsp/hover/hover.go index 6025866d..4d9d5be5 100644 --- a/internal/lsp/hover/hover.go +++ b/internal/lsp/hover/hover.go @@ -135,7 +135,7 @@ func CreateHoverContent(builtin *ast.Builtin) string { return result } -func UpdateBuiltinPositions(cache *cache.Cache, uri string) error { +func UpdateBuiltinPositions(cache *cache.Cache, uri string, builtins map[string]*ast.Builtin) error { module, ok := cache.GetModule(uri) if !ok { return fmt.Errorf("failed to update builtin positions: no parsed module for uri %q", uri) @@ -143,7 +143,7 @@ func UpdateBuiltinPositions(cache *cache.Cache, uri string) error { builtinsOnLine := map[uint][]types2.BuiltinPosition{} - for _, call := range rego.AllBuiltinCalls(module) { + for _, call := range rego.AllBuiltinCalls(module, builtins) { line := uint(call.Location.Row) builtinsOnLine[line] = append(builtinsOnLine[line], types2.BuiltinPosition{ diff --git a/internal/lsp/inlayhint.go b/internal/lsp/inlayhint.go index 1ac7a26d..c64d6054 100644 --- a/internal/lsp/inlayhint.go +++ b/internal/lsp/inlayhint.go @@ -18,10 +18,10 @@ func createInlayTooltip(named *types.NamedType) string { return fmt.Sprintf("%s\n\nType: `%s`", named.Descr, named.Type.String()) } -func getInlayHints(module *ast.Module) []types2.InlayHint { +func getInlayHints(module *ast.Module, builtins map[string]*ast.Builtin) []types2.InlayHint { inlayHints := make([]types2.InlayHint, 0) - for _, call := range rego.AllBuiltinCalls(module) { + for _, call := range rego.AllBuiltinCalls(module, builtins) { for i, arg := range call.Builtin.Decl.NamedFuncArgs().Args { if len(call.Args) <= i { // avoid panic if provided a builtin function where the args diff --git a/internal/lsp/inlayhint_test.go b/internal/lsp/inlayhint_test.go index 5a89716b..6d0243e7 100644 --- a/internal/lsp/inlayhint_test.go +++ b/internal/lsp/inlayhint_test.go @@ -4,6 +4,8 @@ import ( "testing" "github.com/open-policy-agent/opa/ast" + + "github.com/styrainc/regal/internal/lsp/rego" ) // A function call may either be represented as an ast.Call. @@ -16,7 +18,8 @@ func TestGetInlayHintsAstCall(t *testing.T) { module := ast.MustParseModule(policy) - inlayHints := getInlayHints(module) + bis := rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()) + inlayHints := getInlayHints(module, bis) if len(inlayHints) != 2 { t.Fatalf("Expected 2 inlay hints, got %d", len(inlayHints)) @@ -65,7 +68,9 @@ func TestGetInlayHintsAstTerms(t *testing.T) { module := ast.MustParseModule(policy) - inlayHints := getInlayHints(module) + bis := rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()) + + inlayHints := getInlayHints(module, bis) if len(inlayHints) != 1 { t.Fatalf("Expected 1 inlay hints, got %d", len(inlayHints)) diff --git a/internal/lsp/lint.go b/internal/lsp/lint.go index ab975764..0be14d21 100644 --- a/internal/lsp/lint.go +++ b/internal/lsp/lint.go @@ -24,7 +24,13 @@ import ( // updateParse updates the module cache with the latest parse result for a given URI, // if the module cannot be parsed, the parse errors are saved as diagnostics for the // URI instead. -func updateParse(ctx context.Context, cache *cache.Cache, store storage.Store, fileURI string) (bool, error) { +func updateParse( + ctx context.Context, + cache *cache.Cache, + store storage.Store, + fileURI string, + builtins map[string]*ast.Builtin, +) (bool, error) { content, ok := cache.GetFileContents(fileURI) if !ok { return false, fmt.Errorf("failed to get file contents for uri %q", fileURI) @@ -44,7 +50,7 @@ func updateParse(ctx context.Context, cache *cache.Cache, store storage.Store, f return false, fmt.Errorf("failed to update rego store with parsed module: %w", err) } - definedRefs := refs.DefinedInModule(module) + definedRefs := refs.DefinedInModule(module, builtins) cache.SetFileRefs(fileURI, definedRefs) diff --git a/internal/lsp/rego/builtins.go b/internal/lsp/rego/builtins.go index 7f5f88bb..262ebe0d 100644 --- a/internal/lsp/rego/builtins.go +++ b/internal/lsp/rego/builtins.go @@ -1,35 +1,15 @@ package rego import ( - "maps" "strings" - "sync" "github.com/open-policy-agent/opa/ast" ) -var ( - builtInsLock = &sync.RWMutex{} //nolint:gochecknoglobals - builtIns = builtinMap(ast.CapabilitiesForThisVersion()) //nolint:gochecknoglobals -) - -// Update updates the builtins database with the provided capabilities. -func UpdateBuiltins(caps *ast.Capabilities) { - builtInsLock.Lock() - builtIns = builtinMap(caps) - builtInsLock.Unlock() -} - -func GetBuiltins() map[string]*ast.Builtin { - builtInsLock.RLock() - defer builtInsLock.RUnlock() - - return maps.Clone(builtIns) -} - -func builtinMap(caps *ast.Capabilities) map[string]*ast.Builtin { +// BuiltinsForCapabilities returns a list of builtins from the provided capabilities. +func BuiltinsForCapabilities(capabilities *ast.Capabilities) map[string]*ast.Builtin { m := make(map[string]*ast.Builtin) - for _, b := range caps.Builtins { + for _, b := range capabilities.Builtins { m[b.Name] = b } diff --git a/internal/lsp/rego/rego.go b/internal/lsp/rego/rego.go index 819d2fb7..859d951c 100644 --- a/internal/lsp/rego/rego.go +++ b/internal/lsp/rego/rego.go @@ -53,11 +53,9 @@ func LocationFromPosition(pos types.Position) *ast.Location { // AllBuiltinCalls returns all built-in calls in the module, excluding operators // and any other function identified by an infix. -func AllBuiltinCalls(module *ast.Module) []BuiltInCall { +func AllBuiltinCalls(module *ast.Module, builtins map[string]*ast.Builtin) []BuiltInCall { builtinCalls := make([]BuiltInCall, 0) - bis := GetBuiltins() - callVisitor := ast.NewGenericVisitor(func(x interface{}) bool { var terms []*ast.Term @@ -76,7 +74,7 @@ func AllBuiltinCalls(module *ast.Module) []BuiltInCall { return false } - if b, ok := bis[terms[0].Value.String()]; ok { + if b, ok := builtins[terms[0].Value.String()]; ok { // Exclude operators and similar builtins if b.Infix != "" { return false diff --git a/internal/lsp/server.go b/internal/lsp/server.go index b29b5974..151ad965 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -83,6 +83,7 @@ func NewLanguageServer(opts *LanguageServerOptions) *LanguageServer { configWatcher: lsconfig.NewWatcher(&lsconfig.WatcherOpts{ErrorWriter: opts.ErrorLog}), completionsManager: completions.NewDefaultManager(c, store), webServer: web.NewServer(c), + loadedBuiltins: make(map[string]map[string]*ast.Builtin), } return ls @@ -93,9 +94,11 @@ type LanguageServer struct { errorLog io.Writer - configWatcher *lsconfig.Watcher - loadedConfig *config.Config - loadedConfigLock sync.Mutex + configWatcher *lsconfig.Watcher + loadedConfig *config.Config + loadedConfigLock sync.Mutex + loadedBuiltins map[string]map[string]*ast.Builtin + loadedBuiltinsLock sync.RWMutex workspaceRootURI string clientIdentifier clients.Identifier @@ -215,9 +218,11 @@ func (l *LanguageServer) StartDiagnosticsWorker(ctx context.Context) { case <-ctx.Done(): return case evt := <-l.diagnosticRequestFile: + bis := l.builtinsForCurrentCapabilities() + // updateParse will not return an error when the parsing failed, // but only when it was impossible - _, err := updateParse(ctx, l.cache, l.regoStore, evt.URI) + _, err := updateParse(ctx, l.cache, l.regoStore, evt.URI, bis) if err != nil { l.logError(fmt.Errorf("failed to update module for %s: %w", evt.URI, err)) @@ -282,7 +287,9 @@ func (l *LanguageServer) StartHoverWorker(ctx context.Context) { continue } - success, err := updateParse(ctx, l.cache, l.regoStore, fileURI) + bis := l.builtinsForCurrentCapabilities() + + success, err := updateParse(ctx, l.cache, l.regoStore, fileURI, bis) if err != nil { l.logError(fmt.Errorf("failed to update parse: %w", err)) @@ -293,7 +300,7 @@ func (l *LanguageServer) StartHoverWorker(ctx context.Context) { continue } - err = hover.UpdateBuiltinPositions(l.cache, fileURI) + err = hover.UpdateBuiltinPositions(l.cache, fileURI, bis) if err != nil { l.logError(fmt.Errorf("failed to update builtin positions: %w", err)) @@ -337,7 +344,7 @@ func (l *LanguageServer) StartConfigWorker(ctx context.Context) { case <-ctx.Done(): return case path := <-l.configWatcher.Reload: - configFile, err := os.Open(path) + configFileBs, err := os.ReadFile(path) if err != nil { l.logError(fmt.Errorf("failed to open config file: %w", err)) @@ -346,7 +353,7 @@ func (l *LanguageServer) StartConfigWorker(ctx context.Context) { var userConfig config.Config - err = yaml.NewDecoder(configFile).Decode(&userConfig) + err = yaml.Unmarshal(configFileBs, &userConfig) if err != nil && !errors.Is(err, io.EOF) { l.logError(fmt.Errorf("failed to reload config: %w", err)) @@ -371,23 +378,21 @@ func (l *LanguageServer) StartConfigWorker(ctx context.Context) { l.loadedConfigLock.Unlock() - // Capabilities URL may have changed, so we should - // reload it. + // Capabilities URL may have changed, so we should reload it. capsURL := l.getLoadedConfig().CapabilitiesURL - if capsURL == "" { - // This can happen if we have an empty config. - capsURL = "regal:///capabilities/default" - } - caps, err := capabilities.Lookup(ctx, capsURL) if err != nil { - l.logError(fmt.Errorf("failed to lookup capabilities: %w", err)) + l.logError(fmt.Errorf("failed to load capabilities for URL %q: %w", capsURL, err)) return } - rego.UpdateBuiltins(caps) + bis := rego.BuiltinsForCapabilities(caps) + + l.loadedBuiltinsLock.Lock() + l.loadedBuiltins[capsURL] = bis + l.loadedBuiltinsLock.Unlock() // the config may now ignore files that existed in the cache before, // in which case we need to remove them to stop their contents being @@ -426,7 +431,7 @@ func (l *LanguageServer) StartConfigWorker(ctx context.Context) { // updating the parse here will enable things like go-to definition // to start working right away without the need for a file content // update to run updateParse. - _, err = updateParse(ctx, l.cache, l.regoStore, k) + _, err = updateParse(ctx, l.cache, l.regoStore, k, bis) if err != nil { l.logError(fmt.Errorf("failed to update parse for previously ignored file %q: %w", k, err)) } @@ -1133,7 +1138,9 @@ func (l *LanguageServer) processHoverContentUpdate(ctx context.Context, fileURI l.cache.SetFileContents(fileURI, content) - success, err := updateParse(ctx, l.cache, l.regoStore, fileURI) + bis := l.builtinsForCurrentCapabilities() + + success, err := updateParse(ctx, l.cache, l.regoStore, fileURI, bis) if err != nil { return fmt.Errorf("failed to update parse: %w", err) } @@ -1142,7 +1149,7 @@ func (l *LanguageServer) processHoverContentUpdate(ctx context.Context, fileURI return nil } - err = hover.UpdateBuiltinPositions(l.cache, fileURI) + err = hover.UpdateBuiltinPositions(l.cache, fileURI, bis) if err != nil { return fmt.Errorf("failed to update builtin positions: %w", err) } @@ -1429,6 +1436,8 @@ func (l *LanguageServer) handleTextDocumentInlayHint( return []types.InlayHint{}, nil } + bis := l.builtinsForCurrentCapabilities() + // when a file cannot be parsed, we do a best effort attempt to provide inlay hints // by finding the location of the first parse error and attempting to parse up to that point parseErrors, ok := l.cache.GetParseErrors(params.TextDocument.URI) @@ -1439,7 +1448,7 @@ func (l *LanguageServer) handleTextDocumentInlayHint( return []types.InlayHint{}, nil } - return partialInlayHints(parseErrors, contents, params.TextDocument.URI), nil + return partialInlayHints(parseErrors, contents, params.TextDocument.URI, bis), nil } module, ok := l.cache.GetModule(params.TextDocument.URI) @@ -1449,7 +1458,7 @@ func (l *LanguageServer) handleTextDocumentInlayHint( return []types.InlayHint{}, nil } - inlayHints := getInlayHints(module) + inlayHints := getInlayHints(module, bis) return inlayHints, nil } @@ -1500,6 +1509,7 @@ func (l *LanguageServer) handleTextDocumentCompletion( items, err := l.completionsManager.Run(ctx, params, &providers.Options{ ClientIdentifier: l.clientIdentifier, RootURI: l.workspaceRootURI, + Builtins: l.builtinsForCurrentCapabilities(), }) if err != nil { return nil, fmt.Errorf("failed to find completions: %w", err) @@ -1519,7 +1529,12 @@ func (l *LanguageServer) handleTextDocumentCompletion( }, nil } -func partialInlayHints(parseErrors []types.Diagnostic, contents, fileURI string) []types.InlayHint { +func partialInlayHints( + parseErrors []types.Diagnostic, + contents, + fileURI string, + builtins map[string]*ast.Builtin, +) []types.InlayHint { firstErrorLine := uint(0) for _, parseError := range parseErrors { if parseError.Range.Start.Line > firstErrorLine { @@ -1547,7 +1562,7 @@ func partialInlayHints(parseErrors []types.Diagnostic, contents, fileURI string) return []types.InlayHint{} } - return getInlayHints(module) + return getInlayHints(module, builtins) } func (l *LanguageServer) handleWorkspaceSymbol( @@ -1568,9 +1583,11 @@ func (l *LanguageServer) handleWorkspaceSymbol( // But perhaps a good one to do at some point, and I'm not sure all clients // do this filtering. + bis := l.builtinsForCurrentCapabilities() + for moduleURL, module := range l.cache.GetAllModules() { content := contents[moduleURL] - docSyms := documentSymbols(content, module) + docSyms := documentSymbols(content, module, bis) wrkSyms := make([]types.WorkspaceSymbol, 0) toWorkspaceSymbols(docSyms, moduleURL, &wrkSyms) @@ -1811,7 +1828,9 @@ func (l *LanguageServer) handleTextDocumentDocumentSymbol( return []types.DocumentSymbol{}, nil } - return documentSymbols(contents, module), nil + bis := l.builtinsForCurrentCapabilities() + + return documentSymbols(contents, module, bis), nil } func (l *LanguageServer) handleTextDocumentFoldingRange( @@ -2289,7 +2308,9 @@ func (l *LanguageServer) loadWorkspaceContents(ctx context.Context, newOnly bool return nil } - _, err = updateParse(ctx, l.cache, l.regoStore, fileURI) + bis := l.builtinsForCurrentCapabilities() + + _, err = updateParse(ctx, l.cache, l.regoStore, fileURI, bis) if err != nil { return fmt.Errorf("failed to update parse: %w", err) } @@ -2429,6 +2450,26 @@ func (l *LanguageServer) workspacePath() string { return uri.ToPath(l.clientIdentifier, l.workspaceRootURI) } +// builtinsForCurrentCapabilities returns the map of builtins for use +// in the server based on the currently loaded capabilities. If there is no +// config, then the default for the Regal OPA version is used. +func (l *LanguageServer) builtinsForCurrentCapabilities() map[string]*ast.Builtin { + l.loadedBuiltinsLock.RLock() + defer l.loadedBuiltinsLock.RUnlock() + + cfg := l.getLoadedConfig() + if cfg == nil { + return rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()) + } + + bis, ok := l.loadedBuiltins[cfg.CapabilitiesURL] + if !ok { + return rego.BuiltinsForCapabilities(ast.CapabilitiesForThisVersion()) + } + + return bis +} + func positionToOffset(text string, p types.Position) int { bytesRead := 0 lines := strings.Split(text, "\n") diff --git a/internal/lsp/server_test.go b/internal/lsp/server_test.go index 76b6112f..9da8996c 100644 --- a/internal/lsp/server_test.go +++ b/internal/lsp/server_test.go @@ -22,7 +22,6 @@ import ( "github.com/styrainc/regal/internal/lsp/cache" "github.com/styrainc/regal/internal/lsp/clients" - "github.com/styrainc/regal/internal/lsp/rego" "github.com/styrainc/regal/internal/lsp/types" "github.com/styrainc/regal/pkg/config" "github.com/styrainc/regal/pkg/fixer/fixes" @@ -33,7 +32,7 @@ const mainRegoFileName = "/main.rego" // defaultTimeout is set based on the investigation done as part of // https://github.com/StyraInc/regal/issues/931. 20 seconds is 10x the // maximum time observed for an operation to complete. -const defaultTimeout = 20 * time.Second +const defaultTimeout = 5 * time.Second const defaultBufferedChannelSize = 5 @@ -336,36 +335,6 @@ capabilities: } } - // manually inspect the server's list of builtins to ensure that the EOPA - // capabilities were loaded correctly. - timeout = time.NewTimer(defaultTimeout) - defer timeout.Stop() - - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for { - success := false - - select { - case <-timeout.C: - t.Fatalf("timed out waiting for builtins map to be updated") - case <-ticker.C: - bis := rego.GetBuiltins() - - // Search for a builtin we know is only in the EOPA capabilities. - if _, ok := bis["neo4j.query"]; ok { - success = true - } - - t.Logf("waiting for neo4j.query builtin to be present, got %v", bis) - } - - if success { - break - } - } - // 6. Client sends textDocument/didChange notification with new // contents for main.rego no response to the call is expected. We added // the start of an EOPA-specific call, so if the capabilities were @@ -425,7 +394,7 @@ allow := neo4j.q timeout = time.NewTimer(defaultTimeout) defer timeout.Stop() - ticker = time.NewTicker(100 * time.Millisecond) + ticker := time.NewTicker(500 * time.Millisecond) defer ticker.Stop() for { @@ -1071,11 +1040,13 @@ func testRequestDataCodes(t *testing.T, requestData types.FileDiagnostics, fileU sort.Strings(codes) if !slices.Equal(requestCodes, codes) { - t.Logf("expected items: %v, got: %v", codes, requestCodes) + t.Logf("waiting for items: %v, got: %v", codes, requestCodes) return false } + t.Logf("got expected items") + return true }