From 67162e66d0e54fa2319b570a4e425496314db1f0 Mon Sep 17 00:00:00 2001 From: Charlie Egan Date: Thu, 3 Oct 2024 17:10:07 +0100 Subject: [PATCH] lsp: Update LSP linting to run incrementally after file change (#1146) * lsp: Run aggregate-triggered lints incrementally This makes a change to how workspace lint runs are run. Now, when a file with aggregate violations is changed, a full workspace lint will still run but only for the aggregate rules the file has. This will also be done using cached intermediate aggregate data. * WIP * add a polling workspace ticker * remove prints * Increase completions timeout --- cmd/languageserver.go | 2 +- docs/rules/imports/prefer-package-imports.md | 4 +- go.mod | 2 +- internal/lsp/cache/cache.go | 144 ++- internal/lsp/cache/cache_test.go | 161 +++ internal/lsp/completions/manager.go | 4 +- internal/lsp/completions/providers/policy.go | 12 +- .../lsp/completions/providers/policy_test.go | 4 +- internal/lsp/eval_test.go | 2 +- internal/lsp/lint.go | 162 +-- internal/lsp/lint_test.go | 67 ++ internal/lsp/race_off.go | 8 + internal/lsp/race_on.go | 8 + internal/lsp/server.go | 423 +++++-- internal/lsp/server_aggregates_test.go | 478 ++++++++ internal/lsp/server_builtins_test.go | 32 + internal/lsp/server_config_test.go | 307 +++++ internal/lsp/server_formatting_test.go | 78 ++ internal/lsp/server_multi_file_test.go | 166 +++ internal/lsp/server_rename_test.go | 79 ++ internal/lsp/server_single_file_test.go | 364 ++++++ internal/lsp/server_template_test.go | 57 +- internal/lsp/server_test.go | 1054 ++--------------- pkg/linter/linter.go | 160 ++- pkg/linter/linter_test.go | 148 ++- pkg/report/report.go | 36 + 26 files changed, 2680 insertions(+), 1282 deletions(-) create mode 100644 internal/lsp/cache/cache_test.go create mode 100644 internal/lsp/lint_test.go create mode 100644 internal/lsp/race_off.go create mode 100644 internal/lsp/race_on.go create mode 100644 internal/lsp/server_aggregates_test.go create mode 100644 internal/lsp/server_builtins_test.go create mode 100644 internal/lsp/server_config_test.go create mode 100644 internal/lsp/server_formatting_test.go create mode 100644 internal/lsp/server_multi_file_test.go create mode 100644 internal/lsp/server_rename_test.go create mode 100644 internal/lsp/server_single_file_test.go diff --git a/cmd/languageserver.go b/cmd/languageserver.go index 52d8d536..c0f945ab 100644 --- a/cmd/languageserver.go +++ b/cmd/languageserver.go @@ -28,7 +28,7 @@ func init() { ErrorLog: os.Stderr, } - ls := lsp.NewLanguageServer(opts) + ls := lsp.NewLanguageServer(ctx, opts) conn := lsp.NewConnectionFromLanguageServer(ctx, ls.Handle, &lsp.ConnectionOptions{ LoggingConfig: lsp.ConnectionLoggingConfig{ diff --git a/docs/rules/imports/prefer-package-imports.md b/docs/rules/imports/prefer-package-imports.md index 56a6ded9..0cc44de9 100644 --- a/docs/rules/imports/prefer-package-imports.md +++ b/docs/rules/imports/prefer-package-imports.md @@ -15,7 +15,7 @@ import rego.v1 # Rule imported directly import data.users.first_names -has_waldo { +has_waldo if { # Not obvious where "first_names" comes from "Waldo" in first_names } @@ -30,7 +30,7 @@ import rego.v1 # Package imported rather than rule import data.users -has_waldo { +has_waldo if { # Obvious where "first_names" comes from "Waldo" in users.first_names } diff --git a/go.mod b/go.mod index e248ddfd..37095651 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/sourcegraph/jsonrpc2 v0.2.0 github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 + gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -89,6 +90,5 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20240820151423-278611b39280 // indirect google.golang.org/protobuf v1.34.2 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect sigs.k8s.io/yaml v1.4.0 // indirect ) diff --git a/internal/lsp/cache/cache.go b/internal/lsp/cache/cache.go index d46fe9ae..a8c5971f 100644 --- a/internal/lsp/cache/cache.go +++ b/internal/lsp/cache/cache.go @@ -9,6 +9,7 @@ import ( "github.com/open-policy-agent/opa/ast" "github.com/styrainc/regal/internal/lsp/types" + "github.com/styrainc/regal/pkg/report" ) // Cache is used to store: current file contents (which includes unsaved changes), the latest parsed modules, and @@ -26,12 +27,15 @@ type Cache struct { // modules is a map of file URI to parsed AST modules from the latest file contents value modules map[string]*ast.Module + // aggregateData stores the aggregate data from evaluations for each file. + // This is used to cache the results of expensive evaluations and can be used + // to update aggregate diagostics incrementally. + aggregateData map[string][]report.Aggregate + aggregateDataMu sync.Mutex + // diagnosticsFile is a map of file URI to diagnostics for that file diagnosticsFile map[string][]types.Diagnostic - // diagnosticsAggregate is a map of file URI to aggregate diagnostics for that file - diagnosticsAggregate map[string][]types.Diagnostic - // diagnosticsParseErrors is a map of file URI to parse errors for that file diagnosticsParseErrors map[string][]types.Diagnostic @@ -54,8 +58,6 @@ type Cache struct { diagnosticsFileMu sync.Mutex - diagnosticsAggregateMu sync.Mutex - diagnosticsParseMu sync.Mutex builtinPositionsMu sync.Mutex @@ -72,8 +74,9 @@ func NewCache() *Cache { modules: make(map[string]*ast.Module), + aggregateData: make(map[string][]report.Aggregate), + diagnosticsFile: make(map[string][]types.Diagnostic), - diagnosticsAggregate: make(map[string][]types.Diagnostic), diagnosticsParseErrors: make(map[string][]types.Diagnostic), builtinPositionsFile: make(map[string]map[uint][]types.BuiltinPosition), @@ -83,27 +86,6 @@ func NewCache() *Cache { } } -func (c *Cache) GetAllDiagnosticsForURI(fileURI string) []types.Diagnostic { - parseDiags, ok := c.GetParseErrors(fileURI) - if ok && len(parseDiags) > 0 { - return parseDiags - } - - allDiags := make([]types.Diagnostic, 0) - - aggDiags, ok := c.GetAggregateDiagnostics(fileURI) - if ok { - allDiags = append(allDiags, aggDiags...) - } - - fileDiags, ok := c.GetFileDiagnostics(fileURI) - if ok { - allDiags = append(allDiags, fileDiags...) - } - - return allDiags -} - func (c *Cache) GetAllFiles() map[string]string { c.fileContentsMu.Lock() defer c.fileContentsMu.Unlock() @@ -180,6 +162,69 @@ func (c *Cache) SetModule(fileURI string, module *ast.Module) { c.modules[fileURI] = module } +// SetFileAggregates will only set aggregate data for the provided URI. Even if +// data for other files is provided, only the specified URI is updated. +func (c *Cache) SetFileAggregates(fileURI string, data map[string][]report.Aggregate) { + c.aggregateDataMu.Lock() + defer c.aggregateDataMu.Unlock() + + flattenedAggregates := make([]report.Aggregate, 0) + + for _, aggregates := range data { + for _, aggregate := range aggregates { + if aggregate.SourceFile() != fileURI { + continue + } + + flattenedAggregates = append(flattenedAggregates, aggregate) + } + } + + c.aggregateData[fileURI] = flattenedAggregates +} + +func (c *Cache) SetAggregates(data map[string][]report.Aggregate) { + c.aggregateDataMu.Lock() + defer c.aggregateDataMu.Unlock() + + // clear the state + c.aggregateData = make(map[string][]report.Aggregate) + + for _, aggregates := range data { + for _, aggregate := range aggregates { + c.aggregateData[aggregate.SourceFile()] = append(c.aggregateData[aggregate.SourceFile()], aggregate) + } + } +} + +// GetFileAggregates is used to get aggregate data for a given list of files. +// This is only used in tests to validate the cache state. +func (c *Cache) GetFileAggregates(fileURIs ...string) map[string][]report.Aggregate { + c.aggregateDataMu.Lock() + defer c.aggregateDataMu.Unlock() + + includedFiles := make(map[string]struct{}, len(fileURIs)) + for _, fileURI := range fileURIs { + includedFiles[fileURI] = struct{}{} + } + + getAll := len(fileURIs) == 0 + + allAggregates := make(map[string][]report.Aggregate) + + for sourceFile, aggregates := range c.aggregateData { + if _, included := includedFiles[sourceFile]; !included && !getAll { + continue + } + + for _, aggregate := range aggregates { + allAggregates[aggregate.IndexKey()] = append(allAggregates[aggregate.IndexKey()], aggregate) + } + } + + return allAggregates +} + func (c *Cache) GetFileDiagnostics(uri string) ([]types.Diagnostic, bool) { c.diagnosticsFileMu.Lock() defer c.diagnosticsFileMu.Unlock() @@ -196,34 +241,33 @@ func (c *Cache) SetFileDiagnostics(fileURI string, diags []types.Diagnostic) { c.diagnosticsFile[fileURI] = diags } -func (c *Cache) ClearFileDiagnostics() { +// SetFileDiagnosticsForRules will perform a partial update of the diagnostics +// for a file given a list of evaluated rules. +func (c *Cache) SetFileDiagnosticsForRules(fileURI string, rules []string, diags []types.Diagnostic) { c.diagnosticsFileMu.Lock() defer c.diagnosticsFileMu.Unlock() - c.diagnosticsFile = make(map[string][]types.Diagnostic) -} - -func (c *Cache) GetAggregateDiagnostics(fileURI string) ([]types.Diagnostic, bool) { - c.diagnosticsAggregateMu.Lock() - defer c.diagnosticsAggregateMu.Unlock() - - val, ok := c.diagnosticsAggregate[fileURI] + ruleKeys := make(map[string]struct{}, len(rules)) + for _, rule := range rules { + ruleKeys[rule] = struct{}{} + } - return val, ok -} + preservedDiagnostics := make([]types.Diagnostic, 0) -func (c *Cache) SetAggregateDiagnostics(fileURI string, diags []types.Diagnostic) { - c.diagnosticsAggregateMu.Lock() - defer c.diagnosticsAggregateMu.Unlock() + for _, diag := range c.diagnosticsFile[fileURI] { + if _, ok := ruleKeys[diag.Code]; !ok { + preservedDiagnostics = append(preservedDiagnostics, diag) + } + } - c.diagnosticsAggregate[fileURI] = diags + c.diagnosticsFile[fileURI] = append(preservedDiagnostics, diags...) } -func (c *Cache) ClearAggregateDiagnostics() { - c.diagnosticsAggregateMu.Lock() - defer c.diagnosticsAggregateMu.Unlock() +func (c *Cache) ClearFileDiagnostics() { + c.diagnosticsFileMu.Lock() + defer c.diagnosticsFileMu.Unlock() - c.diagnosticsAggregate = make(map[string][]types.Diagnostic) + c.diagnosticsFile = make(map[string][]types.Diagnostic) } func (c *Cache) GetParseErrors(uri string) ([]types.Diagnostic, bool) { @@ -313,14 +357,14 @@ func (c *Cache) Delete(fileURI string) { delete(c.modules, fileURI) c.moduleMu.Unlock() + c.aggregateDataMu.Lock() + delete(c.aggregateData, fileURI) + c.aggregateDataMu.Unlock() + c.diagnosticsFileMu.Lock() delete(c.diagnosticsFile, fileURI) c.diagnosticsFileMu.Unlock() - c.diagnosticsAggregateMu.Lock() - delete(c.diagnosticsAggregate, fileURI) - c.diagnosticsAggregateMu.Unlock() - c.diagnosticsParseMu.Lock() delete(c.diagnosticsParseErrors, fileURI) c.diagnosticsParseMu.Unlock() diff --git a/internal/lsp/cache/cache_test.go b/internal/lsp/cache/cache_test.go new file mode 100644 index 00000000..7dd19f03 --- /dev/null +++ b/internal/lsp/cache/cache_test.go @@ -0,0 +1,161 @@ +package cache + +import ( + "reflect" + "testing" + + "github.com/styrainc/regal/internal/lsp/types" + "github.com/styrainc/regal/pkg/report" +) + +func TestManageAggregates(t *testing.T) { + t.Parallel() + + reportAggregatesFile1 := map[string][]report.Aggregate{ + "my-rule-name": { + { + "aggregate_data": map[string]any{ + "foo": "bar", + }, + "aggregate_source": map[string]any{ + "file": "file1.rego", + "package_path": []string{"p"}, + }, + "rule": map[string]any{ + "category": "my-rule-category", + "title": "my-rule-name", + }, + }, + { + "aggregate_data": map[string]any{ + "more": "things", + }, + "aggregate_source": map[string]any{ + "file": "file1.rego", + "package_path": []string{"p"}, + }, + "rule": map[string]any{ + "category": "my-rule-category", + "title": "my-rule-name", + }, + }, + }, + } + + reportAggregatesFile2 := map[string][]report.Aggregate{ + "my-rule-name": { + { + "aggregate_data": map[string]any{ + "foo": "baz", + }, + "aggregate_source": map[string]any{ + "file": "file2.rego", + "package_path": []string{"p"}, + }, + "rule": map[string]any{ + "category": "my-rule-category", + "title": "my-rule-name", + }, + }, + }, + "my-other-rule-name": { + { + "aggregate_data": map[string]any{ + "foo": "bax", + }, + "aggregate_source": map[string]any{ + "file": "file2.rego", + "package_path": []string{"p"}, + }, + "rule": map[string]any{ + "category": "my-other-rule-category", + "title": "my-other-rule-name", + }, + }, + }, + } + + c := NewCache() + + c.SetFileAggregates("file1.rego", reportAggregatesFile1) + c.SetFileAggregates("file2.rego", reportAggregatesFile2) + + aggs1 := c.GetFileAggregates("file1.rego") + if len(aggs1) != 1 { // there is one cat/rule for file1 + t.Fatalf("unexpected number of aggregates for file1.rego: %d", len(aggs1)) + } + + aggs2 := c.GetFileAggregates("file2.rego") + if len(aggs2) != 2 { + t.Fatalf("unexpected number of aggregates for file2.rego: %d", len(aggs2)) + } + + allAggs := c.GetFileAggregates() + + if len(allAggs) != 2 { + t.Fatalf("unexpected number of aggregates: %d", len(allAggs)) + } + + if _, ok := allAggs["my-other-rule-category/my-other-rule-name"]; !ok { + t.Fatalf("missing aggregate my-other-rule-name") + } + + c.SetAggregates(reportAggregatesFile1) // update aggregates to only contain file1.rego's aggregates + + allAggs = c.GetFileAggregates() + + if len(allAggs) != 1 { + t.Fatalf("unexpected number of aggregates: %d", len(allAggs)) + } + + if _, ok := allAggs["my-rule-category/my-rule-name"]; !ok { + t.Fatalf("missing aggregate my-rule-name") + } + + // remove file1 from the cache + c.Delete("file1.rego") + + allAggs = c.GetFileAggregates() + + if len(allAggs) != 0 { + t.Fatalf("unexpected number of aggregates: %d", len(allAggs)) + } +} + +func TestPartialDiagnosticsUpdate(t *testing.T) { + t.Parallel() + + c := NewCache() + + diag1 := types.Diagnostic{Code: "code1"} + diag2 := types.Diagnostic{Code: "code2"} + diag3 := types.Diagnostic{Code: "code3"} + + c.SetFileDiagnostics("foo.rego", []types.Diagnostic{ + diag1, diag2, + }) + + foundDiags, ok := c.GetFileDiagnostics("foo.rego") + if !ok { + t.Fatalf("expected to get diags for foo.rego") + } + + if !reflect.DeepEqual(foundDiags, []types.Diagnostic{diag1, diag2}) { + t.Fatalf("unexpected diagnostics: %v", foundDiags) + } + + c.SetFileDiagnosticsForRules( + "foo.rego", + []string{"code2", "code3"}, + []types.Diagnostic{diag3}, + ) + + foundDiags, ok = c.GetFileDiagnostics("foo.rego") + if !ok { + t.Fatalf("expected to get diags for foo.rego") + } + + if !reflect.DeepEqual(foundDiags, []types.Diagnostic{diag1, diag3}) { + t.Fatalf("unexpected diagnostics: %v", foundDiags) + } +} diff --git a/internal/lsp/completions/manager.go b/internal/lsp/completions/manager.go index f5c634d8..4d1c5983 100644 --- a/internal/lsp/completions/manager.go +++ b/internal/lsp/completions/manager.go @@ -29,7 +29,7 @@ func NewManager(c *cache.Cache, opts *ManagerOptions) *Manager { return &Manager{c: c, opts: opts} } -func NewDefaultManager(c *cache.Cache, store storage.Store) *Manager { +func NewDefaultManager(ctx context.Context, c *cache.Cache, store storage.Store) *Manager { m := NewManager(c, &ManagerOptions{}) m.RegisterProvider(&providers.BuiltIns{}) @@ -38,7 +38,7 @@ func NewDefaultManager(c *cache.Cache, store storage.Store) *Manager { m.RegisterProvider(&providers.RuleHeadKeyword{}) m.RegisterProvider(&providers.Input{}) - m.RegisterProvider(providers.NewPolicy(store)) + m.RegisterProvider(providers.NewPolicy(ctx, store)) return m } diff --git a/internal/lsp/completions/providers/policy.go b/internal/lsp/completions/providers/policy.go index 08195b42..6a849e20 100644 --- a/internal/lsp/completions/providers/policy.go +++ b/internal/lsp/completions/providers/policy.go @@ -30,8 +30,8 @@ type Policy struct { // NewPolicy creates a new Policy provider. This provider is distinctly different from the other providers // as it acts like the entrypoint for all Rego-based providers, and not a single provider "function" like // the Go providers do. -func NewPolicy(store storage.Store) *Policy { - pq, err := prepareQuery(store, "completions := data.regal.lsp.completion.items") +func NewPolicy(ctx context.Context, store storage.Store) *Policy { + pq, err := prepareQuery(ctx, store, "completions := data.regal.lsp.completion.items") if err != nil { panic(fmt.Sprintf("failed preparing query for static bundle: %v", err)) } @@ -112,10 +112,10 @@ func (p *Policy) Run( return completions, nil } -func prepareQuery(store storage.Store, query string) (*rego.PreparedEvalQuery, error) { +func prepareQuery(ctx context.Context, store storage.Store, query string) (*rego.PreparedEvalQuery, error) { regoArgs := prepareRegoArgs(store, ast.MustParseBody(query)) - txn, err := store.NewTransaction(context.TODO(), storage.WriteParams) + txn, err := store.NewTransaction(ctx, storage.WriteParams) if err != nil { return nil, fmt.Errorf("failed creating transaction: %w", err) } @@ -125,12 +125,12 @@ func prepareQuery(store storage.Store, query string) (*rego.PreparedEvalQuery, e // Note that we currently don't provide metrics or profiling here, and // most likely we should — need to consider how to best make that conditional // and how to present it if enabled. - pq, err := rego.New(regoArgs...).PrepareForEval(context.Background()) + pq, err := rego.New(regoArgs...).PrepareForEval(ctx) if err != nil { return nil, fmt.Errorf("failed preparing query: %s, %w", query, err) } - if err = store.Commit(context.Background(), txn); err != nil { + if err = store.Commit(ctx, txn); err != nil { return nil, fmt.Errorf("failed committing transaction: %w", err) } diff --git a/internal/lsp/completions/providers/policy_test.go b/internal/lsp/completions/providers/policy_test.go index 997858a8..1e5037f8 100644 --- a/internal/lsp/completions/providers/policy_test.go +++ b/internal/lsp/completions/providers/policy_test.go @@ -47,7 +47,7 @@ allow if { }, }, inmem.OptRoundTripOnWrite(false)) - locals := NewPolicy(store) + locals := NewPolicy(context.Background(), store) params := types.CompletionParams{ TextDocument: types.TextDocumentIdentifier{ @@ -109,7 +109,7 @@ import data.example }, }) - locals := NewPolicy(store) + locals := NewPolicy(context.Background(), store) fileEdited := `package example2 import rego.v1 diff --git a/internal/lsp/eval_test.go b/internal/lsp/eval_test.go index 6c9b7dfe..909da62f 100644 --- a/internal/lsp/eval_test.go +++ b/internal/lsp/eval_test.go @@ -14,7 +14,7 @@ import ( func TestEvalWorkspacePath(t *testing.T) { t.Parallel() - ls := NewLanguageServer(&LanguageServerOptions{ErrorLog: os.Stderr}) + ls := NewLanguageServer(context.Background(), &LanguageServerOptions{ErrorLog: os.Stderr}) policy1 := `package policy1 diff --git a/internal/lsp/lint.go b/internal/lsp/lint.go index bd5e8151..70db7dd7 100644 --- a/internal/lsp/lint.go +++ b/internal/lsp/lint.go @@ -148,25 +148,30 @@ func updateFileDiagnostics( ctx context.Context, cache *cache.Cache, regalConfig *config.Config, - uri string, - rootDir string, + fileURI string, + workspaceRootDir string, + updateDiagnosticsForRules []string, ) error { - module, ok := cache.GetModule(uri) + module, ok := cache.GetModule(fileURI) if !ok { // then there must have been a parse error return nil } - contents, ok := cache.GetFileContents(uri) + contents, ok := cache.GetFileContents(fileURI) if !ok { - return fmt.Errorf("failed to get file contents for uri %q", uri) + return fmt.Errorf("failed to get file contents for uri %q", fileURI) } - input := rules.NewInput(map[string]string{uri: contents}, map[string]*ast.Module{uri: module}) + input := rules.NewInput(map[string]string{fileURI: contents}, map[string]*ast.Module{fileURI: module}) regalInstance := linter.NewLinter(). + // needed to get the aggregateData for this file + WithCollectQuery(true). + // needed to get the aggregateData out so we can update the cache + WithExportAggregates(true). WithInputModules(&input). - WithRootDir(rootDir) + WithRootDir(workspaceRootDir) if regalConfig != nil { regalInstance = regalInstance.WithUserConfig(*regalConfig) @@ -177,33 +182,29 @@ func updateFileDiagnostics( return fmt.Errorf("failed to lint: %w", err) } - diags := make([]types.Diagnostic, 0) + fileDiags := convertReportToDiagnostics(&rpt, workspaceRootDir) - for _, item := range rpt.Violations { - // here errors are presented as warnings, and warnings as info - // to differentiate from parse errors - severity := uint(2) - if item.Level == "warning" { - severity = 3 + files := cache.GetAllFiles() + + for uri := range files { + // if a file has parse errors, continue to show these until they're addressed + parseErrs, ok := cache.GetParseErrors(uri) + if ok && len(parseErrs) > 0 { + continue } - diags = append(diags, types.Diagnostic{ - Severity: severity, - Range: getRangeForViolation(item), - Message: item.Description, - Source: "regal/" + item.Category, - Code: item.Title, - CodeDescription: &types.CodeDescription{ - Href: fmt.Sprintf( - "https://docs.styra.com/regal/rules/%s/%s", - item.Category, - item.Title, - ), - }, - }) + // For updateFileDiagnostics, we only update the file in question. + if uri == fileURI { + fd, ok := fileDiags[uri] + if !ok { + fd = []types.Diagnostic{} + } + + cache.SetFileDiagnosticsForRules(uri, updateDiagnosticsForRules, fd) + } } - cache.SetFileDiagnostics(uri, diags) + cache.SetFileAggregates(fileURI, rpt.Aggregates) return nil } @@ -212,25 +213,73 @@ func updateAllDiagnostics( ctx context.Context, cache *cache.Cache, regalConfig *config.Config, - detachedURI string, + workspaceRootDir string, + overwriteAggregates bool, + aggregatesReportOnly bool, + updateDiagnosticsForRules []string, ) error { + var err error + modules := cache.GetAllModules() files := cache.GetAllFiles() - input := rules.NewInput(files, modules) - - regalInstance := linter.NewLinter().WithInputModules(&input).WithRootDir(detachedURI) + regalInstance := linter.NewLinter(). + WithRootDir(workspaceRootDir). + // aggregates need only be exported if they're to be used to overwrite. + WithExportAggregates(overwriteAggregates) if regalConfig != nil { regalInstance = regalInstance.WithUserConfig(*regalConfig) } + if aggregatesReportOnly { + regalInstance = regalInstance. + WithAggregates(cache.GetFileAggregates()) + } else { + input := rules.NewInput(files, modules) + regalInstance = regalInstance.WithInputModules(&input) + } + rpt, err := regalInstance.Lint(ctx) if err != nil { return fmt.Errorf("failed to lint: %w", err) } - aggDiags := make(map[string][]types.Diagnostic) + fileDiags := convertReportToDiagnostics(&rpt, workspaceRootDir) + + for uri := range files { + parseErrs, ok := cache.GetParseErrors(uri) + if ok && len(parseErrs) > 0 { + continue + } + + fd, ok := fileDiags[uri] + if !ok { + fd = []types.Diagnostic{} + } + + // when only an aggregate report was run, then we must make sure to + // only update diagnostics from these rules. So the report is + // authoratative, but for those rules only. + if aggregatesReportOnly { + cache.SetFileDiagnosticsForRules(uri, updateDiagnosticsForRules, fd) + } else { + cache.SetFileDiagnostics(uri, fd) + } + } + + if overwriteAggregates { + // clear all aggregates, and use these ones + cache.SetAggregates(rpt.Aggregates) + } + + return nil +} + +func convertReportToDiagnostics( + rpt *report.Report, + workspaceRootURI string, +) map[string][]types.Diagnostic { fileDiags := make(map[string][]types.Diagnostic) for _, item := range rpt.Violations { @@ -256,53 +305,14 @@ func updateAllDiagnostics( }, } - // TODO(charlieegan3): it'd be nice to be able to only run aggregate rules in some cases, but for now, we - // can just run all rules each time. - if item.IsAggregate { - if item.Location.File == "" { - aggDiags[detachedURI] = append(aggDiags[detachedURI], diag) - } else { - aggDiags[item.Location.File] = append(aggDiags[item.Location.File], diag) - } + if item.Location.File == "" { + fileDiags[workspaceRootURI] = append(fileDiags[workspaceRootURI], diag) } else { fileDiags[item.Location.File] = append(fileDiags[item.Location.File], diag) } } - // this lint contains authoritative information about all files - // all diagnostics are cleared and replaced with the new lint - for uri := range files { - // if a file has parse errors, then we continue to show these until they're addressed - // as if there are lint results they must be based on an old, parsed version of the file - parseErrs, ok := cache.GetParseErrors(uri) - if ok && len(parseErrs) > 0 { - continue - } - - ad, ok := aggDiags[uri] - if !ok { - ad = []types.Diagnostic{} - } - - cache.SetAggregateDiagnostics(uri, ad) - - fd, ok := fileDiags[uri] - if !ok { - fd = []types.Diagnostic{} - } - - cache.SetFileDiagnostics(uri, fd) - } - - // handle the diagnostics for the workspace, under the detachedURI - ad, ok := aggDiags[detachedURI] - if !ok { - ad = []types.Diagnostic{} - } - - cache.SetAggregateDiagnostics(detachedURI, ad) - - return nil + return fileDiags } // astError is copied from OPA but drop details as I (charlieegan3) had issues unmarshalling the field. diff --git a/internal/lsp/lint_test.go b/internal/lsp/lint_test.go new file mode 100644 index 00000000..9db26d8e --- /dev/null +++ b/internal/lsp/lint_test.go @@ -0,0 +1,67 @@ +package lsp + +import ( + "reflect" + "testing" + + "github.com/styrainc/regal/internal/lsp/types" + "github.com/styrainc/regal/pkg/report" +) + +func TestConvertReportToDiagnostics(t *testing.T) { + t.Parallel() + + violation1 := report.Violation{ + Level: "error", + Description: "Mock Error", + Category: "mock_category", + Title: "mock_title", + Location: report.Location{File: "file1"}, + IsAggregate: false, + } + violation2 := report.Violation{ + Level: "warning", + Description: "Mock Warning", + Category: "mock_category", + Title: "mock_title", + Location: report.Location{File: ""}, + IsAggregate: true, + } + + rpt := &report.Report{ + Violations: []report.Violation{violation1, violation2}, + } + + expectedFileDiags := map[string][]types.Diagnostic{ + "file1": { + { + Severity: 2, + Range: getRangeForViolation(violation1), + Message: "Mock Error", + Source: "regal/mock_category", + Code: "mock_title", + CodeDescription: &types.CodeDescription{ + Href: "https://docs.styra.com/regal/rules/mock_category/mock_title", + }, + }, + }, + "workspaceRootURI": { + { + Severity: 3, + Range: getRangeForViolation(violation2), + Message: "Mock Warning", + Source: "regal/mock_category", + Code: "mock_title", + CodeDescription: &types.CodeDescription{ + Href: "https://docs.styra.com/regal/rules/mock_category/mock_title", + }, + }, + }, + } + + fileDiags := convertReportToDiagnostics(rpt, "workspaceRootURI") + + if !reflect.DeepEqual(fileDiags, expectedFileDiags) { + t.Errorf("Expected file diagnostics: %v, got: %v", expectedFileDiags, fileDiags) + } +} diff --git a/internal/lsp/race_off.go b/internal/lsp/race_off.go new file mode 100644 index 00000000..e7d4a886 --- /dev/null +++ b/internal/lsp/race_off.go @@ -0,0 +1,8 @@ +//go:build !race +// +build !race + +package lsp + +func isRaceEnabled() bool { + return false +} diff --git a/internal/lsp/race_on.go b/internal/lsp/race_on.go new file mode 100644 index 00000000..3d9ba5d0 --- /dev/null +++ b/internal/lsp/race_on.go @@ -0,0 +1,8 @@ +//go:build race +// +build race + +package lsp + +func isRaceEnabled() bool { + return true +} diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 69d3cc37..7535eb3b 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -65,25 +65,33 @@ const ( type LanguageServerOptions struct { ErrorLog io.Writer + + // WorkspaceDiagnosticsPoll, if set > 0 will cause a full workspace lint + // to run on this interval. This is intended to be used where eventing + // is not working, as expected. E.g. with a client that does not send + // changes or when running in extremely slow environments like GHA with + // the go race detector on. TODO, work out why this is required. + WorkspaceDiagnosticsPoll time.Duration } -func NewLanguageServer(opts *LanguageServerOptions) *LanguageServer { +func NewLanguageServer(ctx context.Context, opts *LanguageServerOptions) *LanguageServer { c := cache.NewCache() store := NewRegalStore() ls := &LanguageServer{ - cache: c, - regoStore: store, - errorLog: opts.ErrorLog, - diagnosticRequestFile: make(chan fileUpdateEvent, 10), - diagnosticRequestWorkspace: make(chan string, 10), - builtinsPositionFile: make(chan fileUpdateEvent, 10), - commandRequest: make(chan types.ExecuteCommandParams, 10), - templateFile: make(chan fileUpdateEvent, 10), - configWatcher: lsconfig.NewWatcher(&lsconfig.WatcherOpts{ErrorWriter: opts.ErrorLog}), - completionsManager: completions.NewDefaultManager(c, store), - webServer: web.NewServer(c), - loadedBuiltins: make(map[string]map[string]*ast.Builtin), + cache: c, + regoStore: store, + errorLog: opts.ErrorLog, + lintFileJobs: make(chan lintFileJob, 10), + lintWorkspaceJobs: make(chan lintWorkspaceJob, 10), + builtinsPositionJobs: make(chan lintFileJob, 10), + commandRequest: make(chan types.ExecuteCommandParams, 10), + templateFileJobs: make(chan lintFileJob, 10), + configWatcher: lsconfig.NewWatcher(&lsconfig.WatcherOpts{ErrorWriter: opts.ErrorLog}), + completionsManager: completions.NewDefaultManager(ctx, c, store), + webServer: web.NewServer(c), + loadedBuiltins: make(map[string]map[string]*ast.Builtin), + workspaceDiagnosticsPoll: opts.WorkspaceDiagnosticsPoll, } return ls @@ -95,9 +103,11 @@ type LanguageServer struct { regoStore storage.Store conn *jsonrpc2.Conn - configWatcher *lsconfig.Watcher - loadedConfig *config.Config - loadedBuiltins map[string]map[string]*ast.Builtin + configWatcher *lsconfig.Watcher + loadedConfig *config.Config + loadedConfigEnabledNonAggregateRules []string + loadedConfigEnabledAggregateRules []string + loadedBuiltins map[string]map[string]*ast.Builtin clientInitializationOptions types.InitializationOptions @@ -106,11 +116,11 @@ type LanguageServer struct { completionsManager *completions.Manager - diagnosticRequestFile chan fileUpdateEvent - diagnosticRequestWorkspace chan string - builtinsPositionFile chan fileUpdateEvent - commandRequest chan types.ExecuteCommandParams - templateFile chan fileUpdateEvent + commandRequest chan types.ExecuteCommandParams + lintWorkspaceJobs chan lintWorkspaceJob + lintFileJobs chan lintFileJob + builtinsPositionJobs chan lintFileJob + templateFileJobs chan lintFileJob webServer *web.Server @@ -119,15 +129,29 @@ type LanguageServer struct { loadedBuiltinsLock sync.RWMutex + // this is also used to lock the updates to the cache of enabled rules loadedConfigLock sync.Mutex + + workspaceDiagnosticsPoll time.Duration } -// fileUpdateEvent is sent to a channel when an update is required for a file. -type fileUpdateEvent struct { +// lintFileJob is sent to the lintFileJobs channel to trigger a +// diagnostic update for a file. +type lintFileJob struct { Reason string URI string } +// lintWorkspaceJob is sent to lintWorkspaceJobs when a full workspace +// diagnostic update is needed. +type lintWorkspaceJob struct { + Reason string + // OverwriteAggregates for a workspace is only run once at start up. All + // later updates to aggregate state is made as files are changed. + OverwriteAggregates bool + AggregateReportOnly bool +} + func (l *LanguageServer) Handle( ctx context.Context, conn *jsonrpc2.Conn, @@ -214,54 +238,162 @@ func (l *LanguageServer) SetConn(conn *jsonrpc2.Conn) { } func (l *LanguageServer) StartDiagnosticsWorker(ctx context.Context) { - for { - select { - case <-ctx.Done(): - return - case evt := <-l.diagnosticRequestFile: - bis := l.builtinsForCurrentCapabilities() + var wg sync.WaitGroup - // updateParse will not return an error when the parsing failed, - // but only when it was impossible - if _, err := updateParse(ctx, l.cache, l.regoStore, evt.URI, bis); err != nil { - l.logError(fmt.Errorf("failed to update module for %s: %w", evt.URI, err)) + wg.Add(1) - continue - } + go func() { + defer wg.Done() - // lint the file and send the diagnostics - if err := updateFileDiagnostics(ctx, l.cache, l.getLoadedConfig(), evt.URI, l.workspaceRootURI); err != nil { - l.logError(fmt.Errorf("failed to update file diagnostics: %w", err)) + for { + select { + case <-ctx.Done(): + return + case job := <-l.lintFileJobs: + bis := l.builtinsForCurrentCapabilities() - continue - } + // updateParse will not return an error when the parsing failed, + // but only when it was impossible + if _, err := updateParse(ctx, l.cache, l.regoStore, job.URI, bis); err != nil { + l.logError(fmt.Errorf("failed to update module for %s: %w", job.URI, err)) - if err := l.sendFileDiagnostics(ctx, evt.URI); err != nil { - l.logError(fmt.Errorf("failed to send diagnostic: %w", err)) + continue + } - continue + // lint the file and send the diagnostics + if err := updateFileDiagnostics( + ctx, + l.cache, + l.getLoadedConfig(), + job.URI, + l.workspaceRootURI, + // updateFileDiagnostics only ever updates the diagnostics + // of non aggregate rules + l.getEnabledNonAggregateRules(), + ); err != nil { + l.logError(fmt.Errorf("failed to update file diagnostics: %w", err)) + + continue + } + + if err := l.sendFileDiagnostics(ctx, job.URI); err != nil { + l.logError(fmt.Errorf("failed to send diagnostic: %w", err)) + + continue + } + + l.lintWorkspaceJobs <- lintWorkspaceJob{ + Reason: fmt.Sprintf("file %s %s", job.URI, job.Reason), + // this run is expected to used the cached aggregate state + // for other files. + // The aggregate state for this file will still be updated. + OverwriteAggregates: false, + // when a file has changed, then there is no need to run + // any other rules globally other than aggregate rules. + AggregateReportOnly: true, + } } + } + }() + + wg.Add(1) - // if the file has agg diagnostics, we trigger a run for the workspace as by changing this file, - // these may now be out of date - aggDiags, ok := l.cache.GetAggregateDiagnostics(evt.URI) - if ok && len(aggDiags) > 0 { - l.diagnosticRequestWorkspace <- fmt.Sprintf("file %q with aggregate violation changed", evt.URI) + workspaceLintRunBufferSize := 10 + workspaceLintRuns := make(chan lintWorkspaceJob, workspaceLintRunBufferSize) + + go func() { + defer wg.Done() + + for { + select { + case <-ctx.Done(): + return + case job := <-l.lintWorkspaceJobs: + // AggregateReportOnly is set when updating aggregate + // violations on character changes. Since these happen so + // frequently, we stop adding to the channel if there already + // jobs set to preserve performance + if job.AggregateReportOnly && len(workspaceLintRuns) > workspaceLintRunBufferSize/2 { + fmt.Fprintln(l.errorLog, "rate limiting aggregate reports") + + continue + } + + workspaceLintRuns <- job } - case <-l.diagnosticRequestWorkspace: - // results will be sent in response to the next workspace/diagnostics request - if err := updateAllDiagnostics(ctx, l.cache, l.getLoadedConfig(), l.workspaceRootURI); err != nil { - l.logError(fmt.Errorf("failed to update aggregate diagnostics (trigger): %w", err)) + } + }() + + if l.workspaceDiagnosticsPoll > 0 { + wg.Add(1) + + ticker := time.NewTicker(l.workspaceDiagnosticsPoll) + + go func() { + defer wg.Done() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + workspaceLintRuns <- lintWorkspaceJob{ + Reason: "poll ticker", + OverwriteAggregates: true, + } + } } + }() + } - // send diagnostics for all files - for fileURI := range l.cache.GetAllFiles() { - if err := l.sendFileDiagnostics(ctx, fileURI); err != nil { - l.logError(fmt.Errorf("failed to send diagnostic: %w", err)) + wg.Add(1) + + go func() { + defer wg.Done() + + for { + select { + case <-ctx.Done(): + return + case job := <-workspaceLintRuns: + // if there are no files in the cache, then there is no need to + // run the aggregate report. This can happen if the server is + // very slow to start up. + if len(l.cache.GetAllFiles()) == 0 { + continue + } + + targetRules := l.getEnabledAggregateRules() + if !job.AggregateReportOnly { + targetRules = append(targetRules, l.getEnabledNonAggregateRules()...) + } + + err := updateAllDiagnostics( + ctx, + l.cache, + l.getLoadedConfig(), + l.workspacePath(), + // this is intended to only be set to true once at start up, + // on following runs, cached aggregate data is used. + job.OverwriteAggregates, + job.AggregateReportOnly, + targetRules, + ) + if err != nil { + l.logError(fmt.Errorf("failed to update all diagnostics: %w", err)) + } + + for fileURI := range l.cache.GetAllFiles() { + if err := l.sendFileDiagnostics(ctx, fileURI); err != nil { + l.logError(fmt.Errorf("failed to send diagnostic: %w", err)) + } } } } - } + }() + + <-ctx.Done() + wg.Wait() } func (l *LanguageServer) StartHoverWorker(ctx context.Context) { @@ -269,8 +401,8 @@ func (l *LanguageServer) StartHoverWorker(ctx context.Context) { select { case <-ctx.Done(): return - case evt := <-l.builtinsPositionFile: - fileURI := evt.URI + case job := <-l.builtinsPositionJobs: + fileURI := job.URI if l.ignoreURI(fileURI) { continue @@ -318,6 +450,50 @@ func (l *LanguageServer) getLoadedConfig() *config.Config { return l.loadedConfig } +func (l *LanguageServer) getEnabledNonAggregateRules() []string { + l.loadedConfigLock.Lock() + defer l.loadedConfigLock.Unlock() + + return l.loadedConfigEnabledNonAggregateRules +} + +func (l *LanguageServer) getEnabledAggregateRules() []string { + l.loadedConfigLock.Lock() + defer l.loadedConfigLock.Unlock() + + return l.loadedConfigEnabledAggregateRules +} + +// loadEnabledRulesFromConfig is used to cache the enabled rules for the current +// config. These take some time to compute and only change when config changes, +// so we can store them on the server to speed up diagnostic runs. +func (l *LanguageServer) loadEnabledRulesFromConfig(ctx context.Context, cfg config.Config) error { + l.loadedConfigLock.Lock() + defer l.loadedConfigLock.Unlock() + + enabledRules, err := linter.NewLinter().WithUserConfig(cfg).DetermineEnabledRules(ctx) + if err != nil { + return fmt.Errorf("failed to determine enabled rules: %w", err) + } + + enabledAggregateRules, err := linter.NewLinter().WithUserConfig(cfg).DetermineEnabledAggregateRules(ctx) + if err != nil { + return fmt.Errorf("failed to determine enabled aggregate rules: %w", err) + } + + l.loadedConfigEnabledNonAggregateRules = []string{} + + for _, r := range enabledRules { + if !slices.Contains(enabledAggregateRules, r) { + l.loadedConfigEnabledNonAggregateRules = append(l.loadedConfigEnabledNonAggregateRules, r) + } + } + + l.loadedConfigEnabledAggregateRules = enabledAggregateRules + + return nil +} + func (l *LanguageServer) StartConfigWorker(ctx context.Context) { if err := l.configWatcher.Start(ctx); err != nil { l.logError(fmt.Errorf("failed to start config watcher: %w", err)) @@ -339,30 +515,29 @@ func (l *LanguageServer) StartConfigWorker(ctx context.Context) { var userConfig config.Config + // EOF errors are ignored here as then we just use the default config if err = yaml.Unmarshal(configFileBs, &userConfig); err != nil && !errors.Is(err, io.EOF) { l.logError(fmt.Errorf("failed to reload config: %w", err)) - return + continue } mergedConfig, err := config.LoadConfigWithDefaultsFromBundle(&rbundle.LoadedBundle, &userConfig) if err != nil { l.logError(fmt.Errorf("failed to load config: %w", err)) - return + continue } - // if the config is now blank, then we need to clear it l.loadedConfigLock.Lock() + l.loadedConfig = &mergedConfig + l.loadedConfigLock.Unlock() - if errors.Is(err, io.EOF) { - l.loadedConfig = nil - } else { - l.loadedConfig = &mergedConfig + err = l.loadEnabledRulesFromConfig(ctx, mergedConfig) + if err != nil { + l.logError(fmt.Errorf("failed to cache enabled rules: %w", err)) } - l.loadedConfigLock.Unlock() - // Capabilities URL may have changed, so we should reload it. cfg := l.getLoadedConfig() @@ -375,7 +550,7 @@ func (l *LanguageServer) StartConfigWorker(ctx context.Context) { if err != nil { l.logError(fmt.Errorf("failed to load capabilities for URL %q: %w", capsURL, err)) - return + continue } bis := rego.BuiltinsForCapabilities(caps) @@ -441,13 +616,13 @@ func (l *LanguageServer) StartConfigWorker(ctx context.Context) { } }() - l.diagnosticRequestWorkspace <- "config file changed" + l.lintWorkspaceJobs <- lintWorkspaceJob{Reason: "config file changed"} case <-l.configWatcher.Drop: l.loadedConfigLock.Lock() l.loadedConfig = nil l.loadedConfigLock.Unlock() - l.diagnosticRequestWorkspace <- "config file dropped" + l.lintWorkspaceJobs <- lintWorkspaceJob{Reason: "config file dropped"} } } } @@ -786,15 +961,15 @@ func (l *LanguageServer) StartWorkspaceStateWorker(ctx context.Context) { // next, check if there are any new files that are not ignored and // need to be loaded. We get new only so that files being worked // on are not loaded from disk during editing. - changedOrNewURIs, err := l.loadWorkspaceContents(ctx, true) + newURIs, err := l.loadWorkspaceContents(ctx, true) if err != nil { l.logError(fmt.Errorf("failed to refresh workspace contents: %w", err)) continue } - for _, cnURI := range changedOrNewURIs { - l.diagnosticRequestFile <- fileUpdateEvent{ + for _, cnURI := range newURIs { + l.lintFileJobs <- lintFileJob{ URI: cnURI, Reason: "internal/workspaceStateWorker/changedOrNewFile", } @@ -810,9 +985,9 @@ func (l *LanguageServer) StartTemplateWorker(ctx context.Context) { select { case <-ctx.Done(): return - case evt := <-l.templateFile: + case job := <-l.templateFileJobs: // determine the new contents for the file, if permitted - newContents, err := l.templateContentsForFile(evt.URI) + newContents, err := l.templateContentsForFile(job.URI) if err != nil { l.logError(fmt.Errorf("failed to template new file: %w", err)) @@ -822,12 +997,12 @@ func (l *LanguageServer) StartTemplateWorker(ctx context.Context) { // set the contents of the new file in the cache immediately as // these must be update to date in order for fixRenameParams // to work - l.cache.SetFileContents(evt.URI, newContents) + l.cache.SetFileContents(job.URI, newContents) var edits []any edits = append(edits, types.TextDocumentEdit{ - TextDocument: types.OptionalVersionedTextDocumentIdentifier{URI: evt.URI}, + TextDocument: types.OptionalVersionedTextDocumentIdentifier{URI: job.URI}, Edits: ComputeEdits("", newContents), }) @@ -835,7 +1010,7 @@ func (l *LanguageServer) StartTemplateWorker(ctx context.Context) { renameParams, err := l.fixRenameParams( "Rename file to match package path", &fixes.DirectoryPackageMismatch{}, - evt.URI, + job.URI, ) if err != nil { l.logError(fmt.Errorf("failed to fix directory package mismatch: %w", err)) @@ -844,7 +1019,7 @@ func (l *LanguageServer) StartTemplateWorker(ctx context.Context) { } // move the file and clean up any empty directories ifd required - fileURI := evt.URI + fileURI := job.URI if len(renameParams.Edit.DocumentChanges) > 0 { edits = append(edits, renameParams.Edit.DocumentChanges[0]) @@ -894,12 +1069,12 @@ func (l *LanguageServer) StartTemplateWorker(ctx context.Context) { } // finally, trigger a diagnostics run for the new file - updateEvent := fileUpdateEvent{ + updateEvent := lintFileJob{ Reason: "internal/templateNewFile", URI: fileURI, } - l.diagnosticRequestFile <- updateEvent + l.lintFileJobs <- updateEvent } } } @@ -1650,14 +1825,14 @@ func (l *LanguageServer) handleTextDocumentDidOpen( l.cache.SetFileContents(params.TextDocument.URI, params.TextDocument.Text) - evt := fileUpdateEvent{ + job := lintFileJob{ Reason: "textDocument/didOpen", URI: params.TextDocument.URI, } - l.diagnosticRequestFile <- evt + l.lintFileJobs <- job - l.builtinsPositionFile <- evt + l.builtinsPositionJobs <- job return struct{}{}, nil } @@ -1712,13 +1887,13 @@ func (l *LanguageServer) handleTextDocumentDidChange( l.cache.SetFileContents(params.TextDocument.URI, params.ContentChanges[0].Text) - evt := fileUpdateEvent{ + job := lintFileJob{ Reason: "textDocument/didChange", URI: params.TextDocument.URI, } - l.diagnosticRequestFile <- evt - l.builtinsPositionFile <- evt + l.lintFileJobs <- job + l.builtinsPositionJobs <- job return struct{}{}, nil } @@ -1861,12 +2036,12 @@ func (l *LanguageServer) handleTextDocumentFormatting( l.cache.SetFileContents(params.TextDocument.URI, newContent) - updateEvent := fileUpdateEvent{ + updateEvent := lintFileJob{ Reason: "internal/templateFormattingFallback", URI: params.TextDocument.URI, } - l.diagnosticRequestFile <- updateEvent + l.lintFileJobs <- updateEvent return ComputeEdits(oldContent, newContent), nil } @@ -1992,14 +2167,14 @@ func (l *LanguageServer) handleWorkspaceDidCreateFiles( return nil, fmt.Errorf("failed to update cache for uri %q: %w", createOp.URI, err) } - evt := fileUpdateEvent{ + job := lintFileJob{ Reason: "textDocument/didCreate", URI: createOp.URI, } - l.diagnosticRequestFile <- evt - l.builtinsPositionFile <- evt - l.templateFile <- evt + l.lintFileJobs <- job + l.builtinsPositionJobs <- job + l.templateFileJobs <- job } return struct{}{}, nil @@ -2068,15 +2243,15 @@ func (l *LanguageServer) handleWorkspaceDidRenameFiles( l.cache.SetFileContents(renameOp.NewURI, content) - evt := fileUpdateEvent{ + job := lintFileJob{ Reason: "textDocument/didRename", URI: renameOp.NewURI, } - l.diagnosticRequestFile <- evt - l.builtinsPositionFile <- evt + l.lintFileJobs <- job + l.builtinsPositionJobs <- job // if the file being moved is empty, we template it too (if empty) - l.templateFile <- evt + l.templateFileJobs <- job } return struct{}{}, nil @@ -2098,11 +2273,16 @@ func (l *LanguageServer) handleWorkspaceDiagnostic( return workspaceReport, nil } + wkspceDiags, ok := l.cache.GetFileDiagnostics(l.workspaceRootURI) + if !ok { + wkspceDiags = []types.Diagnostic{} + } + workspaceReport.Items = append(workspaceReport.Items, types.WorkspaceFullDocumentDiagnosticReport{ URI: l.workspaceRootURI, Kind: "full", Version: nil, - Items: l.cache.GetAllDiagnosticsForURI(l.workspaceRootURI), + Items: wkspceDiags, }) return workspaceReport, nil @@ -2207,6 +2387,20 @@ func (l *LanguageServer) handleInitialize( }, } + defaultConfig, err := config.LoadConfigWithDefaultsFromBundle(&rbundle.LoadedBundle, nil) + if err != nil { + return nil, fmt.Errorf("failed to load default config: %w", err) + } + + l.loadedConfigLock.Lock() + l.loadedConfig = &defaultConfig + l.loadedConfigLock.Unlock() + + err = l.loadEnabledRulesFromConfig(ctx, defaultConfig) + if err != nil { + l.logError(fmt.Errorf("failed to cache enabled rules: %w", err)) + } + if l.workspaceRootURI != "" { workspaceRootPath := l.workspacePath() @@ -2226,7 +2420,13 @@ func (l *LanguageServer) handleInitialize( l.webServer.SetWorkspaceURI(l.workspaceRootURI) - l.diagnosticRequestWorkspace <- "server initialize" + l.lintWorkspaceJobs <- lintWorkspaceJob{ + Reason: "server initialize", + // 'OverwriteAggregates' is set to populate the cache's + // initial aggregate state. Subsequent runs of lintWorkspaceJobs + // will not set this and use the cached state. + OverwriteAggregates: true, + } } return initializeResult, nil @@ -2306,7 +2506,7 @@ func (l *LanguageServer) handleInitialized( // if running without config, then we should send the diagnostic request now // otherwise it'll happen when the config is loaded if !l.configWatcher.IsWatching() { - l.diagnosticRequestWorkspace <- "initialized" + l.lintWorkspaceJobs <- lintWorkspaceJob{Reason: "server initialized"} } return struct{}{}, nil @@ -2337,6 +2537,15 @@ func (l *LanguageServer) handleWorkspaceDidChangeWatchedFiles( return nil, fmt.Errorf("failed to unmarshal params: %w", err) } + // this handles the case of a new config file being created when one did + // not exist before + if len(params.Changes) > 0 && strings.HasSuffix(params.Changes[0].URI, ".regal/config.yaml") { + configFile, err := config.FindConfig(l.workspacePath()) + if err == nil { + l.configWatcher.Watch(configFile.Name()) + } + } + // when a file is changed (saved), then we send trigger a full workspace lint regoFiles := make([]string, 0) @@ -2349,17 +2558,23 @@ func (l *LanguageServer) handleWorkspaceDidChangeWatchedFiles( } if len(regoFiles) > 0 { - l.diagnosticRequestWorkspace <- fmt.Sprintf( - "workspace/didChangeWatchedFiles (%s)", strings.Join(regoFiles, ", ")) + l.lintWorkspaceJobs <- lintWorkspaceJob{ + Reason: fmt.Sprintf("workspace/didChangeWatchedFiles (%s)", strings.Join(regoFiles, ", ")), + } } return struct{}{}, nil } func (l *LanguageServer) sendFileDiagnostics(ctx context.Context, fileURI string) error { + fileDiags, ok := l.cache.GetFileDiagnostics(fileURI) + if !ok { + fileDiags = []types.Diagnostic{} + } + resp := types.FileDiagnostics{ - Items: l.cache.GetAllDiagnosticsForURI(fileURI), URI: fileURI, + Items: fileDiags, } if err := l.conn.Notify(ctx, methodTextDocumentPublishDiagnostics, resp); err != nil { diff --git a/internal/lsp/server_aggregates_test.go b/internal/lsp/server_aggregates_test.go new file mode 100644 index 00000000..55c86733 --- /dev/null +++ b/internal/lsp/server_aggregates_test.go @@ -0,0 +1,478 @@ +package lsp + +import ( + "context" + "path/filepath" + "slices" + "strings" + "testing" + "time" + + "github.com/sourcegraph/jsonrpc2" + + "github.com/styrainc/regal/internal/lsp/types" + "github.com/styrainc/regal/pkg/report" +) + +//nolint:maintidx +func TestLanguageServerLintsUsingAggregateState(t *testing.T) { + t.Parallel() + + files := map[string]string{ + "foo.rego": `package foo + +import rego.v1 + +import data.bar +import data.baz +`, + "bar.rego": `package bar + +import rego.v1 +`, + "baz.rego": `package baz + +import rego.v1 +`, + ".regal/config.yaml": ``, + } + + messages := createMessageChannels(files) + + clientHandler := createClientHandler(t, messages) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tempDir := t.TempDir() + + _, connClient, err := createAndInitServer(ctx, newTestLogger(t), tempDir, files, clientHandler) + if err != nil { + t.Fatalf("failed to create and init language server: %s", err) + } + + timeout := time.NewTimer(determineTimeout()) + defer timeout.Stop() + + // no unresolved-imports at this stage + for { + var success bool + select { + case violations := <-messages["foo.rego"]: + if slices.Contains(violations, "unresolved-import") { + t.Logf("waiting for violations to not contain unresolved-import") + + continue + } + + success = true + case <-timeout.C: + t.Fatalf("timed out waiting for expected foo.rego diagnostics") + } + + if success { + break + } + } + + barURI := fileURIScheme + filepath.Join(tempDir, "bar.rego") + + err = connClient.Call(ctx, "textDocument/didChange", types.TextDocumentDidChangeParams{ + TextDocument: types.TextDocumentIdentifier{ + URI: barURI, + }, + ContentChanges: []types.TextDocumentContentChangeEvent{ + { + Text: `package qux + +import rego.v1 +`, + }, + }, + }, nil) + if err != nil { + t.Fatalf("failed to send didChange notification: %s", err) + } + + // unresolved-imports is now expected + timeout.Reset(determineTimeout()) + + for { + var success bool + select { + case violations := <-messages["foo.rego"]: + if !slices.Contains(violations, "unresolved-import") { + t.Log("waiting for violations to contain unresolved-import") + + continue + } + + success = true + case <-timeout.C: + t.Fatalf("timed out waiting for expected foo.rego diagnostics") + } + + if success { + break + } + } + + fooURI := fileURIScheme + filepath.Join(tempDir, "foo.rego") + + err = connClient.Call(ctx, "textDocument/didChange", types.TextDocumentDidChangeParams{ + TextDocument: types.TextDocumentIdentifier{ + URI: fooURI, + }, + ContentChanges: []types.TextDocumentContentChangeEvent{ + { + Text: `package foo + +import rego.v1 + +import data.baz +import data.qux # new name for bar.rego package +`, + }, + }, + }, nil) + if err != nil { + t.Fatalf("failed to send didChange notification: %s", err) + } + + // unresolved-imports is again not expected + timeout.Reset(determineTimeout()) + + for { + var success bool + select { + case violations := <-messages["foo.rego"]: + if slices.Contains(violations, "unresolved-import") { + t.Log("waiting for violations to not contain unresolved-import") + + continue + } + + success = true + case <-timeout.C: + t.Fatalf("timed out waiting for expected foo.rego diagnostics") + } + + if success { + break + } + } +} + +func TestLanguageServerUpdatesAggregateState(t *testing.T) { + t.Parallel() + + clientHandler := func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { + t.Logf("message %s", req.Method) + + return struct{}{}, nil + } + + files := map[string]string{ + "foo.rego": `package foo + +import rego.v1 + +import data.baz +`, + "bar.rego": `package bar + +import rego.v1 + +import data.quz +`, + ".regal/config.yaml": ``, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tempDir := t.TempDir() + + ls, connClient, err := createAndInitServer(ctx, newTestLogger(t), tempDir, files, clientHandler) + if err != nil { + t.Fatalf("failed to create and init language server: %s", err) + } + + // 1. check the Aggregates are set at start up + timeout := time.NewTimer(determineTimeout()) + + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + success := false + + select { + case <-ticker.C: + aggs := ls.cache.GetFileAggregates() + if len(aggs) == 0 { + t.Logf("server aggregates %d", len(aggs)) + + continue + } + + success = true + case <-timeout.C: + t.Fatalf("timed out waiting for file aggregates to be set") + } + + if success { + break + } + } + + determineImports := func(aggs map[string][]report.Aggregate) []string { + imports := []string{} + + unresolvedImportAggs, ok := aggs["imports/unresolved-import"] + if !ok { + t.Fatalf("expected imports/unresolved-import aggregate data") + } + + for _, entry := range unresolvedImportAggs { + if aggregateData, ok := entry["aggregate_data"].(map[string]any); ok { + if importsList, ok := aggregateData["imports"].([]any); ok { + for _, imp := range importsList { + if impMap, ok := imp.(map[string]any); ok { + if pathList, ok := impMap["path"].([]any); ok { + pathParts := []string{} + + for _, p := range pathList { + if pathStr, ok := p.(string); ok { + pathParts = append(pathParts, pathStr) + } + } + + imports = append(imports, strings.Join(pathParts, ".")) + } + } + } + } + } + } + + slices.Sort(imports) + + return imports + } + + imports := determineImports(ls.cache.GetFileAggregates()) + + if exp, got := []string{"baz", "quz"}, imports; !slices.Equal(exp, got) { + t.Fatalf("global state imports unexpected, got %v exp %v", got, exp) + } + + // 2. check the aggregates for a file are updated after an update + err = connClient.Call(ctx, "textDocument/didChange", types.TextDocumentDidChangeParams{ + TextDocument: types.TextDocumentIdentifier{ + URI: fileURIScheme + filepath.Join(tempDir, "bar.rego"), + }, + ContentChanges: []types.TextDocumentContentChangeEvent{ + { + Text: `package bar + +import rego.v1 + +import data.qux # changed +import data.wow # new +`, + }, + }, + }, nil) + if err != nil { + t.Fatalf("failed to send didChange notification: %s", err) + } + + timeout.Reset(determineTimeout()) + + for { + success := false + + select { + case <-ticker.C: + imports = determineImports(ls.cache.GetFileAggregates()) + + if exp, got := []string{"baz", "qux", "wow"}, imports; !slices.Equal(exp, got) { + t.Logf("global state imports unexpected, got %v exp %v", got, exp) + + continue + } + + success = true + case <-timeout.C: + t.Fatalf("timed out waiting for file aggregates to be set") + } + + if success { + break + } + } +} + +// nolint:maintidx +func TestLanguageServerAggregateViolationFixedAndReintroducedInUnviolatingFileChange(t *testing.T) { + t.Parallel() + + var err error + + tempDir := t.TempDir() + files := map[string]string{ + "foo.rego": `package foo + +import rego.v1 + +import data.bax # initially unresolved-import + +variable = "string" # use-assignment-operator +`, + "bar.rego": `package bar + +import rego.v1 +`, + ".regal/config.yaml": ``, + } + + messages := createMessageChannels(files) + + clientHandler := createClientHandler(t, messages) + + // set up the server and client connections + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, connClient, err := createAndInitServer(ctx, newTestLogger(t), tempDir, files, clientHandler) + if err != nil { + t.Fatalf("failed to create and init language server: %s", err) + } + + // wait for foo.rego to have the correct violations + timeout := time.NewTimer(determineTimeout()) + defer timeout.Stop() + + for { + var success bool + select { + case violations := <-messages["foo.rego"]: + if !slices.Contains(violations, "unresolved-import") { + t.Logf("waiting for violations to contain unresolved-import") + + continue + } + + if !slices.Contains(violations, "use-assignment-operator") { + t.Logf("waiting for violations to contain use-assignment-operator") + + continue + } + + success = true + case <-timeout.C: + t.Fatalf("timed out waiting for foo.rego diagnostics") + } + + if success { + break + } + } + + // update the contents of the bar.rego file to address the unresolved-import + barURI := fileURIScheme + filepath.Join(tempDir, "bar.rego") + + err = connClient.Call(ctx, "textDocument/didChange", types.TextDocumentDidChangeParams{ + TextDocument: types.TextDocumentIdentifier{ + URI: barURI, + }, + ContentChanges: []types.TextDocumentContentChangeEvent{ + { + Text: `package bax # package imported in foo.rego + +import rego.v1 +`, + }, + }, + }, nil) + if err != nil { + t.Fatalf("failed to send didChange notification: %s", err) + } + + // wait for foo.rego to have the correct violations + timeout.Reset(determineTimeout()) + + for { + var success bool + select { + case violations := <-messages["foo.rego"]: + if slices.Contains(violations, "unresolved-import") { + t.Logf("waiting for violations to not contain unresolved-import") + + continue + } + + if !slices.Contains(violations, "use-assignment-operator") { + t.Logf("use-assignment-operator should still be present") + + continue + } + + success = true + case <-timeout.C: + t.Fatalf("timed out waiting for foo.rego diagnostics") + } + + if success { + break + } + } + + // update the contents of the bar.rego to bring back the violation + err = connClient.Call(ctx, "textDocument/didChange", types.TextDocumentDidChangeParams{ + TextDocument: types.TextDocumentIdentifier{ + URI: barURI, + }, + ContentChanges: []types.TextDocumentContentChangeEvent{ + { + Text: `package bar # original package to bring back the violation + +import rego.v1 +`, + }, + }, + }, nil) + if err != nil { + t.Fatalf("failed to send didChange notification: %s", err) + } + + // check the violation is back + timeout.Reset(determineTimeout()) + + for { + var success bool + select { + case violations := <-messages["foo.rego"]: + if !slices.Contains(violations, "unresolved-import") { + t.Logf("waiting for violations to contain unresolved-import") + + continue + } + + if !slices.Contains(violations, "use-assignment-operator") { + t.Logf("use-assignment-operator should still be present") + + continue + } + + success = true + case <-timeout.C: + t.Fatalf("timed out waiting for foo.rego diagnostics") + } + + if success { + break + } + } +} diff --git a/internal/lsp/server_builtins_test.go b/internal/lsp/server_builtins_test.go new file mode 100644 index 00000000..98584b09 --- /dev/null +++ b/internal/lsp/server_builtins_test.go @@ -0,0 +1,32 @@ +package lsp + +import ( + "context" + "testing" +) + +// https://github.com/StyraInc/regal/issues/679 +func TestProcessBuiltinUpdateExitsOnMissingFile(t *testing.T) { + t.Parallel() + + ls := NewLanguageServer(context.Background(), &LanguageServerOptions{ + ErrorLog: newTestLogger(t), + }) + + if err := ls.processHoverContentUpdate(context.Background(), "file://missing.rego", "foo"); err != nil { + t.Fatal(err) + } + + if l := len(ls.cache.GetAllBuiltInPositions()); l != 0 { + t.Errorf("expected builtin positions to be empty, got %d items", l) + } + + contents, ok := ls.cache.GetFileContents("file://missing.rego") + if ok { + t.Errorf("expected file contents to be empty, got %s", contents) + } + + if len(ls.cache.GetAllFiles()) != 0 { + t.Errorf("expected files to be empty, got %v", ls.cache.GetAllFiles()) + } +} diff --git a/internal/lsp/server_config_test.go b/internal/lsp/server_config_test.go new file mode 100644 index 00000000..5d72bf64 --- /dev/null +++ b/internal/lsp/server_config_test.go @@ -0,0 +1,307 @@ +package lsp + +import ( + "context" + "os" + "path/filepath" + "slices" + "testing" + "time" + + "github.com/anderseknert/roast/pkg/encoding" + "github.com/sourcegraph/jsonrpc2" + + "github.com/styrainc/regal/internal/lsp/types" + "github.com/styrainc/regal/internal/lsp/uri" +) + +// TestLanguageServerParentDirConfig tests that regal config is loaded as it is for the +// Regal CLI, and that config files in a parent directory are loaded correctly +// even when the workspace is a child directory. +func TestLanguageServerParentDirConfig(t *testing.T) { + t.Parallel() + + var err error + + // this is the top level directory for the test + tempDir := t.TempDir() + // childDir will be the directory that the client is using as its workspace + + childDirName := "child" + childDir := filepath.Join(tempDir, childDirName) + + mainRegoContents := `package main + +import rego.v1 +allow := true +` + + files := map[string]string{ + childDirName + mainRegoFileName: mainRegoContents, + ".regal/config.yaml": `rules: + idiomatic: + directory-package-mismatch: + level: ignore + style: + opa-fmt: + level: error +`, + } + + // mainRegoFileURI is used throughout the test to refer to the main.rego file + // and so it is defined here for convenience + mainRegoFileURI := fileURIScheme + childDir + mainRegoFileName + + // set up the server and client connections + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + receivedMessages := make(chan types.FileDiagnostics, defaultBufferedChannelSize) + clientHandler := func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { + if req.Method == methodTextDocumentPublishDiagnostics { + var requestData types.FileDiagnostics + + if err2 := encoding.JSON().Unmarshal(*req.Params, &requestData); err2 != nil { + t.Fatalf("failed to unmarshal diagnostics: %s", err2) + } + + receivedMessages <- requestData + + return struct{}{}, nil + } + + t.Logf("unexpected request from server: %v", req) + + return struct{}{}, nil + } + + ls, _, err := createAndInitServer(ctx, newTestLogger(t), tempDir, files, clientHandler) + if err != nil { + t.Fatalf("failed to create and init language server: %s", err) + } + + if got, exp := ls.workspaceRootURI, uri.FromPath(ls.clientIdentifier, tempDir); exp != got { + t.Fatalf("expected client root URI to be %s, got %s", exp, got) + } + + timeout := time.NewTimer(determineTimeout()) + defer timeout.Stop() + + for { + var success bool + select { + case requestData := <-receivedMessages: + success = testRequestDataCodes(t, requestData, mainRegoFileURI, []string{"opa-fmt"}) + case <-timeout.C: + t.Fatalf("timed out waiting for file diagnostics to be sent") + } + + if success { + break + } + } + + // User updates config file contents in parent directory that is not + // part of the workspace + newConfigContents := `rules: + idiomatic: + directory-package-mismatch: + level: ignore + style: + opa-fmt: + level: ignore +` + + path := filepath.Join(tempDir, ".regal/config.yaml") + if err := os.WriteFile(path, []byte(newConfigContents), 0o600); err != nil { + t.Fatalf("failed to write new config file: %s", err) + } + + // validate that the client received a new, empty diagnostics notification for the file + timeout.Reset(determineTimeout()) + + for { + var success bool + select { + case requestData := <-receivedMessages: + success = testRequestDataCodes(t, requestData, mainRegoFileURI, []string{}) + case <-timeout.C: + t.Fatalf("timed out waiting for file diagnostics to be sent") + } + + if success { + break + } + } +} + +func TestLanguageServerCachesEnabledRulesAndUsesDefaultConfig(t *testing.T) { + t.Parallel() + + var err error + + tempDir := t.TempDir() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // no op handler + clientHandler := func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { + t.Logf("message received: %s", req.Method) + + return struct{}{}, nil + } + + ls, connClient, err := createAndInitServer(ctx, newTestLogger(t), tempDir, map[string]string{}, clientHandler) + if err != nil { + t.Fatalf("failed to create and init language server: %s", err) + } + + if got, exp := ls.workspaceRootURI, uri.FromPath(ls.clientIdentifier, tempDir); exp != got { + t.Fatalf("expected client root URI to be %s, got %s", exp, got) + } + + timeout := time.NewTimer(3 * time.Second) + ticker := time.NewTicker(500 * time.Millisecond) + + for { + var success bool + select { + case <-ticker.C: + enabledRules := ls.getEnabledNonAggregateRules() + enabledAggRules := ls.getEnabledAggregateRules() + + if len(enabledRules) == 0 || len(enabledAggRules) == 0 { + t.Log("no enabled rules yet...") + + continue + } + + success = true + case <-timeout.C: + t.Fatalf("timed out waiting for enabled rules to be correct") + } + + if success { + break + } + } + + err = os.MkdirAll(filepath.Join(tempDir, ".regal"), 0o755) + if err != nil { + t.Fatalf("failed to create regal config dir: %s", err) + } + + configContents := ` +rules: + idiomatic: + directory-package-mismatch: + level: ignore + imports: + unresolved-import: + level: ignore +` + + err = os.WriteFile(filepath.Join(tempDir, ".regal/config.yaml"), []byte(configContents), 0o600) + if err != nil { + t.Fatalf("failed to write regal config file: %s", err) + } + + // this event is sent to allow the server to detect the new config + if err := connClient.Call(ctx, "workspace/didChangeWatchedFiles", types.WorkspaceDidChangeWatchedFilesParams{ + Changes: []types.FileEvent{ + { + URI: fileURIScheme + filepath.Join(tempDir, ".regal/config.yaml"), + Type: 1, // created + }, + }, + }, nil); err != nil { + t.Fatalf("failed to send didChange notification: %s", err) + } + + timeout.Reset(determineTimeout()) + + for { + var success bool + select { + case <-ticker.C: + enabledRules := ls.getEnabledNonAggregateRules() + enabledAggRules := ls.getEnabledAggregateRules() + + if slices.Contains(enabledRules, "directory-package-mismatch") { + t.Log("enabledRules still contains directory-package-mismatch") + + continue + } + + if slices.Contains(enabledAggRules, "unresolved-import") { + t.Log("enabledAggRules still contains unresolved-import") + + continue + } + + success = true + case <-timeout.C: + t.Fatalf("timed out waiting for enabled rules to be correct") + } + + if success { + break + } + } + + configContents2 := ` +rules: + style: + opa-fmt: + level: ignore + idiomatic: + directory-package-mismatch: + level: error + imports: + unresolved-import: + level: error +` + + err = os.WriteFile(filepath.Join(tempDir, ".regal/config.yaml"), []byte(configContents2), 0o600) + if err != nil { + t.Fatalf("failed to write regal config file: %s", err) + } + + timeout.Reset(determineTimeout()) + + for { + var success bool + select { + case <-ticker.C: + enabledRules := ls.getEnabledNonAggregateRules() + enabledAggRules := ls.getEnabledAggregateRules() + + if slices.Contains(enabledRules, "opa-fmt") { + t.Log("enabledRules still contains opa-fmt") + + continue + } + + if !slices.Contains(enabledRules, "directory-package-mismatch") { + t.Log("enabledRules must contain directory-package-mismatch") + + continue + } + + if !slices.Contains(enabledAggRules, "unresolved-import") { + t.Log("enabledAggRules must contain unresolved-import") + + continue + } + + success = true + case <-timeout.C: + t.Fatalf("timed out waiting for enabled rules to be correct") + } + + if success { + break + } + } +} diff --git a/internal/lsp/server_formatting_test.go b/internal/lsp/server_formatting_test.go new file mode 100644 index 00000000..03f7138a --- /dev/null +++ b/internal/lsp/server_formatting_test.go @@ -0,0 +1,78 @@ +package lsp + +import ( + "context" + "encoding/json" + "testing" + + "github.com/sourcegraph/jsonrpc2" + + "github.com/styrainc/regal/internal/lsp/types" +) + +func TestFormatting(t *testing.T) { + t.Parallel() + + tempDir := t.TempDir() + + // set up the server and client connections + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clientHandler := func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { + t.Fatalf("unexpected request: %v", req) + + return struct{}{}, nil + } + + ls, connClient, err := createAndInitServer(ctx, newTestLogger(t), tempDir, map[string]string{}, clientHandler) + if err != nil { + t.Fatalf("failed to create and init language server: %s", err) + } + + mainRegoURI := fileURIScheme + tempDir + mainRegoFileName + + // Simple as possible — opa fmt should just remove a newline + content := `package main + +` + ls.cache.SetFileContents(mainRegoURI, content) + + bs, err := json.Marshal(&types.DocumentFormattingParams{ + TextDocument: types.TextDocumentIdentifier{URI: mainRegoURI}, + Options: types.FormattingOptions{}, + }) + if err != nil { + t.Fatalf("failed to marshal document formatting params: %v", err) + } + + var msg json.RawMessage = bs + + req := &jsonrpc2.Request{Params: &msg} + + res, err := ls.handleTextDocumentFormatting(ctx, connClient, req) + if err != nil { + t.Fatalf("failed to format document: %s", err) + } + + if edits, ok := res.([]types.TextEdit); ok { + if len(edits) != 1 { + t.Fatalf("expected 1 edit, got %d", len(edits)) + } + + expectRange := types.Range{ + Start: types.Position{Line: 1, Character: 0}, + End: types.Position{Line: 2, Character: 0}, + } + + if edits[0].Range != expectRange { + t.Fatalf("expected range to be %v, got %v", expectRange, edits[0].Range) + } + + if edits[0].NewText != "" { + t.Fatalf("expected new text to be empty, got %s", edits[0].NewText) + } + } else { + t.Fatalf("expected edits to be []types.TextEdit, got %T", res) + } +} diff --git a/internal/lsp/server_multi_file_test.go b/internal/lsp/server_multi_file_test.go new file mode 100644 index 00000000..31720514 --- /dev/null +++ b/internal/lsp/server_multi_file_test.go @@ -0,0 +1,166 @@ +package lsp + +import ( + "context" + "path/filepath" + "slices" + "testing" + "time" + + "github.com/styrainc/regal/internal/lsp/types" +) + +// TestLanguageServerMultipleFiles tests that changes to multiple files are handled correctly. When there are multiple +// files in the workspace, the diagnostics worker also processes aggregate violations, there are also changes to when +// workspace diagnostics are run, this test validates that the correct diagnostics are sent to the client in this +// scenario. +// +// nolint:maintidx +func TestLanguageServerMultipleFiles(t *testing.T) { + t.Parallel() + + // set up the workspace content with some example rego and regal config + tempDir := t.TempDir() + + files := map[string]string{ + "authz.rego": `package authz + +import rego.v1 + +import data.admins.users + +default allow := false + +allow if input.user in users +`, + "admins.rego": `package admins + +import rego.v1 + +users = {"alice", "bob"} +`, + "ignored/foo.rego": `package ignored + +foo = 1 +`, + ".regal/config.yaml": ` +rules: + idiomatic: + directory-package-mismatch: + level: ignore +ignore: + files: + - ignored/*.rego +`, + } + + messages := createMessageChannels(files) + + clientHandler := createClientHandler(t, messages) + + // set up the server and client connections + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, connClient, err := createAndInitServer(ctx, newTestLogger(t), tempDir, files, clientHandler) + if err != nil { + t.Fatalf("failed to create and init language server: %s", err) + } + + // validate that the client received a diagnostics notification for authz.rego + timeout := time.NewTimer(determineTimeout()) + defer timeout.Stop() + + for { + var success bool + select { + case violations := <-messages["authz.rego"]: + if !slices.Contains(violations, "prefer-package-imports") { + t.Logf("waiting for violations to contain prefer-package-imports, have: %v", violations) + + continue + } + + success = true + case <-timeout.C: + t.Fatalf("timed out waiting for authz.rego diagnostics to be sent") + } + + if success { + break + } + } + + // validate that the client received a diagnostics notification for admins.rego + timeout.Reset(determineTimeout()) + + for { + var success bool + select { + case violations := <-messages["admins.rego"]: + if !slices.Contains(violations, "use-assignment-operator") { + t.Logf("waiting for violations to contain use-assignment-operator, have: %v", violations) + + continue + } + + success = true + case <-timeout.C: + t.Fatalf("timed out waiting for admins.rego diagnostics to be sent") + } + + if success { + break + } + } + + // 3. Client sends textDocument/didChange notification with new contents + // for authz.rego no response to the call is expected + if err := connClient.Call(ctx, "textDocument/didChange", types.TextDocumentDidChangeParams{ + TextDocument: types.TextDocumentIdentifier{ + URI: fileURIScheme + filepath.Join(tempDir, "authz.rego"), + }, + ContentChanges: []types.TextDocumentContentChangeEvent{ + { + Text: `package authz + +import rego.v1 + +import data.admins # fixes prefer-package-imports + +default allow := false + +# METADATA +# description: Allow only admins +# entrypoint: true # fixes no-defined-entrypoint +allow if input.user in admins.users +`, + }, + }, + }, nil); err != nil { + t.Fatalf("failed to send didChange notification: %s", err) + } + + // authz.rego should now have no violations + timeout.Reset(determineTimeout()) + + for { + var success bool + select { + case violations := <-messages["authz.rego"]: + if len(violations) > 0 { + t.Logf("waiting for violations to be empty for authz.rego, have: %v", violations) + + continue + } + + success = true + case <-timeout.C: + t.Fatalf("timed out waiting for authz.rego diagnostics to be sent") + } + + if success { + break + } + } +} diff --git a/internal/lsp/server_rename_test.go b/internal/lsp/server_rename_test.go new file mode 100644 index 00000000..04b02107 --- /dev/null +++ b/internal/lsp/server_rename_test.go @@ -0,0 +1,79 @@ +package lsp + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/styrainc/regal/internal/lsp/cache" + "github.com/styrainc/regal/internal/lsp/clients" + "github.com/styrainc/regal/pkg/config" + "github.com/styrainc/regal/pkg/fixer/fixes" +) + +func TestLanguageServerFixRenameParams(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + + if err := os.MkdirAll(filepath.Join(tmpDir, "workspace/foo/bar"), 0o755); err != nil { + t.Fatalf("failed to create directory: %s", err) + } + + ctx := context.Background() + + l := NewLanguageServer(ctx, &LanguageServerOptions{ErrorLog: newTestLogger(t)}) + c := cache.NewCache() + f := &fixes.DirectoryPackageMismatch{} + + fileURL := fmt.Sprintf("file://%s/workspace/foo/bar/policy.rego", tmpDir) + + c.SetFileContents(fileURL, "package authz.main.rules") + + l.clientIdentifier = clients.IdentifierVSCode + l.workspaceRootURI = fmt.Sprintf("file://%s/workspace", tmpDir) + l.cache = c + l.loadedConfig = &config.Config{ + Rules: map[string]config.Category{ + "idiomatic": { + "directory-package-mismatch": config.Rule{ + Level: "ignore", + Extra: map[string]any{ + "exclude-test-suffix": true, + }, + }, + }, + }, + } + + params, err := l.fixRenameParams("fix my file!", f, fileURL) + if err != nil { + t.Fatalf("failed to fix rename params: %s", err) + } + + if params.Label != "fix my file!" { + t.Fatalf("expected label to be 'Fix my file!', got %s", params.Label) + } + + if len(params.Edit.DocumentChanges) != 1 { + t.Fatalf("expected 1 document change, got %d", len(params.Edit.DocumentChanges)) + } + + change := params.Edit.DocumentChanges[0] + + if change.Kind != "rename" { + t.Fatalf("expected kind to be 'rename', got %s", change.Kind) + } + + if change.OldURI != fileURL { + t.Fatalf("expected old URI to be %s, got %s", fileURL, change.OldURI) + } + + expectedNewURI := fmt.Sprintf("file://%s/workspace/authz/main/rules/policy.rego", tmpDir) + + if change.NewURI != expectedNewURI { + t.Fatalf("expected new URI to be %s, got %s", expectedNewURI, change.NewURI) + } +} diff --git a/internal/lsp/server_single_file_test.go b/internal/lsp/server_single_file_test.go new file mode 100644 index 00000000..7b1306d8 --- /dev/null +++ b/internal/lsp/server_single_file_test.go @@ -0,0 +1,364 @@ +package lsp + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/anderseknert/roast/pkg/encoding" + "github.com/sourcegraph/jsonrpc2" + + "github.com/styrainc/regal/internal/lsp/types" +) + +// TestLanguageServerSingleFile tests that changes to a single file and Regal config are handled correctly by the +// language server by making updates to both and validating that the correct diagnostics are sent to the client. +// +// This test also ensures that updating the config to point to a non-default engine and capabilities version works +// and causes that engine's builtins to work with completions. +// +//nolint:maintidx +func TestLanguageServerSingleFile(t *testing.T) { + t.Parallel() + + // set up the workspace content with some example rego and regal config + tempDir := t.TempDir() + mainRegoURI := fileURIScheme + tempDir + mainRegoFileName + + if err := os.MkdirAll(filepath.Join(tempDir, ".regal"), 0o755); err != nil { + t.Fatalf("failed to create .regal directory: %s", err) + } + + mainRegoContents := `package main + +import rego.v1 +allow = true +` + + files := map[string]string{ + "main.rego": mainRegoContents, + ".regal/config.yaml": ` +rules: + idiomatic: + directory-package-mismatch: + level: ignore`, + } + + for f, fc := range files { + if err := os.WriteFile(filepath.Join(tempDir, f), []byte(fc), 0o600); err != nil { + t.Fatalf("failed to write file %s: %s", f, err) + } + } + + receivedMessages := make(chan types.FileDiagnostics, defaultBufferedChannelSize) + clientHandler := func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { + if req.Method == methodTextDocumentPublishDiagnostics { + var requestData types.FileDiagnostics + + if err := encoding.JSON().Unmarshal(*req.Params, &requestData); err != nil { + t.Fatalf("failed to unmarshal diagnostics: %s", err) + } + + receivedMessages <- requestData + + return struct{}{}, nil + } + + t.Fatalf("unexpected request: %v", req) + + return struct{}{}, nil + } + + // set up the server and client connections + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, connClient, err := createAndInitServer(ctx, newTestLogger(t), tempDir, files, clientHandler) + if err != nil { + t.Fatalf("failed to create and init language server: %s", err) + } + + // validate that the client received a diagnostics notification for the file + timeout := time.NewTimer(determineTimeout()) + defer timeout.Stop() + + for { + var success bool + select { + case requestData := <-receivedMessages: + success = testRequestDataCodes(t, requestData, mainRegoURI, []string{"opa-fmt", "use-assignment-operator"}) + case <-timeout.C: + t.Fatalf("timed out waiting for file diagnostics to be sent") + } + + if success { + break + } + } + + // Client sends textDocument/didChange notification with new contents for main.rego + // no response to the call is expected + if err := connClient.Call(ctx, "textDocument/didChange", types.TextDocumentDidChangeParams{ + TextDocument: types.TextDocumentIdentifier{ + URI: mainRegoURI, + }, + ContentChanges: []types.TextDocumentContentChangeEvent{ + { + Text: `package main +import rego.v1 +allow := true +`, + }, + }, + }, nil); err != nil { + t.Fatalf("failed to send didChange notification: %s", err) + } + + // validate that the client received a new diagnostics notification for the file + timeout.Reset(determineTimeout()) + + for { + var success bool + select { + case requestData := <-receivedMessages: + success = testRequestDataCodes(t, requestData, mainRegoURI, []string{"opa-fmt"}) + case <-timeout.C: + t.Fatalf("timed out waiting for file diagnostics to be sent") + } + + if success { + break + } + } + + // config update is caught by the config watcher + newConfigContents := ` +rules: + idiomatic: + directory-package-mismatch: + level: ignore + style: + opa-fmt: + level: ignore +` + + if err := os.WriteFile(filepath.Join(tempDir, ".regal/config.yaml"), []byte(newConfigContents), 0o600); err != nil { + t.Fatalf("failed to write new config file: %s", err) + } + + // validate that the client received a new, empty diagnostics notification for the file + timeout.Reset(determineTimeout()) + + for { + var success bool + select { + case requestData := <-receivedMessages: + if requestData.URI != mainRegoURI { + t.Logf("expected diagnostics to be sent for main.rego, got %s", requestData.URI) + + continue + } + + codes := []string{} + for _, d := range requestData.Items { + codes = append(codes, d.Code) + } + + if len(requestData.Items) != 0 { + t.Logf("expected empty diagnostics, got %v", codes) + + continue + } + + success = testRequestDataCodes(t, requestData, mainRegoURI, []string{}) + case <-timeout.C: + t.Fatalf("timed out waiting for main.rego diagnostics to be sent") + } + + if success { + break + } + } + + // Client sends new config with an EOPA capabilities file specified. + newConfigContents = ` +rules: + style: + opa-fmt: + level: ignore + idiomatic: + directory-package-mismatch: + level: ignore +capabilities: + from: + engine: eopa + version: v1.23.0 +` + + if err := os.WriteFile(filepath.Join(tempDir, ".regal/config.yaml"), []byte(newConfigContents), 0o600); err != nil { + t.Fatalf("failed to write new config file: %s", err) + } + + // validate that the client received a new, empty diagnostics notification for the file + timeout.Reset(determineTimeout()) + + for { + var success bool + select { + case requestData := <-receivedMessages: + if requestData.URI != mainRegoURI { + t.Logf("expected diagnostics to be sent for main.rego, got %s", requestData.URI) + + break + } + + codes := []string{} + for _, d := range requestData.Items { + codes = append(codes, d.Code) + } + + if len(requestData.Items) != 0 { + t.Logf("expected empty diagnostics, got %v", codes) + + continue + } + + success = testRequestDataCodes(t, requestData, mainRegoURI, []string{}) + case <-timeout.C: + t.Fatalf("timed out waiting for main.rego diagnostics to be sent") + } + + if success { + break + } + } + + // Client sends textDocument/didChange notification with new + // contents for main.rego no response to the call is expected. We added + // the start of an EOPA-specific call, so if the capabilities were + // loaded correctly, we should see a completion later after we ask for + // it. + if err := connClient.Call(ctx, "textDocument/didChange", types.TextDocumentDidChangeParams{ + TextDocument: types.TextDocumentIdentifier{ + URI: mainRegoURI, + }, + ContentChanges: []types.TextDocumentContentChangeEvent{ + { + Text: `package main +import rego.v1 + +# METADATA +# entrypoint: true +allow := neo4j.q +`, + }, + }, + }, nil); err != nil { + t.Fatalf("failed to send didChange notification: %s", err) + } + + // validate that the client received a new diagnostics notification for the file + timeout.Reset(determineTimeout()) + + for { + var success bool + select { + case requestData := <-receivedMessages: + if requestData.URI != mainRegoURI { + t.Logf("expected diagnostics to be sent for main.rego, got %s", requestData.URI) + + break + } + + codes := []string{} + for _, d := range requestData.Items { + codes = append(codes, d.Code) + } + + if len(requestData.Items) != 0 { + t.Logf("expected empty diagnostics, got %v", codes) + + continue + } + + success = testRequestDataCodes(t, requestData, mainRegoURI, []string{}) + case <-timeout.C: + t.Fatalf("timed out waiting for file diagnostics to be sent") + } + + if success { + break + } + } + + // 7. With our new config applied, and the file updated, we can ask the + // LSP for a completion. We expect to see neo4j.query show up. Since + // neo4j.query is an EOPA-specific builtin, it should never appear if + // we're using the normal OPA capabilities file. + timeout.Reset(determineTimeout()) + + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + foundNeo4j := false + + select { + case <-ticker.C: + // Create a new context with timeout for each request, this is + // timed out after using the default as the GHA runner is super + // slow in the race detector + reqCtx, reqCtxCancel := context.WithTimeout(ctx, determineTimeout()) + + resp := make(map[string]any) + err := connClient.Call(reqCtx, "textDocument/completion", types.CompletionParams{ + TextDocument: types.TextDocumentIdentifier{ + URI: mainRegoURI, + }, + Position: types.Position{ + Line: 5, + Character: 16, + }, + }, &resp) + + reqCtxCancel() + + if err != nil { + t.Fatalf("failed to send completion request: %s", err) + } + + itemsList, ok := resp["items"].([]any) + if !ok { + t.Fatalf("failed to cast resp[items] to []any") + } + + for _, itemI := range itemsList { + item, ok := itemI.(map[string]any) + if !ok { + t.Fatalf("completion item '%+v' was not a JSON object", itemI) + } + + label, ok := item["label"].(string) + if !ok { + t.Fatalf("completion item label is not a string: %+v", item["label"]) + } + + if label == "neo4j.query" { + foundNeo4j = true + + break + } + } + + t.Logf("waiting for neo4j.query in completion results for neo4j.q, got %v", itemsList) + case <-timeout.C: + t.Fatalf("timed out waiting for file completion to correct") + } + + if foundNeo4j { + break + } + } +} diff --git a/internal/lsp/server_template_test.go b/internal/lsp/server_template_test.go index 8921ff2f..bd9bd24f 100644 --- a/internal/lsp/server_template_test.go +++ b/internal/lsp/server_template_test.go @@ -98,8 +98,10 @@ func TestTemplateContentsForFile(t *testing.T) { } } + ctx := context.Background() + // create a new language server - s := NewLanguageServer(&LanguageServerOptions{ErrorLog: newTestLogger(t)}) + s := NewLanguageServer(ctx, &LanguageServerOptions{ErrorLog: newTestLogger(t)}) s.workspaceRootURI = uri.FromPath(clients.IdentifierGeneric, td) fileURI := uri.FromPath(clients.IdentifierGeneric, filepath.Join(td, tc.FileKey)) @@ -137,27 +139,10 @@ func TestNewFileTemplating(t *testing.T) { `, } - for f, fc := range files { - if err := os.MkdirAll(filepath.Dir(filepath.Join(tempDir, f)), 0o755); err != nil { - t.Fatalf("failed to create directory %s: %s", filepath.Dir(filepath.Join(tempDir, f)), err) - } - - if err := os.WriteFile(filepath.Join(tempDir, f), []byte(fc), 0o600); err != nil { - t.Fatalf("failed to write file %s: %s", f, err) - } - } - // set up the server and client connections ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ls := NewLanguageServer(&LanguageServerOptions{ - ErrorLog: newTestLogger(t), - }) - - go ls.StartConfigWorker(ctx) - go ls.StartTemplateWorker(ctx) - receivedMessages := make(chan []byte, 10) clientHandler := func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { bs, err := json.MarshalIndent(req.Params, "", " ") @@ -170,30 +155,15 @@ func TestNewFileTemplating(t *testing.T) { return struct{}{}, nil } - connServer, connClient := createConnections(ctx, ls.Handle, clientHandler) - - ls.SetConn(connServer) - - // 1. Client sends initialize request - request := types.InitializeParams{ - RootURI: fileURIScheme + tempDir, - ClientInfo: types.Client{Name: "go test"}, - } - - var response types.InitializeResult - - if err := connClient.Call(ctx, "initialize", request, &response); err != nil { - t.Fatalf("failed to send initialize request: %s", err) + ls, connClient, err := createAndInitServer(ctx, newTestLogger(t), tempDir, files, clientHandler) + if err != nil { + t.Fatalf("failed to create and init language server: %s", err) } - // 2. Client sends initialized notification no response to the call is - // expected - if err := connClient.Call(ctx, "initialized", struct{}{}, nil); err != nil { - t.Fatalf("failed to send initialized notification: %s", err) - } + go ls.StartTemplateWorker(ctx) - // 3. wait for the server to load it's config - timeout := time.NewTimer(defaultTimeout) + // wait for the server to load it's config + timeout := time.NewTimer(determineTimeout()) select { case <-timeout.C: t.Fatalf("timed out waiting for server to load config") @@ -207,7 +177,7 @@ func TestNewFileTemplating(t *testing.T) { } } - // 4. Touch the new file on disk + // Touch the new file on disk newFilePath := filepath.Join(tempDir, "foo/bar/policy_test.rego") newFileURI := uri.FromPath(clients.IdentifierGeneric, newFilePath) expectedNewFileURI := uri.FromPath(clients.IdentifierGeneric, filepath.Join(tempDir, "foo/bar_test/policy_test.rego")) @@ -220,7 +190,7 @@ func TestNewFileTemplating(t *testing.T) { t.Fatalf("failed to write file %s: %s", newFilePath, err) } - // 5. Client sends workspace/didCreateFiles notification + // Client sends workspace/didCreateFiles notification if err := connClient.Call(ctx, "workspace/didCreateFiles", types.WorkspaceDidCreateFilesParams{ Files: []types.WorkspaceDidCreateFilesParamsCreatedFile{ {URI: newFileURI}, @@ -229,9 +199,8 @@ func TestNewFileTemplating(t *testing.T) { t.Fatalf("failed to send didChange notification: %s", err) } - // 6. Validate that the client received a workspace edit - timeout = time.NewTimer(3 * time.Second) - defer timeout.Stop() + // Validate that the client received a workspace edit + timeout.Reset(determineTimeout()) expectedMessage := fmt.Sprintf(`{ "edit": { diff --git a/internal/lsp/server_test.go b/internal/lsp/server_test.go index 209e4148..a96e30ef 100644 --- a/internal/lsp/server_test.go +++ b/internal/lsp/server_test.go @@ -2,7 +2,6 @@ package lsp import ( "context" - "encoding/json" "errors" "fmt" "io" @@ -19,11 +18,8 @@ import ( "github.com/anderseknert/roast/pkg/encoding" "github.com/sourcegraph/jsonrpc2" - "github.com/styrainc/regal/internal/lsp/cache" - "github.com/styrainc/regal/internal/lsp/clients" "github.com/styrainc/regal/internal/lsp/types" - "github.com/styrainc/regal/pkg/config" - "github.com/styrainc/regal/pkg/fixer/fixes" + "github.com/styrainc/regal/internal/util" ) const mainRegoFileName = "/main.rego" @@ -35,497 +31,125 @@ const defaultTimeout = 20 * time.Second const defaultBufferedChannelSize = 5 -const fileURIScheme = "file://" - -// TestLanguageServerSingleFile tests that changes to a single file and Regal config are handled correctly by the -// language server by making updates to both and validating that the correct diagnostics are sent to the client. -// -// This test also ensures that updating the config to point to a non-default engine and capabilities version works -// and causes that engine's builtins to work with completions. -// -//nolint:maintidx -func TestLanguageServerSingleFile(t *testing.T) { - t.Parallel() - - // set up the workspace content with some example rego and regal config - tempDir := t.TempDir() - mainRegoURI := fileURIScheme + tempDir + mainRegoFileName - - if err := os.MkdirAll(filepath.Join(tempDir, ".regal"), 0o755); err != nil { - t.Fatalf("failed to create .regal directory: %s", err) +// determineTimeout returns a timeout duration based on whether +// the test suite is running with race detection, if so, a more permissive +// timeout is used. +func determineTimeout() time.Duration { + if isRaceEnabled() { + // based on the upper bound here, 20x slower + // https://go.dev/doc/articles/race_detector#Runtime_Overheads + return defaultTimeout * 20 } - mainRegoContents := `package main - -import rego.v1 -allow = true -` + return defaultTimeout +} - files := map[string]string{ - "main.rego": mainRegoContents, - ".regal/config.yaml": ` -rules: - idiomatic: - directory-package-mismatch: - level: ignore`, - } +const fileURIScheme = "file://" - for f, fc := range files { - if err := os.WriteFile(filepath.Join(tempDir, f), []byte(fc), 0o600); err != nil { - t.Fatalf("failed to write file %s: %s", f, err) - } - } +// NewTestLogger returns an io.Writer that logs to the given testing.T. +// This is helpful as it can be used to have the server log to the test logger +// in server tests. It is protected from being written to after the test is +// over. +func newTestLogger(t *testing.T) io.Writer { + t.Helper() - // set up the server and client connections - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + tl := &testLogger{t: t, open: true} - ls := NewLanguageServer(&LanguageServerOptions{ - ErrorLog: newTestLogger(t), + // using cleanup ensure that no goroutines attempt to write to the logger + // after the test has been cleaned up + t.Cleanup(func() { + tl.mu.Lock() + defer tl.mu.Unlock() + tl.open = false }) - go ls.StartDiagnosticsWorker(ctx) - go ls.StartConfigWorker(ctx) - - receivedMessages := make(chan types.FileDiagnostics, defaultBufferedChannelSize) - clientHandler := func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { - if req.Method == methodTextDocumentPublishDiagnostics { - var requestData types.FileDiagnostics - - if err := encoding.JSON().Unmarshal(*req.Params, &requestData); err != nil { - t.Fatalf("failed to unmarshal diagnostics: %s", err) - } - - receivedMessages <- requestData - - return struct{}{}, nil - } - - t.Fatalf("unexpected request: %v", req) - - return struct{}{}, nil - } - - connServer, connClient := createConnections(ctx, ls.Handle, clientHandler) - - ls.SetConn(connServer) - - // 1. Client sends initialize request - request := types.InitializeParams{ - RootURI: fileURIScheme + tempDir, - ClientInfo: types.Client{Name: "go test"}, - } - - var response types.InitializeResult - - if err := connClient.Call(ctx, "initialize", request, &response); err != nil { - t.Fatalf("failed to send initialize request: %s", err) - } - - // validate that the server responded with the correct capabilities, and that the correct root URI was set on the - // server - if response.Capabilities.DiagnosticProvider.Identifier != "rego" { - t.Fatalf( - "expected diagnostic provider identifier to be rego, got %s", - response.Capabilities.DiagnosticProvider.Identifier, - ) - } - - if ls.workspaceRootURI != request.RootURI { - t.Fatalf("expected client root URI to be %s, got %s", request.RootURI, ls.workspaceRootURI) - } - - // validate that the file contents from the workspace are loaded during the initialize request - contents, ok := ls.cache.GetFileContents(mainRegoURI) - if !ok { - t.Fatalf("expected file contents to be cached") - } - - if contents != mainRegoContents { - t.Fatalf("expected file contents to be %s, got %s", mainRegoContents, contents) - } - - _, ok = ls.cache.GetModule(mainRegoURI) - if !ok { - t.Fatalf("expected module to have been parsed and cached for main.rego") - } - - // 2. Client sends initialized notification - // no response to the call is expected - if err := connClient.Call(ctx, "initialized", struct{}{}, nil); err != nil { - t.Fatalf("failed to send initialized notification: %s", err) - } - - // validate that the client received a diagnostics notification for the file - timeout := time.NewTimer(defaultTimeout) - defer timeout.Stop() - - for { - var success bool - select { - case requestData := <-receivedMessages: - success = testRequestDataCodes(t, requestData, mainRegoURI, []string{"opa-fmt", "use-assignment-operator"}) - case <-timeout.C: - t.Fatalf("timed out waiting for file diagnostics to be sent") - } - - if success { - break - } - } - - // 3. Client sends textDocument/didChange notification with new contents for main.rego - // no response to the call is expected - if err := connClient.Call(ctx, "textDocument/didChange", types.TextDocumentDidChangeParams{ - TextDocument: types.TextDocumentIdentifier{ - URI: mainRegoURI, - }, - ContentChanges: []types.TextDocumentContentChangeEvent{ - { - Text: `package main -import rego.v1 -allow := true -`, - }, - }, - }, nil); err != nil { - t.Fatalf("failed to send didChange notification: %s", err) - } - - // validate that the client received a new diagnostics notification for the file - timeout = time.NewTimer(defaultTimeout) - defer timeout.Stop() - - for { - var success bool - select { - case requestData := <-receivedMessages: - success = testRequestDataCodes(t, requestData, mainRegoURI, []string{"opa-fmt"}) - case <-timeout.C: - t.Fatalf("timed out waiting for file diagnostics to be sent") - } - - if success { - break - } - } - - // 4. Client sends workspace/didChangeWatchedFiles notification with new config - newConfigContents := ` -rules: - idiomatic: - directory-package-mismatch: - level: ignore - style: - opa-fmt: - level: ignore -` - - if err := os.WriteFile(filepath.Join(tempDir, ".regal/config.yaml"), []byte(newConfigContents), 0o600); err != nil { - t.Fatalf("failed to write new config file: %s", err) - } - - // validate that the client received a new, empty diagnostics notification for the file - timeout = time.NewTimer(defaultTimeout) - defer timeout.Stop() - - for { - var success bool - select { - case requestData := <-receivedMessages: - if requestData.URI != mainRegoURI { - t.Logf("expected diagnostics to be sent for main.rego, got %s", requestData.URI) - - break - } - - if len(requestData.Items) != 0 { - t.Logf("expected 0 diagnostic, got %d", len(requestData.Items)) - - break - } - - success = testRequestDataCodes(t, requestData, mainRegoURI, []string{}) - case <-timeout.C: - t.Fatalf("timed out waiting for file diagnostics to be sent") - } - - if success { - break - } - } - - // 5. Client sends new config with an EOPA capabilities file specified. - newConfigContents = ` -rules: - style: - opa-fmt: - level: ignore - idiomatic: - directory-package-mismatch: - level: ignore -capabilities: - from: - engine: eopa - version: v1.23.0 -` - - if err := os.WriteFile(filepath.Join(tempDir, ".regal/config.yaml"), []byte(newConfigContents), 0o600); err != nil { - t.Fatalf("failed to write new config file: %s", err) - } - - // validate that the client received a new, empty diagnostics notification for the file - timeout = time.NewTimer(defaultTimeout) - defer timeout.Stop() - - for { - var success bool - select { - case requestData := <-receivedMessages: - if requestData.URI != mainRegoURI { - t.Logf("expected diagnostics to be sent for main.rego, got %s", requestData.URI) - - break - } - - if len(requestData.Items) != 0 { - t.Logf("expected 0 diagnostic, got %d", len(requestData.Items)) - - break - } - - success = testRequestDataCodes(t, requestData, mainRegoURI, []string{}) - case <-timeout.C: - t.Fatalf("timed out waiting for file diagnostics to be sent") - } - - if success { - break - } - } - - // 6. Client sends textDocument/didChange notification with new - // contents for main.rego no response to the call is expected. We added - // the start of an EOPA-specific call, so if the capabilities were - // loaded correctly, we should see a completion later after we ask for - // it. - if err := connClient.Call(ctx, "textDocument/didChange", types.TextDocumentDidChangeParams{ - TextDocument: types.TextDocumentIdentifier{ - URI: mainRegoURI, - }, - ContentChanges: []types.TextDocumentContentChangeEvent{ - { - Text: `package main -import rego.v1 -allow := neo4j.q -`, - }, - }, - }, nil); err != nil { - t.Fatalf("failed to send didChange notification: %s", err) - } - - // validate that the client received a new diagnostics notification for the file - timeout = time.NewTimer(defaultTimeout) - defer timeout.Stop() - - for { - var success bool - select { - case requestData := <-receivedMessages: - if requestData.URI != mainRegoURI { - t.Logf("expected diagnostics to be sent for main.rego, got %s", requestData.URI) - break - } - - if len(requestData.Items) != 0 { - t.Logf("expected 0 diagnostic, got %d", len(requestData.Items)) + return tl +} - break - } +type testLogger struct { + t *testing.T + open bool + mu sync.RWMutex +} - success = testRequestDataCodes(t, requestData, mainRegoURI, []string{}) - case <-timeout.C: - t.Fatalf("timed out waiting for file diagnostics to be sent") - } +func (tl *testLogger) Write(p []byte) (n int, err error) { + tl.mu.RLock() + defer tl.mu.RUnlock() - if success { - break - } + if !tl.open { + return 0, errors.New("cannot log, test is over") } - // 7. With our new config applied, and the file updated, we can ask the - // LSP for a completion. We expect to see neo4j.query show up. Since - // neo4j.query is an EOPA-specific builtin, it should never appear if - // we're using the normal OPA capabilities file. - timeout = time.NewTimer(defaultTimeout) - defer timeout.Stop() - - ticker := time.NewTicker(500 * time.Millisecond) - defer ticker.Stop() - - for { - foundNeo4j := false - - select { - case <-timeout.C: - t.Fatalf("timed out waiting for file completion to correct") - case <-ticker.C: - // Create a new context with timeout for each request - reqCtx, cancel := context.WithTimeout(ctx, time.Millisecond*100) - - resp := make(map[string]any) - err := connClient.Call(reqCtx, "textDocument/completion", types.CompletionParams{ - TextDocument: types.TextDocumentIdentifier{ - URI: mainRegoURI, - }, - Position: types.Position{ - Line: 2, - Character: 16, - }, - }, &resp) - - cancel() - - if err != nil { - t.Fatalf("failed to send completion notification: %s", err) - } - - itemsList, ok := resp["items"].([]any) - if !ok { - t.Fatalf("failed to cast resp[items] to []any") - } - - for _, itemI := range itemsList { - item, ok := itemI.(map[string]any) - if !ok { - t.Fatalf("completion item '%+v' was not a JSON object", itemI) - } - - label, ok := item["label"].(string) - if !ok { - t.Fatalf("completion item label is not a string: %+v", item["label"]) - } + tl.t.Log(strings.TrimSpace(string(p))) - if label == "neo4j.query" { - foundNeo4j = true + return len(p), nil +} - break - } - } +func createAndInitServer( + ctx context.Context, + logger io.Writer, + tempDir string, + files map[string]string, + clientHandler func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error), +) ( + *LanguageServer, + *jsonrpc2.Conn, + error, +) { + var err error - t.Logf("waiting for neo4j.query in completion results for neo4j.q, got %v", itemsList) + for f, fc := range files { + err = os.MkdirAll(filepath.Dir(filepath.Join(tempDir, f)), 0o755) + if err != nil { + return nil, nil, fmt.Errorf("failed to create directory: %w", err) } - if foundNeo4j { - break + err = os.WriteFile(filepath.Join(tempDir, f), []byte(fc), 0o600) + if err != nil { + return nil, nil, fmt.Errorf("failed to write file: %w", err) } } -} - -// TestLanguageServerMultipleFiles tests that changes to multiple files are handled correctly. When there are multiple -// files in the workspace, the diagnostics worker also processes aggregate violations, there are also changes to when -// workspace diagnostics are run, this test validates that the correct diagnostics are sent to the client in this -// scenario. -// -// nolint:maintidx -func TestLanguageServerMultipleFiles(t *testing.T) { - t.Parallel() - - // set up the workspace content with some example rego and regal config - tempDir := t.TempDir() - authzRegoURI := fileURIScheme + tempDir + "/authz.rego" - adminsRegoURI := fileURIScheme + tempDir + "/admins.rego" - ignoredRegoURI := fileURIScheme + tempDir + "/ignored/foo.rego" - - files := map[string]string{ - "authz.rego": `package authz - -import rego.v1 - -import data.admins.users - -default allow := false - -allow if input.user in users -`, - "admins.rego": `package admins - -import rego.v1 - -users = {"alice", "bob"} -`, - "ignored/foo.rego": `package ignored - -foo = 1 -`, - ".regal/config.yaml": ` -rules: - idiomatic: - directory-package-mismatch: - level: ignore -ignore: - files: - - ignored/*.rego -`, - } - - if err := os.MkdirAll(filepath.Join(tempDir, ".regal"), 0o755); err != nil { - t.Fatalf("failed to create .regal directory: %s", err) - } - if err := os.MkdirAll(filepath.Join(tempDir, "ignored"), 0o755); err != nil { - t.Fatalf("failed to create ignored directory: %s", err) - } - - for f, fc := range files { - if err := os.WriteFile(filepath.Join(tempDir, f), []byte(fc), 0o600); err != nil { - t.Fatalf("failed to write file %s: %s", f, err) - } + // This is set due to eventing being so slow in go test -race that we + // get flakes. TODO, work out how to avoid needing this in lsp tests. + pollingInterval := time.Duration(0) + if isRaceEnabled() { + pollingInterval = 10 * time.Second } // set up the server and client connections - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ls := NewLanguageServer(ctx, &LanguageServerOptions{ + ErrorLog: logger, + WorkspaceDiagnosticsPoll: pollingInterval, + }) - ls := NewLanguageServer(&LanguageServerOptions{ErrorLog: newTestLogger(t)}) go ls.StartDiagnosticsWorker(ctx) go ls.StartConfigWorker(ctx) - authzFileMessages := make(chan types.FileDiagnostics, defaultBufferedChannelSize) - adminsFileMessages := make(chan types.FileDiagnostics, defaultBufferedChannelSize) - ignoredFileMessages := make(chan types.FileDiagnostics, defaultBufferedChannelSize) - clientHandler := func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { - if req.Method != "textDocument/publishDiagnostics" { - t.Log("unexpected request method:", req.Method) - - return struct{}{}, nil - } - - var requestData types.FileDiagnostics - if err := encoding.JSON().Unmarshal(*req.Params, &requestData); err != nil { - t.Fatalf("failed to unmarshal diagnostics: %s", err) - } + netConnServer, netConnClient := net.Pipe() - switch requestData.URI { - case authzRegoURI: - authzFileMessages <- requestData - case adminsRegoURI: - adminsFileMessages <- requestData - case ignoredRegoURI: - ignoredFileMessages <- requestData - default: - t.Logf("unexpected diagnostics for file: %s", requestData.URI) - } + connServer := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewBufferedStream(netConnServer, jsonrpc2.VSCodeObjectCodec{}), + jsonrpc2.HandlerWithError(ls.Handle), + ) - return struct{}{}, nil - } + connClient := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewBufferedStream(netConnClient, jsonrpc2.VSCodeObjectCodec{}), + jsonrpc2.HandlerWithError(clientHandler), + ) - connServer, connClient := createConnections(ctx, ls.Handle, clientHandler) + go func() { + <-ctx.Done() + // we need only close the pipe connections as the jsonrpc2.Conn accept + // the ctx + _ = netConnClient.Close() + _ = netConnServer.Close() + }() ls.SetConn(connServer) - // 1. Client sends initialize request request := types.InitializeParams{ RootURI: fileURIScheme + tempDir, ClientInfo: types.Client{Name: "go test"}, @@ -533,440 +157,68 @@ ignore: var response types.InitializeResult - if err := connClient.Call(ctx, "initialize", request, &response); err != nil { - t.Fatalf("failed to send initialize request: %s", err) + err = connClient.Call(ctx, "initialize", request, &response) + if err != nil { + return nil, nil, fmt.Errorf("failed to initialize %w", err) } // 2. Client sends initialized notification // no response to the call is expected - if err := connClient.Call(ctx, "initialized", struct{}{}, nil); err != nil { - t.Fatalf("failed to send initialized notification: %s", err) - } - - // validate that the client received a diagnostics notification for authz.rego - timeout := time.NewTimer(defaultTimeout) - defer timeout.Stop() - - for { - var success bool - select { - case diags := <-authzFileMessages: - success = testRequestDataCodes(t, diags, authzRegoURI, []string{"prefer-package-imports"}) - case <-timeout.C: - t.Fatalf("timed out waiting for authz.rego diagnostics to be sent") - } - - if success { - break - } - } - - // validate that the client received a diagnostics notification admins.rego - timeout = time.NewTimer(defaultTimeout) - defer timeout.Stop() - - for { - var success bool - select { - case diags := <-adminsFileMessages: - success = testRequestDataCodes(t, diags, adminsRegoURI, []string{"use-assignment-operator"}) - case <-timeout.C: - t.Fatalf("timed out waiting for admins.rego diagnostics to be sent") - } - - if success { - break - } - } - - // 3. Client sends textDocument/didChange notification with new contents for authz.rego - // no response to the call is expected - if err := connClient.Call(ctx, "textDocument/didChange", types.TextDocumentDidChangeParams{ - TextDocument: types.TextDocumentIdentifier{ - URI: authzRegoURI, - }, - ContentChanges: []types.TextDocumentContentChangeEvent{ - { - Text: `package authz - -import rego.v1 - -import data.admins - -default allow := false - -allow if input.user in admins.users -`, - }, - }, - }, nil); err != nil { - t.Fatalf("failed to send didChange notification: %s", err) - } - - // authz.rego should now have no violations - timeout = time.NewTimer(defaultTimeout) - defer timeout.Stop() - - for { - var success bool - select { - case diags := <-authzFileMessages: - success = testRequestDataCodes(t, diags, authzRegoURI, []string{}) - case <-timeout.C: - t.Fatalf("timed out waiting for authz.rego diagnostics to be sent") - } - - if success { - break - } - } - - // we should also receive a diagnostics notification for admins.rego, since it is in the workspace, but it has not - // been changed, so the violations should be the same. - timeout = time.NewTimer(defaultTimeout) - defer timeout.Stop() - - for { - var success bool - select { - case requestData := <-adminsFileMessages: - success = testRequestDataCodes(t, requestData, adminsRegoURI, []string{"use-assignment-operator"}) - case <-timeout.C: - t.Fatalf("timed out waiting for admins.rego diagnostics to be sent") - } - - if success { - break - } - } -} - -// https://github.com/StyraInc/regal/issues/679 -func TestProcessBuiltinUpdateExitsOnMissingFile(t *testing.T) { - t.Parallel() - - ls := NewLanguageServer(&LanguageServerOptions{ - ErrorLog: newTestLogger(t), - }) - - if err := ls.processHoverContentUpdate(context.Background(), "file://missing.rego", "foo"); err != nil { - t.Fatal(err) - } - - if l := len(ls.cache.GetAllBuiltInPositions()); l != 0 { - t.Errorf("expected builtin positions to be empty, got %d items", l) - } - - contents, ok := ls.cache.GetFileContents("file://missing.rego") - if ok { - t.Errorf("expected file contents to be empty, got %s", contents) - } - - if len(ls.cache.GetAllFiles()) != 0 { - t.Errorf("expected files to be empty, got %v", ls.cache.GetAllFiles()) - } -} - -func TestFormatting(t *testing.T) { - t.Parallel() - - tempDir := t.TempDir() - - // set up the server and client connections - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ls := NewLanguageServer(&LanguageServerOptions{ - ErrorLog: newTestLogger(t), - }) - go ls.StartConfigWorker(ctx) - - clientHandler := func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { - t.Fatalf("unexpected request: %v", req) - - return struct{}{}, nil - } - - connServer, connClient := createConnections(ctx, ls.Handle, clientHandler) - - ls.SetConn(connServer) - - // 1. Client sends initialize request - request := types.InitializeParams{ - RootURI: fileURIScheme + tempDir, - ClientInfo: types.Client{Name: "go test"}, - } - - var response types.InitializeResult - - if err := connClient.Call(ctx, "initialize", request, &response); err != nil { - t.Fatalf("failed to send initialize request: %s", err) - } - - mainRegoURI := fileURIScheme + tempDir + mainRegoFileName - - // Simple as possible — opa fmt should just remove a newline - content := `package main - -` - ls.cache.SetFileContents(mainRegoURI, content) - - bs, err := json.Marshal(&types.DocumentFormattingParams{ - TextDocument: types.TextDocumentIdentifier{URI: mainRegoURI}, - Options: types.FormattingOptions{}, - }) + err = connClient.Call(ctx, "initialized", struct{}{}, nil) if err != nil { - t.Fatalf("failed to marshal document formatting params: %v", err) + return nil, nil, fmt.Errorf("failed to complete initialized %w", err) } - var msg json.RawMessage = bs - - req := &jsonrpc2.Request{Params: &msg} - - res, err := ls.handleTextDocumentFormatting(ctx, connClient, req) - if err != nil { - t.Fatalf("failed to format document: %s", err) - } - - if edits, ok := res.([]types.TextEdit); ok { - if len(edits) != 1 { - t.Fatalf("expected 1 edit, got %d", len(edits)) - } - - expectRange := types.Range{ - Start: types.Position{Line: 1, Character: 0}, - End: types.Position{Line: 2, Character: 0}, - } - - if edits[0].Range != expectRange { - t.Fatalf("expected range to be %v, got %v", expectRange, edits[0].Range) - } - - if edits[0].NewText != "" { - t.Fatalf("expected new text to be empty, got %s", edits[0].NewText) - } - } else { - t.Fatalf("expected edits to be []types.TextEdit, got %T", res) - } + return ls, connClient, nil } -// TestLanguageServerParentDirConfig tests that regal config is loaded as it is for the -// Regal CLI, and that config files in a parent directory are loaded correctly -// even when the workspace is a child directory. -func TestLanguageServerParentDirConfig(t *testing.T) { - t.Parallel() - - var err error - - // this is the top level directory for the test - parentDir := t.TempDir() - // childDir will be the directory that the client is using as its workspace - childDirName := "child" - childDir := filepath.Join(parentDir, childDirName) - - for _, dir := range []string{childDirName, ".regal"} { - err = os.MkdirAll(filepath.Join(parentDir, dir), 0o755) - if err != nil { - t.Fatalf("failed to create %q directory under parent: %s", dir, err) - } - } - - mainRegoContents := `package main - -import rego.v1 -allow := true -` - - files := map[string]string{ - childDirName + mainRegoFileName: mainRegoContents, - ".regal/config.yaml": `rules: - idiomatic: - directory-package-mismatch: - level: ignore - style: - opa-fmt: - level: error -`, - } - - for f, fc := range files { - if err := os.WriteFile(filepath.Join(parentDir, f), []byte(fc), 0o600); err != nil { - t.Fatalf("failed to write file %s: %s", f, err) - } - } - - // mainRegoFileURI is used throughout the test to refer to the main.rego file - // and so it is defined here for convenience - mainRegoFileURI := fileURIScheme + childDir + mainRegoFileName - - // set up the server and client connections - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ls := NewLanguageServer(&LanguageServerOptions{ - ErrorLog: newTestLogger(t), - }) - go ls.StartDiagnosticsWorker(ctx) - go ls.StartConfigWorker(ctx) - - receivedMessages := make(chan types.FileDiagnostics, defaultBufferedChannelSize) - clientHandler := func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { - if req.Method == methodTextDocumentPublishDiagnostics { - var requestData types.FileDiagnostics - - if err2 := encoding.JSON().Unmarshal(*req.Params, &requestData); err2 != nil { - t.Fatalf("failed to unmarshal diagnostics: %s", err2) - } +func createClientHandler( + t *testing.T, + messages map[string]chan []string, +) func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { + t.Helper() - receivedMessages <- requestData + return func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { + if req.Method != "textDocument/publishDiagnostics" { + t.Log("unexpected request method:", req.Method) return struct{}{}, nil } - t.Logf("unexpected request from server: %v", req) - - return struct{}{}, nil - } - - connServer, connClient := createConnections(ctx, ls.Handle, clientHandler) - - ls.SetConn(connServer) - - // Client sends initialize request - request := types.InitializeParams{ - RootURI: fileURIScheme + childDir, - ClientInfo: types.Client{Name: "go test"}, - } - - var response types.InitializeResult - - if err := connClient.Call(ctx, "initialize", request, &response); err != nil { - t.Fatalf("failed to send initialize request: %s", err) - } - - if ls.workspaceRootURI != request.RootURI { - t.Fatalf("expected client root URI to be %s, got %s", request.RootURI, ls.workspaceRootURI) - } - - // Client sends initialized notification - // the response to the call is expected to be empty and is ignored - if err := connClient.Call(ctx, "initialized", struct{}{}, nil); err != nil { - t.Fatalf("failed to send initialized notification: %s", err) - } - - timeout := time.NewTimer(defaultTimeout) - defer timeout.Stop() + var requestData types.FileDiagnostics - for { - var success bool - select { - case requestData := <-receivedMessages: - success = testRequestDataCodes(t, requestData, mainRegoFileURI, []string{"opa-fmt"}) - case <-timeout.C: - t.Fatalf("timed out waiting for file diagnostics to be sent") + err = encoding.JSON().Unmarshal(*req.Params, &requestData) + if err != nil { + t.Fatalf("failed to unmarshal diagnostics: %s", err) } - if success { - break + violations := make([]string, len(requestData.Items)) + for i, item := range requestData.Items { + violations[i] = item.Code } - } - // User updates config file contents in parent directory that is not - // part of the workspace - newConfigContents := `rules: - idiomatic: - directory-package-mismatch: - level: ignore - style: - opa-fmt: - level: ignore -` + slices.Sort(violations) - path := filepath.Join(parentDir, ".regal/config.yaml") - if err := os.WriteFile(path, []byte(newConfigContents), 0o600); err != nil { - t.Fatalf("failed to write new config file: %s", err) - } - - // validate that the client received a new, empty diagnostics notification for the file - timeout = time.NewTimer(defaultTimeout) - defer timeout.Stop() + fileBase := filepath.Base(requestData.URI) + t.Log("queue", fileBase, len(messages[fileBase])) - for { - var success bool select { - case requestData := <-receivedMessages: - success = testRequestDataCodes(t, requestData, mainRegoFileURI, []string{}) - case <-timeout.C: - t.Fatalf("timed out waiting for file diagnostics to be sent") + case messages[fileBase] <- violations: + case <-time.After(1 * time.Second): + t.Fatalf("timeout writing to messages channel for %s", fileBase) } - if success { - break - } + return struct{}{}, nil } } -func TestLanguageServerFixRenameParams(t *testing.T) { - t.Parallel() - - tmpDir := t.TempDir() - - if err := os.MkdirAll(filepath.Join(tmpDir, "workspace/foo/bar"), 0o755); err != nil { - t.Fatalf("failed to create directory: %s", err) +func createMessageChannels(files map[string]string) map[string]chan []string { + messages := make(map[string]chan []string) + for _, file := range util.Keys(files) { + messages[file] = make(chan []string, 10) } - l := NewLanguageServer(&LanguageServerOptions{ErrorLog: newTestLogger(t)}) - c := cache.NewCache() - f := &fixes.DirectoryPackageMismatch{} - - fileURL := fmt.Sprintf("file://%s/workspace/foo/bar/policy.rego", tmpDir) - - c.SetFileContents(fileURL, "package authz.main.rules") - - l.clientIdentifier = clients.IdentifierVSCode - l.workspaceRootURI = fmt.Sprintf("file://%s/workspace", tmpDir) - l.cache = c - l.loadedConfig = &config.Config{ - Rules: map[string]config.Category{ - "idiomatic": { - "directory-package-mismatch": config.Rule{ - Level: "ignore", - Extra: map[string]any{ - "exclude-test-suffix": true, - }, - }, - }, - }, - } - - params, err := l.fixRenameParams("fix my file!", f, fileURL) - if err != nil { - t.Fatalf("failed to fix rename params: %s", err) - } - - if params.Label != "fix my file!" { - t.Fatalf("expected label to be 'Fix my file!', got %s", params.Label) - } - - if len(params.Edit.DocumentChanges) != 1 { - t.Fatalf("expected 1 document change, got %d", len(params.Edit.DocumentChanges)) - } - - change := params.Edit.DocumentChanges[0] - - if change.Kind != "rename" { - t.Fatalf("expected kind to be 'rename', got %s", change.Kind) - } - - if change.OldURI != fileURL { - t.Fatalf("expected old URI to be %s, got %s", fileURL, change.OldURI) - } - - expectedNewURI := fmt.Sprintf("file://%s/workspace/authz/main/rules/policy.rego", tmpDir) - - if change.NewURI != expectedNewURI { - t.Fatalf("expected new URI to be %s, got %s", expectedNewURI, change.NewURI) - } + return messages } func testRequestDataCodes(t *testing.T, requestData types.FileDiagnostics, fileURI string, codes []string) bool { @@ -998,71 +250,3 @@ func testRequestDataCodes(t *testing.T, requestData types.FileDiagnostics, fileU return true } - -func createConnections( - ctx context.Context, - serverHandler, clientHandler func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error), -) (*jsonrpc2.Conn, *jsonrpc2.Conn) { - netConnServer, netConnClient := net.Pipe() - - connServer := jsonrpc2.NewConn( - ctx, - jsonrpc2.NewBufferedStream(netConnServer, jsonrpc2.VSCodeObjectCodec{}), - jsonrpc2.HandlerWithError(serverHandler), - ) - - connClient := jsonrpc2.NewConn( - ctx, - jsonrpc2.NewBufferedStream(netConnClient, jsonrpc2.VSCodeObjectCodec{}), - jsonrpc2.HandlerWithError(clientHandler), - ) - - go func() { - <-ctx.Done() - // we need only close the pipe connections as the jsonrpc2.Conn accept - // the ctx - _ = netConnClient.Close() - _ = netConnServer.Close() - }() - - return connServer, connClient -} - -// NewTestLogger returns an io.Writer that logs to the given testing.T. -// This is helpful as it can be used to have the server log to the test logger -// in server tests. It is protected from being written to after the test is -// over. -func newTestLogger(t *testing.T) io.Writer { - t.Helper() - - tl := &testLogger{t: t, open: true} - - // using cleanup ensure that no goroutines attempt to write to the logger - // after the test has been cleaned up - t.Cleanup(func() { - tl.mu.Lock() - defer tl.mu.Unlock() - tl.open = false - }) - - return tl -} - -type testLogger struct { - t *testing.T - open bool - mu sync.RWMutex -} - -func (tl *testLogger) Write(p []byte) (n int, err error) { - tl.mu.RLock() - defer tl.mu.RUnlock() - - if !tl.open { - return 0, errors.New("cannot log, test is over") - } - - tl.t.Log(strings.TrimSpace(string(p))) - - return len(p), nil -} diff --git a/pkg/linter/linter.go b/pkg/linter/linter.go index 48876635..c6839550 100644 --- a/pkg/linter/linter.go +++ b/pkg/linter/linter.go @@ -53,12 +53,13 @@ type Linter struct { enable []string enableCategory []string ignoreFiles []string - additionalAggregates map[string][]report.Aggregate + overriddenAggregates map[string][]report.Aggregate + useCollectQuery bool debugMode bool + exportAggregates bool disableAll bool enableAll bool profiling bool - populateAggregates bool } //nolint:gochecknoglobals @@ -216,20 +217,28 @@ func (l Linter) WithRootDir(rootDir string) Linter { return l } -// WithAlwaysAggregate enables the population of aggregate data even when -// linting a single file. This is useful when needing to incrementally build +// WithExportAggregates enables the setting of intermediate aggregate data +// on the final report. This is useful when you want to collect and // aggregate state from multiple different linting runs. -func (l Linter) WithAlwaysAggregate(enabled bool) Linter { - l.populateAggregates = enabled +func (l Linter) WithExportAggregates(enabled bool) Linter { + l.exportAggregates = enabled return l } -// WithAggregates supplies additional aggregate data to a linter instance. +// WithCollectQuery forcibly enables the collect query even when there is +// only one file to lint. +func (l Linter) WithCollectQuery(enabled bool) Linter { + l.useCollectQuery = enabled + + return l +} + +// WithAggregates supplies aggregate data to a linter instance. // Likely generated in a previous run, and used to provide a global context to // a subsequent run of a single file lint. func (l Linter) WithAggregates(aggregates map[string][]report.Aggregate) Linter { - l.additionalAggregates = aggregates + l.overriddenAggregates = aggregates return l } @@ -240,7 +249,7 @@ func (l Linter) Lint(ctx context.Context) (report.Report, error) { finalReport := report.Report{} - if len(l.inputPaths) == 0 && l.inputModules == nil { + if len(l.inputPaths) == 0 && l.inputModules == nil && len(l.overriddenAggregates) == 0 { return report.Report{}, errors.New("nothing provided to lint") } @@ -335,12 +344,13 @@ func (l Linter) Lint(ctx context.Context) (report.Report, error) { } } - if len(input.FileNames) > 1 || len(l.additionalAggregates) > 0 { - allAggregates := make(map[string][]report.Aggregate) - for k, aggregates := range l.additionalAggregates { + allAggregates := make(map[string][]report.Aggregate) + + if len(l.overriddenAggregates) > 0 { + for k, aggregates := range l.overriddenAggregates { allAggregates[k] = append(allAggregates[k], aggregates...) } - + } else if len(input.FileNames) > 1 { for k, aggregates := range goReport.Aggregates { allAggregates[k] = append(allAggregates[k], aggregates...) } @@ -348,7 +358,9 @@ func (l Linter) Lint(ctx context.Context) (report.Report, error) { for k, aggregates := range regoReport.Aggregates { allAggregates[k] = append(allAggregates[k], aggregates...) } + } + if len(allAggregates) > 0 { aggregateReport, err := l.lintWithRegoAggregateRules(ctx, allAggregates, regoReport.IgnoreDirectives) if err != nil { return report.Report{}, fmt.Errorf("failed to lint using Rego aggregate rules: %w", err) @@ -364,7 +376,7 @@ func (l Linter) Lint(ctx context.Context) (report.Report, error) { NumViolations: len(finalReport.Violations), } - if l.populateAggregates { + if l.exportAggregates { finalReport.Aggregates = make(map[string][]report.Aggregate) for k, aggregates := range goReport.Aggregates { finalReport.Aggregates[k] = append(finalReport.Aggregates[k], aggregates...) @@ -390,9 +402,10 @@ func (l Linter) Lint(ctx context.Context) (report.Report, error) { return finalReport, nil } -// DetermineEnabledRules returns the list of rules that are enabled based on the supplied configuration. -// This makes use of the Rego and Go rule settings to produce a single list of the rules that are to be run -// on this linter instance. +// DetermineEnabledRules returns the list of rules that are enabled based on +// the supplied configuration. This makes use of the Rego and Go rule settings +// to produce a single list of the rules that are to be run on this linter +// instance. func (l Linter) DetermineEnabledRules(ctx context.Context) ([]string, error) { enabledRules := make([]string, 0) @@ -405,10 +418,91 @@ func (l Linter) DetermineEnabledRules(ctx context.Context) ([]string, error) { enabledRules = append(enabledRules, rule.Name()) } + conf, err := l.GetConfig() + if err != nil { + return []string{}, fmt.Errorf("failed to merge config: %w", err) + } + + l.dataBundle = &bundle.Bundle{ + Manifest: bundle.Manifest{ + Roots: &[]string{"internal"}, + Metadata: map[string]any{"name": "internal"}, + }, + Data: map[string]any{ + "internal": map[string]any{ + "combined_config": config.ToMap(*conf), + "capabilities": rio.ToMap(config.CapabilitiesForThisVersion()), + }, + }, + } + + queryStr := `[rule | + data.regal.rules[cat][rule] + data.regal.config.for_rule(cat, rule).level != "ignore" + ]` + + query := ast.MustParseBody(queryStr) + + regoArgs, err := l.prepareRegoArgs(query) + if err != nil { + return nil, fmt.Errorf("failed preparing query %s: %w", queryStr, err) + } + + rs, err := rego.New(regoArgs...).Eval(ctx) + if err != nil { + return nil, fmt.Errorf("failed evaluating query %s: %w", queryStr, err) + } + + if len(rs) != 1 || len(rs[0].Expressions) != 1 { + return nil, fmt.Errorf("expected exactly one expression, got %d", len(rs[0].Expressions)) + } + + list, ok := rs[0].Expressions[0].Value.([]interface{}) + if !ok { + return nil, fmt.Errorf("expected list, got %T", rs[0].Expressions[0].Value) + } + + for _, item := range list { + rule, ok := item.(string) + if !ok { + return nil, fmt.Errorf("expected string, got %T", item) + } + + enabledRules = append(enabledRules, rule) + } + + slices.Sort(enabledRules) + + return enabledRules, nil +} + +// DetermineEnabledAggregateRules returns the list of aggregate rules that are +// enabled based on the configuration. This does not include any go rules. +func (l Linter) DetermineEnabledAggregateRules(ctx context.Context) ([]string, error) { + enabledRules := make([]string, 0) + + conf, err := l.GetConfig() + if err != nil { + return []string{}, fmt.Errorf("failed to merge config: %w", err) + } + + l.dataBundle = &bundle.Bundle{ + Manifest: bundle.Manifest{ + Roots: &[]string{"internal"}, + Metadata: map[string]any{"name": "internal"}, + }, + Data: map[string]any{ + "internal": map[string]any{ + "combined_config": config.ToMap(*conf), + "capabilities": rio.ToMap(config.CapabilitiesForThisVersion()), + }, + }, + } + queryStr := `[rule | - data.regal.rules[cat][rule] - data.regal.config.for_rule(cat, rule).level != "ignore" -]` + data.regal.rules[cat][rule].aggregate + data.regal.config.for_rule(cat, rule).level != "ignore" + ]` query := ast.MustParseBody(queryStr) @@ -454,7 +548,7 @@ func (l Linter) lintWithGoRules(ctx context.Context, input rules.Input) (report. return report.Report{}, fmt.Errorf("failed to get configured Go rules: %w", err) } - aggregate := report.Report{} + goReport := report.Report{} for _, rule := range goRules { inp, err := inputForRule(input, rule) @@ -467,10 +561,10 @@ func (l Linter) lintWithGoRules(ctx context.Context, input rules.Input) (report. return report.Report{}, fmt.Errorf("error encountered in Go rule evaluation: %w", err) } - aggregate.Violations = append(aggregate.Violations, result.Violations...) + goReport.Violations = append(goReport.Violations, result.Violations...) } - return aggregate, err + return goReport, err } func inputForRule(input rules.Input, rule rules.Rule) (rules.Input, error) { @@ -708,7 +802,7 @@ func (l Linter) lintWithRegoRules(ctx context.Context, input rules.Input) (repor defer cancel() var query ast.Body - if len(input.FileNames) > 1 || l.populateAggregates { + if len(input.FileNames) > 1 || l.useCollectQuery { query = lintAndCollectQuery } else { query = lintQuery @@ -724,9 +818,9 @@ func (l Linter) lintWithRegoRules(ctx context.Context, input rules.Input) (repor return report.Report{}, fmt.Errorf("failed preparing query for linting: %w", err) } - aggregate := report.Report{} - aggregate.Aggregates = make(map[string][]report.Aggregate) - aggregate.IgnoreDirectives = make(map[string]map[string][]string) + regoReport := report.Report{} + regoReport.Aggregates = make(map[string][]report.Aggregate) + regoReport.IgnoreDirectives = make(map[string]map[string][]string) var wg sync.WaitGroup @@ -789,19 +883,19 @@ func (l Linter) lintWithRegoRules(ctx context.Context, input rules.Input) (repor } mu.Lock() - aggregate.Violations = append(aggregate.Violations, result.Violations...) - aggregate.Notices = append(aggregate.Notices, result.Notices...) + regoReport.Violations = append(regoReport.Violations, result.Violations...) + regoReport.Notices = append(regoReport.Notices, result.Notices...) for k := range result.Aggregates { - aggregate.Aggregates[k] = append(aggregate.Aggregates[k], result.Aggregates[k]...) + regoReport.Aggregates[k] = append(regoReport.Aggregates[k], result.Aggregates[k]...) } for k := range result.IgnoreDirectives { - aggregate.IgnoreDirectives[k] = result.IgnoreDirectives[k] + regoReport.IgnoreDirectives[k] = result.IgnoreDirectives[k] } if l.profiling { - aggregate.AddProfileEntries(result.AggregateProfile) + regoReport.AddProfileEntries(result.AggregateProfile) } mu.Unlock() }(name) @@ -818,7 +912,7 @@ func (l Linter) lintWithRegoRules(ctx context.Context, input rules.Input) (repor case err := <-errCh: return report.Report{}, fmt.Errorf("error encountered in rule evaluation %w", err) case <-doneCh: - return aggregate, nil + return regoReport, nil } } diff --git a/pkg/linter/linter_test.go b/pkg/linter/linter_test.go index dd3527d7..f52ca499 100644 --- a/pkg/linter/linter_test.go +++ b/pkg/linter/linter_test.go @@ -4,10 +4,12 @@ import ( "bytes" "context" "embed" - "os" "path/filepath" + "slices" "testing" + "gopkg.in/yaml.v2" + "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/topdown" @@ -15,6 +17,7 @@ import ( "github.com/styrainc/regal/internal/test" "github.com/styrainc/regal/internal/testutil" "github.com/styrainc/regal/pkg/config" + "github.com/styrainc/regal/pkg/report" "github.com/styrainc/regal/pkg/rules" ) @@ -598,7 +601,79 @@ func TestEnabledRules(t *testing.T) { } } -func TestLintWithPopulateAggregates(t *testing.T) { +func TestEnabledRulesWithConfig(t *testing.T) { + t.Parallel() + + configFileBs := []byte(` +rules: + style: + opa-fmt: + level: ignore # go rule + imports: + unresolved-import: # agg rule + level: ignore + idiomatic: + directory-package-mismatch: # non agg rule + level: ignore +`) + + var userConfig config.Config + + if err := yaml.Unmarshal(configFileBs, &userConfig); err != nil { + t.Fatalf("failed to load config: %s", err) + } + + linter := NewLinter().WithUserConfig(userConfig) + + enabledRules, err := linter.DetermineEnabledRules(context.Background()) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + enabledAggRules, err := linter.DetermineEnabledAggregateRules(context.Background()) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if len(enabledRules) == 0 { + t.Fatalf("expected enabledRules, got none") + } + + if slices.Contains(enabledRules, "directory-package-mismatch") { + t.Errorf("did not expect directory-package-mismatch to be in enabled rules") + } + + if slices.Contains(enabledRules, "opa-fmt") { + t.Errorf("did not expect opa-fmt to be in enabled rules") + } + + if slices.Contains(enabledAggRules, "unresolved-import") { + t.Errorf("did not expect unresolved-import to be in enabled aggregate rules") + } +} + +func TestEnabledAggregateRules(t *testing.T) { + t.Parallel() + + linter := NewLinter(). + WithDisableAll(true). + WithEnabledRules("opa-fmt", "unresolved-import", "use-assignment-operator") + + enabledRules, err := linter.DetermineEnabledAggregateRules(context.Background()) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if len(enabledRules) != 1 { + t.Fatalf("expected 1 enabled rules, got %d", len(enabledRules)) + } + + if enabledRules[0] != "unresolved-import" { + t.Errorf("expected first enabled rule to be 'unresolved-import', got %q", enabledRules[0]) + } +} + +func TestLintWithCollectQuery(t *testing.T) { t.Parallel() input := test.InputPolicy("p.rego", `package p @@ -609,8 +684,8 @@ import data.foo.bar.unresolved linter := NewLinter(). WithDisableAll(true). WithEnabledRules("unresolved-import"). - WithPrintHook(topdown.NewPrintHook(os.Stderr)). - WithAlwaysAggregate(true). + WithCollectQuery(true). // needed since we have a single file input + WithExportAggregates(true). // needed to be able to test the aggregates are set WithInputModules(&input) result := testutil.Must(linter.Lint(context.Background()))(t) @@ -624,41 +699,64 @@ import data.foo.bar.unresolved } } -func TestLintWithAggregates(t *testing.T) { +func TestLintWithCollectQueryAndAggregates(t *testing.T) { t.Parallel() - contents := `package p + files := map[string]string{ + "foo.rego": `package foo -import data.foo.bar.unresolved -` +import data.unresolved`, + "bar.rego": `package foo - input := test.InputPolicy("p.rego", contents) +import data.unresolved`, + "baz.rego": `package foo - linter := NewLinter(). - WithDisableAll(true). - WithEnabledRules("unresolved-import"). - WithPrintHook(topdown.NewPrintHook(os.Stderr)). - WithAlwaysAggregate(true). - WithInputModules(&input) +import data.unresolved`, + } + + allAggregates := make(map[string][]report.Aggregate) + + for file, content := range files { + input := test.InputPolicy(file, content) + + linter := NewLinter(). + WithDisableAll(true). + WithEnabledRules("unresolved-import"). + WithCollectQuery(true). // runs collect for a single file input + WithExportAggregates(true). + WithInputModules(&input) - result1 := testutil.Must(linter.Lint(context.Background()))(t) + result := testutil.Must(linter.Lint(context.Background()))(t) - linter2 := NewLinter(). + for k, aggs := range result.Aggregates { + allAggregates[k] = append(allAggregates[k], aggs...) + } + } + + linter := NewLinter(). WithDisableAll(true). WithEnabledRules("unresolved-import"). - WithPrintHook(topdown.NewPrintHook(os.Stderr)). - WithAggregates(result1.Aggregates). - WithInputModules(&input) + WithAggregates(allAggregates) - result := testutil.Must(linter2.Lint(context.Background()))(t) + result := testutil.Must(linter.Lint(context.Background()))(t) - if len(result.Violations) != 1 { + if len(result.Violations) != 3 { t.Fatalf("expected one violation, got %d", len(result.Violations)) } - violation := result.Violations[0] + foundFiles := []string{} + + for _, v := range result.Violations { + if v.Title != "unresolved-import" { + t.Errorf("unexpected title: %s", v.Title) + } + + foundFiles = append(foundFiles, v.Location.File) + } + + slices.Sort(foundFiles) - if violation.Title != "unresolved-import" { - t.Errorf("expected violation to be 'unresolved-import', got %q", violation.Title) + if exp, got := []string{"bar.rego", "baz.rego", "foo.rego"}, foundFiles; !slices.Equal(exp, got) { + t.Fatalf("unexpected files: %v", got) } } diff --git a/pkg/report/report.go b/pkg/report/report.go index 292a2502..dcede9d9 100644 --- a/pkg/report/report.go +++ b/pkg/report/report.go @@ -61,6 +61,42 @@ type Notice struct { // while working with large Rego code repositories. type Aggregate map[string]any +func (a Aggregate) SourceFile() string { + source, ok := a["aggregate_source"].(map[string]any) + if !ok { + return "" + } + + file, ok := source["file"].(string) + if !ok { + return "" + } + + return file +} + +// IndexKey is the category/title of the rule that generated the aggregate. +// This key is generated in Rego during linting, this function replicates the +// functionality in Go for use in the cache when indexing aggregates. +func (a Aggregate) IndexKey() string { + rule, ok := a["rule"].(map[string]any) + if !ok { + return "" + } + + cat, ok := rule["category"].(string) + if !ok { + return "" + } + + title, ok := rule["title"].(string) + if !ok { + return "" + } + + return fmt.Sprintf("%s/%s", cat, title) +} + type Summary struct { FilesScanned int `json:"files_scanned"` FilesFailed int `json:"files_failed"`