From 930c21b00eee444aa1ee70409467c2ecfba03113 Mon Sep 17 00:00:00 2001 From: dobarx <111326505+dobarx@users.noreply.github.com> Date: Tue, 5 Mar 2024 21:55:43 +0200 Subject: [PATCH] Plugin download & locking (#118) --- .gitignore | 4 +- .golangci.yaml | 4 +- .goreleaser.yaml | 251 ++++++++++--- .mockery.yaml | 11 +- buf.gen.yaml | 5 +- cmd/data.go | 7 +- cmd/evaluator.go | 91 +++-- cmd/install.go | 42 +++ cmd/render.go | 8 +- cmd/root.go | 21 +- examples/templates/hackerone/example.fabric | 7 +- examples/templates/openai/example.fabric | 10 +- examples/templates/stixview/data.csv | 4 - examples/templates/stixview/example.fabric | 6 +- examples/templates/virustotal/example.fabric | 7 +- parser/definitions/global_config.go | 108 +----- parser/definitions/global_config_test.go | 108 ------ parser/parser.go | 7 +- plugin/pluginapi/v1/client.go | 38 +- plugin/pluginapi/v1/cty_type_decoder.go | 6 +- plugin/pluginapi/v1/cty_type_encoder.go | 14 +- plugin/pluginapi/v1/cty_value_decoder.go | 24 +- plugin/pluginapi/v1/hclspec_decoder.go | 3 - plugin/pluginapi/v1/plugin.go | 12 +- plugin/registry/.gitkeep | 0 plugin/resolver/checksum.go | 116 ++++++ plugin/resolver/checksum_test.go | 327 +++++++++++++++++ plugin/resolver/lockfile.go | 108 ++++++ plugin/resolver/lockfile_test.go | 91 +++++ plugin/resolver/mock_source_test.go | 156 ++++++++ plugin/resolver/name.go | 67 ++++ plugin/resolver/name_test.go | 148 ++++++++ plugin/resolver/options.go | 34 ++ plugin/resolver/resolver.go | 226 ++++++++++++ plugin/resolver/resolver_test.go | 183 ++++++++++ plugin/resolver/source.go | 70 ++++ plugin/resolver/source_local.go | 131 +++++++ plugin/resolver/source_local_test.go | 363 +++++++++++++++++++ plugin/resolver/source_remote.go | 363 +++++++++++++++++++ plugin/resolver/source_remote_test.go | 202 +++++++++++ plugin/resolver/version.go | 58 +++ plugin/resolver/version_test.go | 148 ++++++++ plugin/runner/loader.go | 92 ++--- plugin/runner/options.go | 33 -- plugin/runner/resolver.go | 96 ----- plugin/runner/runner.go | 84 +---- test/e2e/data_test.go | 7 +- test/e2e/render_test.go | 7 +- tools/pluginmeta/main.go | 217 +++++++++++ tools/pluginmeta/metadata.go | 18 + tools/pluginmeta/releaser_config.go | 23 ++ 51 files changed, 3531 insertions(+), 635 deletions(-) create mode 100644 cmd/install.go delete mode 100644 examples/templates/stixview/data.csv delete mode 100644 parser/definitions/global_config_test.go delete mode 100644 plugin/registry/.gitkeep create mode 100644 plugin/resolver/checksum.go create mode 100644 plugin/resolver/checksum_test.go create mode 100644 plugin/resolver/lockfile.go create mode 100644 plugin/resolver/lockfile_test.go create mode 100644 plugin/resolver/mock_source_test.go create mode 100644 plugin/resolver/name.go create mode 100644 plugin/resolver/name_test.go create mode 100644 plugin/resolver/options.go create mode 100644 plugin/resolver/resolver.go create mode 100644 plugin/resolver/resolver_test.go create mode 100644 plugin/resolver/source.go create mode 100644 plugin/resolver/source_local.go create mode 100644 plugin/resolver/source_local_test.go create mode 100644 plugin/resolver/source_remote.go create mode 100644 plugin/resolver/source_remote_test.go create mode 100644 plugin/resolver/version.go create mode 100644 plugin/resolver/version_test.go delete mode 100644 plugin/runner/options.go delete mode 100644 plugin/runner/resolver.go create mode 100644 tools/pluginmeta/main.go create mode 100644 tools/pluginmeta/metadata.go create mode 100644 tools/pluginmeta/releaser_config.go diff --git a/.gitignore b/.gitignore index 112066fb..a0df6f17 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,6 @@ bin/* *.dll *.so *.dylib - .DS_Store # Vale will install modules in `.vale/style` and those should not be in version control @@ -30,3 +29,6 @@ vendor/ go.work dist/ +.tmp +.fabric +.fabric-lock.json \ No newline at end of file diff --git a/.golangci.yaml b/.golangci.yaml index 600588cd..8b7313a1 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -86,7 +86,9 @@ linters-settings: nolintlint: require-explanation: false # don't require an explanation for nolint directives require-specific: false # don't require nolint directives to be specific about which linter is being skipped - + # gosec: + # excludes: + # - G110 linters: disable-all: true enable: diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 0b67100a..a3dcc651 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -6,6 +6,11 @@ project_name: fabric env: - CGO_ENABLED=0 +before: + hooks: + - go mod tidy + - go run ./tools/pluginmeta --namespace blackstork --version {{.Version}} + builds: # CLI @@ -22,110 +27,147 @@ builds: - darwin # Plugins + # TODO: generate this list with custom script or use Premium goreleaser to template it - - id: elasticsearch + - id: plugin_elasticsearch main: ./internal/elasticsearch/cmd - binary: "plugins/blackstork/elasticsearch@{{ .Version }}" + binary: "elasticsearch@{{ .Version }}" flags: "-trimpath" + hooks: + post: + - go run ./tools/pluginmeta --namespace blackstork --version {{.Version}} patch --plugin {{.Path}} --os {{.Os}} --arch {{.Arch}} goos: - linux - windows - darwin - - - id: github + + - id: plugin_github main: ./internal/github/cmd - binary: "plugins/blackstork/github@{{ .Version }}" + binary: "github@{{ .Version }}" flags: "-trimpath" + hooks: + post: + - go run ./tools/pluginmeta --namespace blackstork --version {{.Version}} patch --plugin {{.Path}} --os {{.Os}} --arch {{.Arch}} goos: - linux - windows - darwin - - id: graphql + - id: plugin_graphql main: ./internal/graphql/cmd - binary: "plugins/blackstork/graphql@{{ .Version }}" + binary: "graphql@{{ .Version }}" flags: "-trimpath" + hooks: + post: + - go run ./tools/pluginmeta --namespace blackstork --version {{.Version}} patch --plugin {{.Path}} --os {{.Os}} --arch {{.Arch}} goos: - linux - windows - darwin - - id: openai + - id: plugin_openai main: ./internal/openai/cmd - binary: "plugins/blackstork/openai@{{ .Version }}" + binary: "openai@{{ .Version }}" flags: "-trimpath" + hooks: + post: + - go run ./tools/pluginmeta --namespace blackstork --version {{.Version}} patch --plugin {{.Path}} --os {{.Os}} --arch {{.Arch}} goos: - linux - windows - darwin - - id: opencti + - id: plugin_opencti main: ./internal/opencti/cmd - binary: "plugins/blackstork/opencti@{{ .Version }}" + binary: "opencti@{{ .Version }}" flags: "-trimpath" + hooks: + post: + - go run ./tools/pluginmeta --namespace blackstork --version {{.Version}} patch --plugin {{.Path}} --os {{.Os}} --arch {{.Arch}} goos: - linux - windows - darwin - - id: postgresql + - id: plugin_postgresql main: ./internal/postgresql/cmd - binary: "plugins/blackstork/postgresql@{{ .Version }}" + binary: "postgresql@{{ .Version }}" flags: "-trimpath" + hooks: + post: + - go run ./tools/pluginmeta --namespace blackstork --version {{.Version}} patch --plugin {{.Path}} --os {{.Os}} --arch {{.Arch}} goos: - linux - windows - darwin - - id: sqlite + - id: plugin_sqlite main: ./internal/sqlite/cmd - binary: "plugins/blackstork/sqlite@{{ .Version }}" + binary: "sqlite@{{ .Version }}" flags: "-trimpath" + hooks: + post: + - go run ./tools/pluginmeta --namespace blackstork --version {{.Version}} patch --plugin {{.Path}} --os {{.Os}} --arch {{.Arch}} goos: - linux - windows - darwin - - id: terraform + - id: plugin_terraform main: ./internal/terraform/cmd - binary: "plugins/blackstork/terraform@{{ .Version }}" + binary: "terraform@{{ .Version }}" flags: "-trimpath" + hooks: + post: + - go run ./tools/pluginmeta --namespace blackstork --version {{.Version}} patch --plugin {{.Path}} --os {{.Os}} --arch {{.Arch}} goos: - linux - windows - darwin - - id: hackerone + - id: plugin_hackerone main: ./internal/hackerone/cmd - binary: "plugins/blackstork/hackerone@{{ .Version }}" + binary: "hackerone@{{ .Version }}" flags: "-trimpath" + hooks: + post: + - go run ./tools/pluginmeta --namespace blackstork --version {{.Version}} patch --plugin {{.Path}} --os {{.Os}} --arch {{.Arch}} goos: - linux - windows - darwin - - id: virustotal + - id: plugin_virustotal main: ./internal/virustotal/cmd - binary: "plugins/blackstork/virustotal@{{ .Version }}" + binary: "virustotal@{{ .Version }}" flags: "-trimpath" + hooks: + post: + - go run ./tools/pluginmeta --namespace blackstork --version {{.Version}} patch --plugin {{.Path}} --os {{.Os}} --arch {{.Arch}} goos: - linux - windows - darwin - - id: splunk + - id: plugin_splunk main: ./internal/splunk/cmd - binary: "plugins/blackstork/splunk@{{ .Version }}" + binary: "splunk@{{ .Version }}" flags: "-trimpath" + hooks: + post: + - go run ./tools/pluginmeta --namespace blackstork --version {{.Version}} patch --plugin {{.Path}} --os {{.Os}} --arch {{.Arch}} goos: - linux - windows - darwin - - id: stixview + - id: plugin_stixview main: ./internal/stixview/cmd - binary: "plugins/blackstork/stixview@{{ .Version }}" + binary: "stixview@{{ .Version }}" flags: "-trimpath" + hooks: + post: + - go run ./tools/pluginmeta --namespace blackstork --version {{.Version}} patch --plugin {{.Path}} --os {{.Os}} --arch {{.Arch}} goos: - linux - windows @@ -147,31 +189,152 @@ archives: - goos: windows format: zip - - id: plugins + # Plugins + # TODO: generate this list with custom script or use Premium goreleaser to template it + + - id: plugin_elasticsearch format: tar.gz builds: - - elasticsearch - - github - - graphql - - openai - - opencti - - postgresql - - sqlite - - terraform - - hackerone - - virustotal - - splunk - - stixview + - plugin_elasticsearch name_template: >- - plugins_ + plugin_elasticsearch_ + {{- .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + + - id: plugin_github + format: tar.gz + builds: + - plugin_github + name_template: >- + plugin_github_ + {{- .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + + - id: plugin_graphql + format: tar.gz + builds: + - plugin_graphql + name_template: >- + plugin_graphql_ + {{- .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + + - id: plugin_openai + format: tar.gz + builds: + - plugin_openai + name_template: >- + plugin_openai_ + {{- .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + + - id: plugin_opencti + format: tar.gz + builds: + - plugin_opencti + name_template: >- + plugin_opencti_ + {{- .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + + - id: plugin_postgresql + format: tar.gz + builds: + - plugin_postgresql + name_template: >- + plugin_postgresql_ + {{- .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + + - id: plugin_sqlite + format: tar.gz + builds: + - plugin_sqlite + name_template: >- + plugin_sqlite_ + {{- .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + + - id: plugin_terraform + format: tar.gz + builds: + - plugin_terraform + name_template: >- + plugin_terraform_ + {{- .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + + - id: plugin_hackerone + format: tar.gz + builds: + - plugin_hackerone + name_template: >- + plugin_hackerone_ + {{- .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + + - id: plugin_virustotal + format: tar.gz + builds: + - plugin_virustotal + name_template: >- + plugin_virustotal_ + {{- .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + + - id: plugin_splunk + format: tar.gz + builds: + - plugin_splunk + name_template: >- + plugin_splunk_ + {{- .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + + - id: plugin_stixview + format: tar.gz + builds: + - plugin_stixview + name_template: >- + plugin_stixview_ {{- .Os }}_ {{- if eq .Arch "amd64" }}x86_64 {{- else if eq .Arch "386" }}i386 {{- else }}{{ .Arch }}{{ end }} {{- if .Arm }}v{{ .Arm }}{{ end }} - format_overrides: - - goos: windows - format: zip changelog: sort: asc @@ -179,3 +342,7 @@ changelog: exclude: - "^docs:" - "^test:" +release: + extra_files: + - glob: ./.tmp/plugins.json + prerelease: auto \ No newline at end of file diff --git a/.mockery.yaml b/.mockery.yaml index 51c97f6d..e2df577c 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -28,4 +28,13 @@ packages: github.com/blackstork-io/fabric/internal/splunk/client: config: interfaces: - Client: \ No newline at end of file + Client: + github.com/blackstork-io/fabric/plugin/resolver: + config: + inpackage: true + dir: "./plugin/resolver" + mockname: "mock{{.InterfaceName}}" + outpkg: "{{.PackageName}}" + filename: "mock_{{.InterfaceName | snakecase}}_test.go" + interfaces: + Source: \ No newline at end of file diff --git a/buf.gen.yaml b/buf.gen.yaml index fb330997..f15052e0 100644 --- a/buf.gen.yaml +++ b/buf.gen.yaml @@ -9,7 +9,4 @@ plugins: opt: module=github.com/blackstork-io/fabric - plugin: buf.build/grpc/go out: . - opt: module=github.com/blackstork-io/fabric - # - plugin: buf.build/connectrpc/go - # out: . - # opt: module=github.com/knobz-io/knobz \ No newline at end of file + opt: module=github.com/blackstork-io/fabric \ No newline at end of file diff --git a/cmd/data.go b/cmd/data.go index 845c1970..7a910f4b 100644 --- a/cmd/data.go +++ b/cmd/data.go @@ -97,7 +97,7 @@ var dataCmd = &cobra.Command{ Long: `Execute the data block and print out prettified JSON to stdout`, RunE: func(cmd *cobra.Command, args []string) (err error) { var diags diagnostics.Diag - eval := NewEvaluator(cliArgs.pluginsDir) + eval := NewEvaluator() defer func() { err = eval.Cleanup(diags) }() @@ -105,7 +105,10 @@ var dataCmd = &cobra.Command{ if diags.HasErrors() { return } - if diags.Extend(eval.LoadRunner()) { + if diags.Extend(eval.LoadPluginResolver(false)) { + return + } + if diags.Extend(eval.LoadPluginRunner(cmd.Context())) { return } diff --git a/cmd/evaluator.go b/cmd/evaluator.go index f63dd0b7..31b4398f 100644 --- a/cmd/evaluator.go +++ b/cmd/evaluator.go @@ -1,27 +1,45 @@ package cmd import ( + "context" + "fmt" "io/fs" + "log/slog" "os" + "path/filepath" "github.com/hashicorp/hcl/v2" "github.com/blackstork-io/fabric/internal/builtin" "github.com/blackstork-io/fabric/parser" + "github.com/blackstork-io/fabric/parser/definitions" "github.com/blackstork-io/fabric/pkg/diagnostics" + "github.com/blackstork-io/fabric/plugin/resolver" "github.com/blackstork-io/fabric/plugin/runner" ) +const ( + defaultLockFile = ".fabric-lock.json" +) + type Evaluator struct { - PluginsDir string - Blocks *parser.DefinedBlocks - Runner *runner.Runner - FileMap map[string]*hcl.File + Config *definitions.GlobalConfig + Blocks *parser.DefinedBlocks + Runner *runner.Runner + LockFile *resolver.LockFile + Resolver *resolver.Resolver + FileMap map[string]*hcl.File } -func NewEvaluator(pluginsDir string) *Evaluator { +func NewEvaluator() *Evaluator { return &Evaluator{ - PluginsDir: pluginsDir, + Config: &definitions.GlobalConfig{ + PluginRegistry: &definitions.PluginRegistry{ + BaseURL: "https://registry.blackstork.io", + MirrorDir: "", + }, + CacheDir: ".fabric", + }, } } @@ -44,28 +62,59 @@ func (e *Evaluator) ParseFabricFiles(sourceDir fs.FS) (diags diagnostics.Diag) { if diags.HasErrors() { return } - if e.PluginsDir == "" && e.Blocks.GlobalConfig != nil && e.Blocks.GlobalConfig.PluginRegistry != nil { - // use pluginsDir from config, unless overridden by cli arg - e.PluginsDir = e.Blocks.GlobalConfig.PluginRegistry.MirrorDir + if e.Blocks.GlobalConfig != nil { + e.Config.Merge(e.Blocks.GlobalConfig) } return } -func (e *Evaluator) LoadRunner() diagnostics.Diag { - var pluginVersions runner.VersionMap - if e.Blocks.GlobalConfig != nil { - pluginVersions = e.Blocks.GlobalConfig.PluginVersions +func (e *Evaluator) LoadPluginRunner(ctx context.Context) diagnostics.Diag { + var diag diagnostics.Diag + binaryMap, diags := e.Resolver.Resolve(ctx, e.LockFile) + if diag.ExtendHcl(diags) { + return diag } - var stdDiag hcl.Diagnostics + e.Runner, diags = runner.Load(binaryMap, builtin.Plugin(version), slog.Default()) + diag.ExtendHcl(diags) + return diag +} - e.Runner, stdDiag = runner.Load( - runner.WithBuiltIn( - builtin.Plugin(version), - ), - runner.WithPluginDir(e.PluginsDir), - runner.WithPluginVersions(pluginVersions), +func (e *Evaluator) LoadPluginResolver(includeRemote bool) diagnostics.Diag { + pluginDir := filepath.Join(e.Config.CacheDir, "plugins") + sources := []resolver.Source{ + resolver.LocalSource{ + Path: pluginDir, + }, + } + if e.Config.PluginRegistry != nil { + if e.Config.PluginRegistry.MirrorDir != "" { + sources = append(sources, resolver.LocalSource{ + Path: e.Config.PluginRegistry.MirrorDir, + }) + } + if includeRemote && e.Config.PluginRegistry.BaseURL != "" { + sources = append(sources, resolver.RemoteSource{ + BaseURL: e.Config.PluginRegistry.BaseURL, + DownloadDir: pluginDir, + UserAgent: fmt.Sprintf("fabric/%s", version), + }) + } + } + var err error + e.LockFile, err = resolver.ReadLockFileFrom(defaultLockFile) + if err != nil { + return diagnostics.Diag{{ + Severity: hcl.DiagError, + Summary: "Failed to read lock file", + Detail: err.Error(), + }} + } + var diags hcl.Diagnostics + e.Resolver, diags = resolver.NewResolver(e.Config.PluginVersions, + resolver.WithLogger(slog.Default()), + resolver.WithSources(sources...), ) - return diagnostics.Diag(stdDiag) + return diagnostics.Diag(diags) } func (e *Evaluator) PluginCaller() *parser.Caller { diff --git a/cmd/install.go b/cmd/install.go new file mode 100644 index 00000000..0d4cbdee --- /dev/null +++ b/cmd/install.go @@ -0,0 +1,42 @@ +package cmd + +import ( + "os" + + "github.com/spf13/cobra" + + "github.com/blackstork-io/fabric/pkg/diagnostics" + "github.com/blackstork-io/fabric/plugin/resolver" +) + +var installUpgrade bool + +var installCmd = &cobra.Command{ + Use: "install", + Short: "Install plugins", + Long: "Install Fabric plugins", + RunE: func(cmd *cobra.Command, args []string) (err error) { + var diags diagnostics.Diag + eval := NewEvaluator() + defer func() { + err = eval.Cleanup(diags) + }() + diags = eval.ParseFabricFiles(os.DirFS(cliArgs.sourceDir)) + if diags.HasErrors() { + return + } + if diags.Extend(eval.LoadPluginResolver(true)) { + return + } + lockFile, stdDiags := eval.Resolver.Install(cmd.Context(), eval.LockFile, installUpgrade) + if diags.ExtendHcl(stdDiags) { + return + } + return resolver.SaveLockFileTo(defaultLockFile, lockFile) + }, +} + +func init() { + rootCmd.AddCommand(installCmd) + installCmd.Flags().BoolVarP(&installUpgrade, "upgrade", "u", false, "Upgrade plugin versions") +} diff --git a/cmd/render.go b/cmd/render.go index 53cd4f02..90b66148 100644 --- a/cmd/render.go +++ b/cmd/render.go @@ -86,7 +86,7 @@ var renderCmd = &cobra.Command{ } var diags diagnostics.Diag - eval := NewEvaluator(cliArgs.pluginsDir) + eval := NewEvaluator() defer func() { err = eval.Cleanup(diags) }() @@ -94,8 +94,10 @@ var renderCmd = &cobra.Command{ if diags.HasErrors() { return } - diag := eval.LoadRunner() - if diags.Extend(diag) { + if diags.Extend(eval.LoadPluginResolver(false)) { + return + } + if diags.Extend(eval.LoadPluginRunner(cmd.Context())) { return } res, diag := Render(cmd.Context(), eval.Blocks, eval.PluginCaller(), target) diff --git a/cmd/root.go b/cmd/root.go index 8a960720..4ba747a2 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -87,8 +87,6 @@ var rootCmd = &cobra.Command{ } cliArgs.sourceDir = rawArgs.sourceDir - cliArgs.pluginsDir = rawArgs.pluginsDir - cliArgs.colorize = rawArgs.colorize && term.IsTerminal(int(os.Stderr.Fd())) var level slog.Level @@ -157,18 +155,16 @@ func Execute() { } var cliArgs = struct { - sourceDir string - pluginsDir string - colorize bool + sourceDir string + colorize bool }{} var rawArgs = struct { - sourceDir string - logOutput string - logLevel string - pluginsDir string - verbose bool - colorize bool + sourceDir string + logOutput string + logLevel string + verbose bool + colorize bool }{} func init() { @@ -180,7 +176,4 @@ func init() { ) rootCmd.PersistentFlags().BoolVar(&rawArgs.colorize, "color", true, "enables colorizing the logs and diagnostics (if supported by the terminal and log format)") rootCmd.PersistentFlags().BoolVarP(&rawArgs.verbose, "verbose", "v", false, "a shortcut to --log-level debug") - rootCmd.PersistentFlags().StringVar( - &rawArgs.pluginsDir, "plugins-dir", "", "override for plugins dir from fabric configuration", - ) } diff --git a/examples/templates/hackerone/example.fabric b/examples/templates/hackerone/example.fabric index dd6ec118..0972afa9 100644 --- a/examples/templates/hackerone/example.fabric +++ b/examples/templates/hackerone/example.fabric @@ -1,10 +1,6 @@ fabric { - cache_dir = "./.fabric" - plugin_registry { - mirror_dir = "dist/plugins" - } plugin_versions = { - "blackstork/hackerone" = "0.0.0-dev" + "blackstork/hackerone" = ">= 0.4 < 1.0 || 0.4.0-rev0" } } @@ -15,7 +11,6 @@ config data hackerone_reports { document "example" { title = "Using hackerone plugin" - data hackerone_reports "my_reports" { program = [from_env_variable("HACKERONE_PROGRAM")] } diff --git a/examples/templates/openai/example.fabric b/examples/templates/openai/example.fabric index 825a9d56..64c6c2b7 100644 --- a/examples/templates/openai/example.fabric +++ b/examples/templates/openai/example.fabric @@ -1,20 +1,14 @@ fabric { - cache_dir = "./.fabric" - plugin_registry { - mirror_dir = "dist/plugins" - } plugin_versions = { - "blackstork/openai" = "0.0.0-dev" + "blackstork/openai" = ">= 0.4 < 1.0 || 0.4.0-rev0" } } -config data csv {} - document "example" { title = "Testing plugins" data csv "csv_file" { - path = "./examples/templates/openai/data.csv" + path = "./data.csv" } content text { text = "Values from the CSV file" diff --git a/examples/templates/stixview/data.csv b/examples/templates/stixview/data.csv deleted file mode 100644 index 94ecff74..00000000 --- a/examples/templates/stixview/data.csv +++ /dev/null @@ -1,4 +0,0 @@ -id,active,name,age,height -b8fa4bb0-6dd4-45ba-96e0-9a182b2b932e,true,Stacey,26,1.98 -b0086c49-bcd8-4aae-9f88-4f46b128e709,false,Myriam,33,1.81 -a12d2a8c-eebc-42b3-be52-1ab0a2969a81,true,Oralee,31,2.23 \ No newline at end of file diff --git a/examples/templates/stixview/example.fabric b/examples/templates/stixview/example.fabric index 4815c695..0da07ea3 100644 --- a/examples/templates/stixview/example.fabric +++ b/examples/templates/stixview/example.fabric @@ -1,10 +1,6 @@ fabric { - cache_dir = "./.fabric" - plugin_registry { - mirror_dir = "dist/plugins" - } plugin_versions = { - "blackstork/stixview" = "0.0.0-dev" + "blackstork/stixview" = ">= 0.4 < 1.0 || 0.4.0-rev0" } } diff --git a/examples/templates/virustotal/example.fabric b/examples/templates/virustotal/example.fabric index 087dd3c3..5d4a6653 100644 --- a/examples/templates/virustotal/example.fabric +++ b/examples/templates/virustotal/example.fabric @@ -1,10 +1,6 @@ fabric { - cache_dir = "./.fabric" - plugin_registry { - mirror_dir = "dist/plugins" - } plugin_versions = { - "blackstork/virustotal" = "0.0.0-dev" + "blackstork/virustotal" = ">= 0.4 < 1.0 || 0.4.0-rev0" } } @@ -14,7 +10,6 @@ config data virustotal_api_usage { document "example" { title = "Using virustotal plugin" - data virustotal_api_usage "my_usage" { user_id = from_env_variable("VIRUSTOTAL_USER_ID") start_date = "20240201" diff --git a/parser/definitions/global_config.go b/parser/definitions/global_config.go index 0ac1a5d1..87fd1d44 100644 --- a/parser/definitions/global_config.go +++ b/parser/definitions/global_config.go @@ -1,105 +1,31 @@ package definitions -import ( - "fmt" - - "github.com/hashicorp/hcl/v2" - "github.com/hashicorp/hcl/v2/hcldec" - "github.com/hashicorp/hcl/v2/hclsyntax" - "github.com/zclconf/go-cty/cty" - "github.com/zclconf/go-cty/cty/convert" - - "github.com/blackstork-io/fabric/pkg/diagnostics" -) - -var globalConfigSpec = &hcldec.ObjectSpec{ - "cache_dir": &hcldec.AttrSpec{ - Name: "cache_dir", - Type: cty.String, - Required: false, - }, - "plugin_registry": &hcldec.BlockSpec{ - TypeName: "plugin_registry", - Nested: hcldec.ObjectSpec{ - "mirror_dir": &hcldec.AttrSpec{ - Name: "mirror_dir", - Type: cty.String, - Required: false, - }, - }, - }, - "plugin_versions": &hcldec.AttrSpec{ - Name: "plugin_versions", - Type: cty.Map(cty.String), - Required: false, - }, -} - type GlobalConfig struct { - block *hclsyntax.Block - CacheDir string - PluginRegistry *PluginRegistry - PluginVersions map[string]string + CacheDir string `hcl:"cache_dir,optional"` + PluginRegistry *PluginRegistry `hcl:"plugin_registry,block"` + PluginVersions map[string]string `hcl:"plugin_versions,optional"` } type PluginRegistry struct { - MirrorDir string + BaseURL string `hcl:"base_url,optional"` + MirrorDir string `hcl:"mirror_dir,optional"` } -func DefineGlobalConfig(block *hclsyntax.Block) (cfg *GlobalConfig, diags diagnostics.Diag) { - if len(block.Labels) > 0 { - return nil, diagnostics.Diag{{ - Severity: hcl.DiagError, - Summary: "Invalid global config", - Detail: "Global config should not have labels", - }} - } - value, hclDiags := hcldec.Decode(block.Body, globalConfigSpec, nil) - if diags.ExtendHcl(hclDiags) { - return +func (g *GlobalConfig) Merge(other *GlobalConfig) { + if other.CacheDir != "" { + g.CacheDir = other.CacheDir } - typ := hcldec.ImpliedType(globalConfigSpec) - errs := value.Type().TestConformance(typ) - if len(errs) > 0 { - var err error - value, err = convert.Convert(value, typ) - if err != nil { - diags.AppendErr(err, "Error while serializing global config") - return - } - } - cfg = &GlobalConfig{ - block: block, - CacheDir: "./.fabric", - PluginVersions: make(map[string]string), - } - cacheDir := value.GetAttr("cache_dir") - if !cacheDir.IsNull() && cacheDir.AsString() != "" { - cfg.CacheDir = cacheDir.AsString() - } - pluginRegistry := value.GetAttr("plugin_registry") - if !pluginRegistry.IsNull() { - mirrorDir := pluginRegistry.GetAttr("mirror_dir") - if !mirrorDir.IsNull() || mirrorDir.AsString() != "" { - cfg.PluginRegistry = &PluginRegistry{ - MirrorDir: mirrorDir.AsString(), + if other.PluginRegistry != nil { + if g.PluginRegistry == nil { + g.PluginRegistry = other.PluginRegistry + } else { + if other.PluginRegistry.BaseURL != "" { + g.PluginRegistry.BaseURL = other.PluginRegistry.BaseURL } - } - } - pluginVersions := value.GetAttr("plugin_versions") - if !pluginVersions.IsNull() { - versionMap := pluginVersions.AsValueMap() - for k, v := range versionMap { - if v.Type() != cty.String { - diags.Append(&hcl.Diagnostic{ - Severity: hcl.DiagError, - Summary: "Invalid plugin version", - Detail: fmt.Sprintf("Version of plugin '%s' should be a string", k), - }) - continue + if other.PluginRegistry.MirrorDir != "" { + g.PluginRegistry.MirrorDir = other.PluginRegistry.MirrorDir } - cfg.PluginVersions[k] = v.AsString() } } - return cfg, nil + g.PluginVersions = other.PluginVersions } diff --git a/parser/definitions/global_config_test.go b/parser/definitions/global_config_test.go deleted file mode 100644 index 9796dfc5..00000000 --- a/parser/definitions/global_config_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package definitions - -import ( - "testing" - - "github.com/hashicorp/hcl/v2" - "github.com/hashicorp/hcl/v2/hclsyntax" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestDefineGlobalConfig(t *testing.T) { - t.Parallel() - tt := []struct { - name string - block string - want *GlobalConfig - }{ - { - name: "empty", - block: `fabric {}`, - want: &GlobalConfig{ - CacheDir: "./.fabric", - PluginVersions: map[string]string{}, - }, - }, - { - name: "with_cache_dir", - block: `fabric { - cache_dir = "./.other_cache" - }`, - want: &GlobalConfig{ - CacheDir: "./.other_cache", - PluginVersions: map[string]string{}, - }, - }, - { - name: "with_plugin_registry", - block: `fabric { - plugin_registry { - mirror_dir = "./.other_mirror" - } - }`, - want: &GlobalConfig{ - CacheDir: "./.fabric", - PluginRegistry: &PluginRegistry{ - MirrorDir: "./.other_mirror", - }, - PluginVersions: map[string]string{}, - }, - }, - { - name: "with_plugin_versions", - block: `fabric { - plugin_versions = { - "namespace/plugin1" = "1.0.0" - "namespace/plugin2" = "2.0.0" - } - }`, - want: &GlobalConfig{ - CacheDir: "./.fabric", - PluginVersions: map[string]string{ - "namespace/plugin1": "1.0.0", - "namespace/plugin2": "2.0.0", - }, - }, - }, - { - name: "with_all", - block: `fabric { - cache_dir = "./.other_cache" - plugin_registry { - mirror_dir = "./.other_mirror" - } - plugin_versions = { - "namespace/plugin1" = "1.0.0" - "namespace/plugin2" = "2.0.0" - } - }`, - want: &GlobalConfig{ - CacheDir: "./.other_cache", - PluginRegistry: &PluginRegistry{ - MirrorDir: "./.other_mirror", - }, - PluginVersions: map[string]string{ - "namespace/plugin1": "1.0.0", - "namespace/plugin2": "2.0.0", - }, - }, - }, - } - - for _, tc := range tt { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - f, hcldiags := hclsyntax.ParseConfig([]byte(tc.block), "", hcl.Pos{}) - require.Len(t, hcldiags, 0) - body, ok := f.Body.(*hclsyntax.Body) - require.True(t, ok) - block := body.Blocks[0] - got, diags := DefineGlobalConfig(block) - assert.Len(t, diags, 0) - tc.want.block = block - assert.Equal(t, tc.want, got) - }) - } -} diff --git a/parser/parser.go b/parser/parser.go index 17ceb3c2..521692f8 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/hashicorp/hcl/v2" + "github.com/hashicorp/hcl/v2/gohcl" "github.com/hashicorp/hcl/v2/hclsyntax" "github.com/blackstork-io/fabric/parser/definitions" @@ -213,14 +214,14 @@ func parseBlockDefinitions(body *hclsyntax.Body) (res *DefinedBlocks, diags diag } diags.Append(AddIfMissing(res.Config, *key, cfg)) case definitions.BlockKindGlobalConfig: - globalCfg, dgs := definitions.DefineGlobalConfig(block) - if diags.Extend(dgs) { + var globalCfg definitions.GlobalConfig + if diags.ExtendHcl(gohcl.DecodeBody(block.Body, nil, &globalCfg)) { continue } if res.GlobalConfig != nil { diags.Add("Global config declared multiple times", "") } - res.GlobalConfig = globalCfg + res.GlobalConfig = &globalCfg default: diags.Append(definitions.NewNestingDiag( "Top level of fabric document", diff --git a/plugin/pluginapi/v1/client.go b/plugin/pluginapi/v1/client.go index 2b4c19a5..83f7a253 100644 --- a/plugin/pluginapi/v1/client.go +++ b/plugin/pluginapi/v1/client.go @@ -4,8 +4,6 @@ import ( "fmt" "log/slog" "os/exec" - "path/filepath" - "strings" goplugin "github.com/hashicorp/go-plugin" "google.golang.org/grpc" @@ -14,41 +12,21 @@ import ( "github.com/blackstork-io/fabric/plugin" ) -func parsePluginInfo(path string) (name, version string, err error) { - nameVer := filepath.Base(path) - ext := filepath.Ext(path) - - parts := strings.SplitN( - nameVer[:len(nameVer)-len(ext)], - "@", 2, - ) - if len(parts) != 2 { - err = fmt.Errorf("plugin at '%s' must have a file name '@[.exe]'", path) - return - } - name = parts[0] - version = parts[1] - return -} - -func NewClient(loc string) (p *plugin.Schema, closefn func() error, err error) { - pluginName, _, err := parsePluginInfo(loc) - if err != nil { - return - } - slog.Info("Loading plugin", "filename", loc) +func NewClient(name, binaryPath string, logger *slog.Logger) (p *plugin.Schema, closefn func() error, err error) { client := goplugin.NewClient(&goplugin.ClientConfig{ HandshakeConfig: handshake, Plugins: map[string]goplugin.Plugin{ - pluginName: &grpcPlugin{}, + name: &grpcPlugin{ + logger: logger, + }, }, - Cmd: exec.Command(loc), + Cmd: exec.Command(binaryPath), AllowedProtocols: []goplugin.Protocol{ goplugin.ProtocolGRPC, }, Logger: sloghclog.Adapt( - slog.Default(), - sloghclog.Name("plugin."+pluginName), + logger, + sloghclog.Name("plugin."+name), // disable code location reporting, it's always going to be incorrect // for remote plugin logs sloghclog.AddSource(false), @@ -65,7 +43,7 @@ func NewClient(loc string) (p *plugin.Schema, closefn func() error, err error) { if err != nil { return nil, nil, fmt.Errorf("failed to create plugin client: %w", err) } - raw, err := rpcClient.Dispense(pluginName) + raw, err := rpcClient.Dispense(name) if err != nil { rpcClient.Close() return nil, nil, fmt.Errorf("failed to dispense plugin: %w", err) diff --git a/plugin/pluginapi/v1/cty_type_decoder.go b/plugin/pluginapi/v1/cty_type_decoder.go index 7f2f2314..d20523e1 100644 --- a/plugin/pluginapi/v1/cty_type_decoder.go +++ b/plugin/pluginapi/v1/cty_type_decoder.go @@ -21,16 +21,12 @@ func decodeCtyType(src *CtyType) (cty.Type, error) { case *CtyType_Tuple: return decodeCtyTupleType(src.Tuple) case *CtyType_DynamicPseudo: - return decodeCtyDynamicPseudoType(src.DynamicPseudo) + return cty.DynamicPseudoType, nil default: return cty.NilType, fmt.Errorf("unsupported cty type: %T", src) } } -func decodeCtyDynamicPseudoType(src *CtyDynamicPseudoType) (cty.Type, error) { - return cty.DynamicPseudoType, nil -} - func decodeCtyPrimitiveType(src *CtyPrimitiveType) (cty.Type, error) { switch src.GetKind() { case CtyPrimitiveKind_CTY_PRIMITIVE_KIND_BOOL: diff --git a/plugin/pluginapi/v1/cty_type_encoder.go b/plugin/pluginapi/v1/cty_type_encoder.go index 0499e1a3..661c465a 100644 --- a/plugin/pluginapi/v1/cty_type_encoder.go +++ b/plugin/pluginapi/v1/cty_type_encoder.go @@ -21,20 +21,16 @@ func encodeCtyType(src cty.Type) (*CtyType, error) { case src.IsTupleType(): return encodeCtyTupleType(src) case src.Equals(cty.DynamicPseudoType): - return encodeCtyDynamicPseudoType(src) + return &CtyType{ + Data: &CtyType_DynamicPseudo{ + DynamicPseudo: &CtyDynamicPseudoType{}, + }, + }, nil default: return nil, fmt.Errorf("unsupported cty type: %s", src.FriendlyName()) } } -func encodeCtyDynamicPseudoType(src cty.Type) (*CtyType, error) { - return &CtyType{ - Data: &CtyType_DynamicPseudo{ - DynamicPseudo: &CtyDynamicPseudoType{}, - }, - }, nil -} - func encodeCtyPrimitiveType(src cty.Type) (*CtyType, error) { kind := CtyPrimitiveKind_CTY_PRIMITIVE_KIND_UNSPECIFIED switch src { diff --git a/plugin/pluginapi/v1/cty_value_decoder.go b/plugin/pluginapi/v1/cty_value_decoder.go index 7486bad1..52693074 100644 --- a/plugin/pluginapi/v1/cty_value_decoder.go +++ b/plugin/pluginapi/v1/cty_value_decoder.go @@ -16,23 +16,23 @@ func decodeCtyValue(src *CtyValue) (cty.Value, error) { } switch { case t.IsPrimitiveType() && src.GetPrimitive() != nil: - return decodeCtyPrimitiveValue(t, src.GetPrimitive()) + return decodeCtyPrimitiveValue(src.GetPrimitive()) case t.IsListType() && src.GetList() != nil: - return decodeCtyListValue(t, src.GetList()) + return decodeCtyListValue(src.GetList()) case t.IsMapType() && src.GetMap() != nil: - return decodeCtyMapValue(t, src.GetMap()) + return decodeCtyMapValue(src.GetMap()) case t.IsSetType() && src.GetSet() != nil: - return decodeCtySetValue(t, src.GetSet()) + return decodeCtySetValue(src.GetSet()) case t.IsObjectType() && src.GetObject() != nil: - return decodeCtyObjectValue(t, src.GetObject()) + return decodeCtyObjectValue(src.GetObject()) case t.IsTupleType() && src.GetTuple() != nil: - return decodeCtyTupleValue(t, src.GetTuple()) + return decodeCtyTupleValue(src.GetTuple()) default: return cty.NullVal(t), nil } } -func decodeCtyTupleValue(t cty.Type, src *CtyTupleValue) (cty.Value, error) { +func decodeCtyTupleValue(src *CtyTupleValue) (cty.Value, error) { elements := make([]cty.Value, len(src.GetElements())) var err error for i, elem := range src.GetElements() { @@ -44,7 +44,7 @@ func decodeCtyTupleValue(t cty.Type, src *CtyTupleValue) (cty.Value, error) { return cty.TupleVal(elements), nil } -func decodeCtyObjectValue(t cty.Type, src *CtyObjectValue) (cty.Value, error) { +func decodeCtyObjectValue(src *CtyObjectValue) (cty.Value, error) { attrs := make(map[string]cty.Value, len(src.GetAttrs())) var err error for k, v := range src.GetAttrs() { @@ -56,7 +56,7 @@ func decodeCtyObjectValue(t cty.Type, src *CtyObjectValue) (cty.Value, error) { return cty.ObjectVal(attrs), nil } -func decodeCtySetValue(t cty.Type, src *CtySetValue) (cty.Value, error) { +func decodeCtySetValue(src *CtySetValue) (cty.Value, error) { elements := make([]cty.Value, len(src.GetElements())) var err error for i, elem := range src.GetElements() { @@ -68,7 +68,7 @@ func decodeCtySetValue(t cty.Type, src *CtySetValue) (cty.Value, error) { return cty.SetVal(elements), nil } -func decodeCtyMapValue(t cty.Type, src *CtyMapValue) (cty.Value, error) { +func decodeCtyMapValue(src *CtyMapValue) (cty.Value, error) { elements := make(map[string]cty.Value, len(src.GetElements())) var err error for k, v := range src.GetElements() { @@ -80,7 +80,7 @@ func decodeCtyMapValue(t cty.Type, src *CtyMapValue) (cty.Value, error) { return cty.MapVal(elements), nil } -func decodeCtyListValue(t cty.Type, src *CtyListValue) (cty.Value, error) { +func decodeCtyListValue(src *CtyListValue) (cty.Value, error) { elements := make([]cty.Value, len(src.GetElements())) var err error for i, elem := range src.GetElements() { @@ -92,7 +92,7 @@ func decodeCtyListValue(t cty.Type, src *CtyListValue) (cty.Value, error) { return cty.ListVal(elements), nil } -func decodeCtyPrimitiveValue(t cty.Type, src *CtyPrimitiveValue) (cty.Value, error) { +func decodeCtyPrimitiveValue(src *CtyPrimitiveValue) (cty.Value, error) { switch data := src.GetData().(type) { case *CtyPrimitiveValue_Bln: return cty.BoolVal(data.Bln), nil diff --git a/plugin/pluginapi/v1/hclspec_decoder.go b/plugin/pluginapi/v1/hclspec_decoder.go index 14826ad9..f15ca3c0 100644 --- a/plugin/pluginapi/v1/hclspec_decoder.go +++ b/plugin/pluginapi/v1/hclspec_decoder.go @@ -7,9 +7,6 @@ import ( ) func decodeHclSpec(src *HclSpec) (hcldec.Spec, error) { - if src == nil { - return nil, nil - } switch { case src == nil || src.GetData() == nil: return nil, nil diff --git a/plugin/pluginapi/v1/plugin.go b/plugin/pluginapi/v1/plugin.go index cfe49690..cf4a6c4a 100644 --- a/plugin/pluginapi/v1/plugin.go +++ b/plugin/pluginapi/v1/plugin.go @@ -3,6 +3,8 @@ package pluginapiv1 import ( context "context" "fmt" + "log/slog" + "time" goplugin "github.com/hashicorp/go-plugin" "github.com/hashicorp/hcl/v2" @@ -21,6 +23,7 @@ var handshake = goplugin.HandshakeConfig{ type grpcPlugin struct { goplugin.Plugin + logger *slog.Logger schema *plugin.Schema } @@ -33,7 +36,6 @@ func (p *grpcPlugin) GRPCServer(broker *goplugin.GRPCBroker, s *grpc.Server) err func (p *grpcPlugin) GRPCClient(ctx context.Context, broker *goplugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) { client := NewPluginServiceClient(c) - res, err := client.GetSchema(ctx, &GetSchemaRequest{}) if err != nil { return nil, err @@ -66,6 +68,10 @@ func (p *grpcPlugin) callOptions() []grpc.CallOption { func (p *grpcPlugin) clientGenerateFunc(name string, client PluginServiceClient) plugin.ProvideContentFunc { return func(ctx context.Context, params *plugin.ProvideContentParams) (*plugin.Content, hcl.Diagnostics) { + p.logger.Debug("Calling content provider", "name", name) + defer func(start time.Time) { + p.logger.Debug("Called content provider", "name", name, "took", time.Since(start)) + }(time.Now()) if params == nil { return nil, hcl.Diagnostics{{ Severity: hcl.DiagError, @@ -110,6 +116,10 @@ func (p *grpcPlugin) clientGenerateFunc(name string, client PluginServiceClient) func (p *grpcPlugin) clientDataFunc(name string, client PluginServiceClient) plugin.RetrieveDataFunc { return func(ctx context.Context, params *plugin.RetrieveDataParams) (plugin.Data, hcl.Diagnostics) { + p.logger.Debug("Calling data source", "name", name) + defer func(start time.Time) { + p.logger.Debug("Called data source", "name", name, "took", time.Since(start)) + }(time.Now()) if params == nil { return nil, hcl.Diagnostics{{ Severity: hcl.DiagError, diff --git a/plugin/registry/.gitkeep b/plugin/registry/.gitkeep deleted file mode 100644 index e69de29b..00000000 diff --git a/plugin/resolver/checksum.go b/plugin/resolver/checksum.go new file mode 100644 index 00000000..db0a11ed --- /dev/null +++ b/plugin/resolver/checksum.go @@ -0,0 +1,116 @@ +package resolver + +import ( + "bufio" + "bytes" + "encoding/base64" + "fmt" + "io" + "strconv" + "strings" +) + +// Checksum for plugin binaries and archives. +// It contains the object name, os, arch, and the hash sum value. +// +// Format in string: ':::'. +// +// Example: 'archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8='. +type Checksum struct { + Object string + OS string + Arch string + Sum []byte +} + +func encodeChecksums(w io.Writer, checksums []Checksum) error { + writer := bufio.NewWriter(w) + for _, c := range checksums { + if _, err := writer.WriteString(c.String() + "\n"); err != nil { + return err + } + } + return writer.Flush() +} + +func decodeChecksums(r io.Reader) ([]Checksum, error) { + var checksums []Checksum + scanner := bufio.NewScanner(r) + for scanner.Scan() { + var c Checksum + if err := c.UnmarshalText(scanner.Bytes()); err != nil { + return nil, err + } + checksums = append(checksums, c) + } + return checksums, nil +} + +func (c Checksum) Compare(other Checksum) int { + cmp := strings.Compare(c.Object, other.Object) + if cmp != 0 { + return cmp + } + cmp = strings.Compare(c.OS, other.OS) + if cmp != 0 { + return cmp + } + cmp = strings.Compare(c.Arch, other.Arch) + if cmp != 0 { + return cmp + } + return bytes.Compare(c.Sum, other.Sum) +} + +func (c Checksum) Match(list []Checksum) bool { + for _, other := range list { + if c.Compare(other) == 0 { + return true + } + } + return false +} + +func (c Checksum) String() string { + return strings.Join([]string{ + c.Object, + c.OS, + c.Arch, + base64.StdEncoding.EncodeToString(c.Sum), + }, ":") +} + +func (c Checksum) MarshalText() ([]byte, error) { + return []byte(c.String()), nil +} + +func (c *Checksum) UnmarshalText(data []byte) error { + raw := string(data) + parts := strings.Split(raw, ":") + if len(parts) != 4 { + return fmt.Errorf("invalid checksum format: %s", raw) + } + sum, err := base64.StdEncoding.DecodeString(parts[3]) + if err != nil { + return fmt.Errorf("failed to decode checksum: %w", err) + } + *c = Checksum{ + Object: parts[0], + OS: parts[1], + Arch: parts[2], + Sum: sum, + } + return nil +} + +func (c Checksum) MarshalJSON() ([]byte, error) { + return []byte(strconv.Quote(c.String())), nil +} + +func (c *Checksum) UnmarshalJSON(data []byte) error { + raw, err := strconv.Unquote(string(data)) + if err != nil { + return fmt.Errorf("failed to unquote checksum: %w", err) + } + return c.UnmarshalText([]byte(raw)) +} diff --git a/plugin/resolver/checksum_test.go b/plugin/resolver/checksum_test.go new file mode 100644 index 00000000..4a24a9c9 --- /dev/null +++ b/plugin/resolver/checksum_test.go @@ -0,0 +1,327 @@ +package resolver + +import ( + "bytes" + "encoding/base64" + "io" + "reflect" + "slices" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func mustBase64(t *testing.T, s string) []byte { + sum, err := base64.StdEncoding.DecodeString(s) + require.NoError(t, err) + return sum +} + +func TestPluginChecksum_UnmarshalJSON(t *testing.T) { + str := `"archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="` + var dst Checksum + err := dst.UnmarshalJSON([]byte(str)) + require.NoError(t, err) + require.Equal(t, "archive", dst.Object) + require.Equal(t, "darwin", dst.OS) + require.Equal(t, "arm64", dst.Arch) + require.Equal(t, "lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", base64.StdEncoding.EncodeToString(dst.Sum)) +} + +func TestPluginChecksum_MarshalJSON(t *testing.T) { + sum, _ := base64.StdEncoding.DecodeString("lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=") + src := Checksum{ + Object: "archive", + OS: "darwin", + Arch: "arm64", + Sum: sum, + } + data, err := src.MarshalJSON() + require.NoError(t, err) + require.Equal(t, `"archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="`, string(data)) +} + +func Test_encodeChecksums(t *testing.T) { + type args struct { + checksums []Checksum + } + tests := []struct { + name string + args args + wantW string + wantErr bool + }{ + { + name: "empty", + args: args{ + checksums: nil, + }, + wantW: "", + wantErr: false, + }, + { + name: "single", + args: args{ + checksums: []Checksum{ + { + Object: "archive", + OS: "darwin", + Arch: "arm64", + Sum: mustBase64(t, "lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + }, + }, + wantW: "archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=\n", + wantErr: false, + }, + { + name: "multiple", + args: args{ + checksums: []Checksum{ + { + Object: "archive", + OS: "darwin", + Arch: "arm64", + Sum: mustBase64(t, "lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + { + Object: "binary", + OS: "linux", + Arch: "amd64", + Sum: mustBase64(t, "lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + }, + }, + wantW: "archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=\n" + + "binary:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=\n", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &bytes.Buffer{} + if err := encodeChecksums(w, tt.args.checksums); (err != nil) != tt.wantErr { + t.Errorf("encodeChecksums() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotW := w.String(); gotW != tt.wantW { + t.Errorf("encodeChecksums() = %v, want %v", gotW, tt.wantW) + } + }) + } +} + +func Test_decodeChecksums(t *testing.T) { + type args struct { + r io.Reader + } + tests := []struct { + name string + args args + want []Checksum + wantErr bool + }{ + { + name: "empty", + args: args{ + r: bytes.NewBufferString(""), + }, + want: nil, + wantErr: false, + }, + { + name: "single", + args: args{ + r: bytes.NewBufferString("archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=\n"), + }, + want: []Checksum{ + { + Object: "archive", + OS: "darwin", + Arch: "arm64", + Sum: mustBase64(t, "lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + }, + wantErr: false, + }, + { + name: "multiple", + args: args{ + r: bytes.NewBufferString("archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=\n" + + "binary:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=\n"), + }, + want: []Checksum{ + { + Object: "archive", + OS: "darwin", + Arch: "arm64", + Sum: mustBase64(t, "lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + { + Object: "binary", + OS: "linux", + Arch: "amd64", + Sum: mustBase64(t, "lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + }, + wantErr: false, + }, + { + name: "invalid", + args: args{ + r: bytes.NewBufferString("invalid\n"), + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := decodeChecksums(tt.args.r) + if (err != nil) != tt.wantErr { + t.Errorf("decodeChecksums() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("decodeChecksums() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestChecksum_Compare(t *testing.T) { + strList := []string{ + "archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "binary:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "binary:windows:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "binary:darwin:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "binary:linux:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "archive:windows:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "archive:linux:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "archive:darwin:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "archive:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "binary:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "archive:windows:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "binary:windows:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + } + checksums := make([]Checksum, len(strList)) + for i, str := range strList { + checksums[i].UnmarshalText([]byte(str)) + } + slices.SortFunc(checksums, func(a, b Checksum) int { + return a.Compare(b) + }) + + gotStrList := make([]string, len(checksums)) + for i, c := range checksums { + gotStrList[i] = c.String() + } + + assert.Exactly(t, []string{ + "archive:darwin:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "archive:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "archive:linux:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "archive:windows:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "archive:windows:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "binary:darwin:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "binary:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "binary:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "binary:linux:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "binary:windows:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + "binary:windows:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", + }, gotStrList) +} + +func mustChecksum(t *testing.T, str string) Checksum { + t.Helper() + var c Checksum + err := c.UnmarshalText([]byte(str)) + require.NoError(t, err) + return c +} + +func TestChecksum_Match(t *testing.T) { + tests := []struct { + name string + sum Checksum + args []Checksum + want bool + }{ + { + name: "match", + sum: mustChecksum(t, "archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + args: []Checksum{ + mustChecksum(t, "binary:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + mustChecksum(t, "archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + mustChecksum(t, "binary:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + want: true, + }, + { + name: "empty", + sum: mustChecksum(t, "archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + args: nil, + want: false, + }, + { + name: "mismatch", + sum: mustChecksum(t, "archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + args: []Checksum{ + mustChecksum(t, "binary:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + mustChecksum(t, "binary:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + want: false, + }, + { + name: "mismatch_os", + sum: mustChecksum(t, "archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + args: []Checksum{ + mustChecksum(t, "archive:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + mustChecksum(t, "binary:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + want: false, + }, + { + name: "mismatch_arch", + sum: mustChecksum(t, "archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + args: []Checksum{ + mustChecksum(t, "archive:darwin:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + mustChecksum(t, "binary:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + want: false, + }, + { + name: "mismatch_sum", + sum: mustChecksum(t, "archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + args: []Checksum{ + mustChecksum(t, "archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDa8="), + mustChecksum(t, "binary:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + want: false, + }, + { + name: "empty_sum", + sum: mustChecksum(t, ":::"), + args: []Checksum{ + mustChecksum(t, "archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ9="), + mustChecksum(t, "binary:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + want: false, + }, + { + name: "empty_args", + sum: mustChecksum(t, "archive:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ9="), + args: nil, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.sum.Match(tt.args); got != tt.want { + t.Errorf("Checksum.Match() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/plugin/resolver/lockfile.go b/plugin/resolver/lockfile.go new file mode 100644 index 00000000..91c578e3 --- /dev/null +++ b/plugin/resolver/lockfile.go @@ -0,0 +1,108 @@ +package resolver + +import ( + "encoding/json" + "fmt" + "io" + "os" + "slices" +) + +// LockFile is a plugin lock configuration. +type LockFile struct { + Plugins []PluginLock `json:"plugins"` +} + +// PluginLock is a lock for a one plugin. +type PluginLock struct { + Name Name `json:"name"` + Version Version `json:"version"` + Checksums []Checksum `json:"checksums"` +} + +// LockCheckResult is the result of a lock check. +type LockCheckResult struct { + Missing ConstraintMap + Mismatch ConstraintMap + Removed map[Name]Version +} + +// IsInstallRequired returns true if the lock check result requires an install. +func (result LockCheckResult) IsInstallRequired() bool { + return len(result.Missing) > 0 || len(result.Mismatch) > 0 +} + +// Check the lock configuration against the given constraints. +func (file *LockFile) Check(constraints ConstraintMap) LockCheckResult { + missing := ConstraintMap{} + mismatch := ConstraintMap{} + removed := map[Name]Version{} + for name, constraint := range constraints { + idx := slices.IndexFunc(file.Plugins, func(lock PluginLock) bool { + return lock.Name == name + }) + if idx == -1 { + missing[name] = constraint + continue + } + lock := file.Plugins[idx] + if !constraint.Check(lock.Version.Version) { + mismatch[name] = constraint + } + } + for _, lock := range file.Plugins { + if _, ok := constraints[lock.Name]; !ok { + removed[lock.Name] = lock.Version + } + } + return LockCheckResult{ + Missing: missing, + Mismatch: mismatch, + Removed: removed, + } +} + +// ReadLockFile parses a lock configuration from a reader. +func ReadLockFile(r io.Reader) (*LockFile, error) { + var lockFile LockFile + err := json.NewDecoder(r).Decode(&lockFile) + if err != nil { + return nil, err + } + return &lockFile, nil +} + +// ReadLockFileFrom parses a lock configuration from a local file. +func ReadLockFileFrom(path string) (*LockFile, error) { + file, err := os.Open(path) + if os.IsNotExist(err) { + return &LockFile{}, nil + } else if err != nil { + return nil, err + } + defer file.Close() + return ReadLockFile(file) +} + +// SaveLockFile saves a lock configuration to a writer. +func SaveLockFile(w io.Writer, lockFile *LockFile) error { + if lockFile == nil { + return fmt.Errorf("plugin lock file is nil") + } + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + return enc.Encode(lockFile) +} + +// SaveLockFileTo saves a lock configuration to a local file. +func SaveLockFileTo(path string, lockFile *LockFile) error { + if lockFile == nil { + return fmt.Errorf("plugin lock file is nil") + } + file, err := os.Create(path) + if err != nil { + return err + } + defer file.Close() + return SaveLockFile(file, lockFile) +} diff --git a/plugin/resolver/lockfile_test.go b/plugin/resolver/lockfile_test.go new file mode 100644 index 00000000..ba2f354e --- /dev/null +++ b/plugin/resolver/lockfile_test.go @@ -0,0 +1,91 @@ +package resolver + +import ( + "bytes" + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLockCheckResult_IsInstallRequired(t *testing.T) { + locks := LockFile{ + Plugins: []PluginLock{ + { + Name: mustName(t, "ns/name"), + Version: mustVersion(t, "1.0.0"), + }, + { + Name: mustName(t, "ns/name2"), + Version: mustVersion(t, "2.0.0"), + }, + { + Name: mustName(t, "ns/name3"), + Version: mustVersion(t, "3.0.0"), + }, + }, + } + constraints := ConstraintMap{ + mustName(t, "ns/name"): mustConstraint(t, ">1.0.0"), + mustName(t, "ns/name2"): mustConstraint(t, "<=2.0.0"), + mustName(t, "ns/name4"): mustConstraint(t, "3.0.0"), + } + result := locks.Check(constraints) + assert.True(t, result.IsInstallRequired(), "expected install required") + assert.Len(t, result.Missing, 1, "expected 1 missing") + assert.Len(t, result.Mismatch, 1, "expected 1 mismatch") + assert.Len(t, result.Removed, 1, "expected 1 removed") +} + +func TestReadLockFile(t *testing.T) { + buf := bytes.NewBufferString(`{ + "plugins": [ + { + "name": "ns/name", + "version": "1.0.0", + "checksums": ["binary:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="] + } + ] + }`) + locks, err := ReadLockFile(buf) + assert.NoError(t, err) + assert.Len(t, locks.Plugins, 1) + assert.Equal(t, "ns/name", locks.Plugins[0].Name.String()) + assert.Equal(t, "1.0.0", locks.Plugins[0].Version.String()) + assert.Len(t, locks.Plugins[0].Checksums, 1) + assert.Equal(t, "binary", locks.Plugins[0].Checksums[0].Object) + assert.Equal(t, "darwin", locks.Plugins[0].Checksums[0].OS) + assert.Equal(t, "arm64", locks.Plugins[0].Checksums[0].Arch) + assert.Equal(t, "lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8=", base64.StdEncoding.EncodeToString(locks.Plugins[0].Checksums[0].Sum)) +} + +func TestSaveLockFile(t *testing.T) { + locks := &LockFile{ + Plugins: []PluginLock{ + { + Name: mustName(t, "ns/name"), + Version: mustVersion(t, "1.0.0"), + Checksums: []Checksum{ + { + Object: "binary", + OS: "darwin", + Arch: "arm64", + Sum: mustBase64(t, "lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + }, + }, + }, + } + buf := bytes.NewBuffer(nil) + err := SaveLockFile(buf, locks) + assert.NoError(t, err) + assert.JSONEq(t, `{ + "plugins": [ + { + "name": "ns/name", + "version": "1.0.0", + "checksums": ["binary:darwin:arm64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="] + } + ] + }`, buf.String()) +} diff --git a/plugin/resolver/mock_source_test.go b/plugin/resolver/mock_source_test.go new file mode 100644 index 00000000..bb7839df --- /dev/null +++ b/plugin/resolver/mock_source_test.go @@ -0,0 +1,156 @@ +// Code generated by mockery v2.42.0. DO NOT EDIT. + +package resolver + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// mockSource is an autogenerated mock type for the Source type +type mockSource struct { + mock.Mock +} + +type mockSource_Expecter struct { + mock *mock.Mock +} + +func (_m *mockSource) EXPECT() *mockSource_Expecter { + return &mockSource_Expecter{mock: &_m.Mock} +} + +// Lookup provides a mock function with given fields: ctx, name +func (_m *mockSource) Lookup(ctx context.Context, name Name) ([]Version, error) { + ret := _m.Called(ctx, name) + + if len(ret) == 0 { + panic("no return value specified for Lookup") + } + + var r0 []Version + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, Name) ([]Version, error)); ok { + return rf(ctx, name) + } + if rf, ok := ret.Get(0).(func(context.Context, Name) []Version); ok { + r0 = rf(ctx, name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]Version) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, Name) error); ok { + r1 = rf(ctx, name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// mockSource_Lookup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Lookup' +type mockSource_Lookup_Call struct { + *mock.Call +} + +// Lookup is a helper method to define mock.On call +// - ctx context.Context +// - name Name +func (_e *mockSource_Expecter) Lookup(ctx interface{}, name interface{}) *mockSource_Lookup_Call { + return &mockSource_Lookup_Call{Call: _e.mock.On("Lookup", ctx, name)} +} + +func (_c *mockSource_Lookup_Call) Run(run func(ctx context.Context, name Name)) *mockSource_Lookup_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(Name)) + }) + return _c +} + +func (_c *mockSource_Lookup_Call) Return(_a0 []Version, _a1 error) *mockSource_Lookup_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *mockSource_Lookup_Call) RunAndReturn(run func(context.Context, Name) ([]Version, error)) *mockSource_Lookup_Call { + _c.Call.Return(run) + return _c +} + +// Resolve provides a mock function with given fields: ctx, name, version, checksums +func (_m *mockSource) Resolve(ctx context.Context, name Name, version Version, checksums []Checksum) (*ResolvedPlugin, error) { + ret := _m.Called(ctx, name, version, checksums) + + if len(ret) == 0 { + panic("no return value specified for Resolve") + } + + var r0 *ResolvedPlugin + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, Name, Version, []Checksum) (*ResolvedPlugin, error)); ok { + return rf(ctx, name, version, checksums) + } + if rf, ok := ret.Get(0).(func(context.Context, Name, Version, []Checksum) *ResolvedPlugin); ok { + r0 = rf(ctx, name, version, checksums) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*ResolvedPlugin) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, Name, Version, []Checksum) error); ok { + r1 = rf(ctx, name, version, checksums) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// mockSource_Resolve_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Resolve' +type mockSource_Resolve_Call struct { + *mock.Call +} + +// Resolve is a helper method to define mock.On call +// - ctx context.Context +// - name Name +// - version Version +// - checksums []Checksum +func (_e *mockSource_Expecter) Resolve(ctx interface{}, name interface{}, version interface{}, checksums interface{}) *mockSource_Resolve_Call { + return &mockSource_Resolve_Call{Call: _e.mock.On("Resolve", ctx, name, version, checksums)} +} + +func (_c *mockSource_Resolve_Call) Run(run func(ctx context.Context, name Name, version Version, checksums []Checksum)) *mockSource_Resolve_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(Name), args[2].(Version), args[3].([]Checksum)) + }) + return _c +} + +func (_c *mockSource_Resolve_Call) Return(_a0 *ResolvedPlugin, _a1 error) *mockSource_Resolve_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *mockSource_Resolve_Call) RunAndReturn(run func(context.Context, Name, Version, []Checksum) (*ResolvedPlugin, error)) *mockSource_Resolve_Call { + _c.Call.Return(run) + return _c +} + +// newMockSource creates a new instance of mockSource. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockSource(t interface { + mock.TestingT + Cleanup(func()) +}) *mockSource { + mock := &mockSource{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/plugin/resolver/name.go b/plugin/resolver/name.go new file mode 100644 index 00000000..1d856abe --- /dev/null +++ b/plugin/resolver/name.go @@ -0,0 +1,67 @@ +package resolver + +import ( + "fmt" + "strconv" + "strings" +) + +const ( + namespaceIdx = 0 + shortNameIdx = 1 +) + +// Name of a plugin structured as '/'. +type Name [2]string + +// Namespace returns the namespace part of the plugin name. +func (n Name) Namespace() string { + return n[namespaceIdx] +} + +// Short returns the short name part of the plugin name. +func (n Name) Short() string { + return n[shortNameIdx] +} + +// String returns the full plugin name in the form '/'. +func (n Name) String() string { + return fmt.Sprintf("%s/%s", n[namespaceIdx], n[shortNameIdx]) +} + +// Compare compares the plugin name with another plugin name. +func (n Name) Compare(other Name) int { + cmp := strings.Compare(n.Namespace(), other.Namespace()) + if cmp != 0 { + return cmp + } + return strings.Compare(n.Short(), other.Short()) +} + +// MarshalJSON returns the JSON representation of the plugin name in '/' format. +func (n Name) MarshalJSON() ([]byte, error) { + return []byte(strconv.Quote(n.String())), nil +} + +// UnmarshalJSON parses a JSON string into a PluginName from '/' format. +func (n *Name) UnmarshalJSON(data []byte) error { + raw, err := strconv.Unquote(string(data)) + if err != nil { + return fmt.Errorf("failed to unquote plugin name: %w", err) + } + name, err := ParseName(raw) + if err != nil { + return err + } + *n = Name(name) + return nil +} + +// ParseName parses a plugin name from '/' format. +func ParseName(name string) (Name, error) { + parts := strings.Split(name, "/") + if len(parts) != 2 || len(parts[namespaceIdx]) == 0 || len(parts[shortNameIdx]) == 0 { + return [2]string{}, fmt.Errorf("plugin name '%s' is not in the form '/'", name) + } + return Name(parts), nil +} diff --git a/plugin/resolver/name_test.go b/plugin/resolver/name_test.go new file mode 100644 index 00000000..3ad7fc54 --- /dev/null +++ b/plugin/resolver/name_test.go @@ -0,0 +1,148 @@ +package resolver + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func mustName(t *testing.T, str string) Name { + t.Helper() + name, err := ParseName(str) + assert.NoError(t, err) + return name +} + +func TestName_String(t *testing.T) { + name := Name{"namespace", "short"} + assert.Equal(t, "namespace/short", name.String()) + assert.Equal(t, "namespace", name.Namespace()) + assert.Equal(t, "short", name.Short()) +} + +func TestName_Compare(t *testing.T) { + tests := []struct { + name string + name1 Name + name2 Name + want int + }{ + { + name: "equal", + name1: Name{"namespace", "short"}, + name2: Name{"namespace", "short"}, + want: 0, + }, + { + name: "namespace", + name1: Name{"namespace1", "short"}, + name2: Name{"namespace2", "short"}, + want: -1, + }, + { + name: "short", + name1: Name{"namespace", "short1"}, + name2: Name{"namespace", "short2"}, + want: -1, + }, + { + name: "both", + name1: Name{"namespace1", "short1"}, + name2: Name{"namespace2", "short2"}, + want: -1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.name1.Compare(tt.name2); got != tt.want { + t.Errorf("Name.Compare() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestName_JSON(t *testing.T) { + name := Name{"ns", "name"} + data, err := name.MarshalJSON() + assert.NoError(t, err) + assert.Equal(t, `"ns/name"`, string(data)) + + var name2 Name + err = name2.UnmarshalJSON(data) + assert.NoError(t, err) + assert.Equal(t, name, name2) +} + +func TestParseName(t *testing.T) { + type args struct { + name string + } + tests := []struct { + name string + args args + want Name + wantErr bool + }{ + { + name: "simple", + args: args{ + name: "namespace/short", + }, + want: Name{"namespace", "short"}, + wantErr: false, + }, + { + name: "empty", + args: args{ + name: "", + }, + want: Name{}, + wantErr: true, + }, + { + name: "no_slash", + args: args{ + name: "namespace", + }, + want: Name{}, + wantErr: true, + }, + { + name: "no_short", + args: args{ + name: "namespace/", + }, + want: Name{}, + wantErr: true, + }, + { + name: "no_namespace", + args: args{ + name: "/short", + }, + want: Name{}, + wantErr: true, + }, + { + name: "only_slash", + args: args{ + name: "/", + }, + want: Name{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseName(tt.args.name) + if (err != nil) != tt.wantErr { + t.Errorf("ParseName() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseName() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/plugin/resolver/options.go b/plugin/resolver/options.go new file mode 100644 index 00000000..7bbe506e --- /dev/null +++ b/plugin/resolver/options.go @@ -0,0 +1,34 @@ +package resolver + +import ( + "io" + "log/slog" +) + +// options for the resolver. +type options struct { + logger *slog.Logger + sources []Source +} + +var defaultOptions = options{ + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + sources: []Source{}, +} + +// Option is a functional option for the resolver. +type Option func(*options) + +// WithLogger sets the logger for the resolver. +func WithLogger(logger *slog.Logger) Option { + return func(o *options) { + o.logger = logger + } +} + +// WithSources sets the sources for the resolver. +func WithSources(sources ...Source) Option { + return func(o *options) { + o.sources = sources + } +} diff --git a/plugin/resolver/resolver.go b/plugin/resolver/resolver.go new file mode 100644 index 00000000..3b3e0625 --- /dev/null +++ b/plugin/resolver/resolver.go @@ -0,0 +1,226 @@ +package resolver + +import ( + "context" + "fmt" + "maps" + "slices" + + "github.com/hashicorp/hcl/v2" +) + +// Resolver resolves and installs plugins. +type Resolver struct { + constraints ConstraintMap + options +} + +// NewResolver creates a new plugin resolver. +func NewResolver(constraints map[string]string, opts ...Option) (*Resolver, hcl.Diagnostics) { + parsedVersions, err := ParseConstraintMap(constraints) + if err != nil { + return nil, hcl.Diagnostics{{ + Severity: hcl.DiagError, + Summary: "Failed to parse plugin versions", + Detail: err.Error(), + }} + } + res := &Resolver{ + constraints: parsedVersions, + options: defaultOptions, + } + for _, opt := range opts { + opt(&res.options) + } + return res, nil +} + +// Install all plugins based the version constraints and return updated a lock file. +func (r *Resolver) Install(ctx context.Context, lockFile *LockFile, upgrade bool) (*LockFile, hcl.Diagnostics) { + check := lockFile.Check(r.constraints) + locks := []PluginLock{} + // lookupMap is a map of plugins that are we look up based on the constraints + lookupMap := make(ConstraintMap) + if upgrade { + // if upgrade is enabled we install all plugins based on the constraints + maps.Copy(lookupMap, r.constraints) + } else { + // if upgrade is disabled we only install the missing and mismatched plugins based on the constraints + maps.Copy(lookupMap, check.Missing) + maps.Copy(lookupMap, check.Mismatch) + } + chain := makeSourceChain(r.sources...) + // resolve the plugins by the latest version that matches the constraints + for name, constraint := range lookupMap { + r.logger.Info("Searching plugin", "name", name.String(), "constraints", constraint.String()) + list, err := chain.Lookup(ctx, name) + if err != nil { + return nil, hcl.Diagnostics{{ + Severity: hcl.DiagError, + Summary: fmt.Sprintf("Failed to lookup plugin '%s'", name), + Detail: err.Error(), + }} + } + if len(list) == 0 { + return nil, hcl.Diagnostics{{ + Severity: hcl.DiagError, + Summary: fmt.Sprintf("Plugin '%s' not found", name), + Detail: "Could not find version for the current platform", + }} + } + // filter out the versions that do not match the constraint + matches := slices.DeleteFunc(list, func(v Version) bool { + return !constraint.Check(v.Version) + }) + if len(matches) == 0 { + return nil, hcl.Diagnostics{{ + Severity: hcl.DiagError, + Summary: fmt.Sprintf("Plugin '%s' not found", name), + Detail: fmt.Sprintf("No version of '%s' matches the constraint '%s'", name, constraint), + }} + } + max := slices.MaxFunc(matches, func(a, b Version) int { + return a.Compare(b) + }) + r.logger.Info("Installing plugin", "name", name.String(), "version", max.String()) + var checksums []Checksum + // check if the plugin with the same version is already in the lock file + lockIdx := slices.IndexFunc(lockFile.Plugins, func(lock PluginLock) bool { + return lock.Name == name && lock.Version.Equal(max.Version) + }) + if lockIdx > -1 { + // use the checksums from the lock file + checksums = lockFile.Plugins[lockIdx].Checksums + } + res, err := chain.Resolve(ctx, name, max, checksums) + if err != nil { + return nil, hcl.Diagnostics{{ + Severity: hcl.DiagError, + Summary: fmt.Sprintf("Failed to resolve plugin '%s@%s'", name, max), + Detail: err.Error(), + }} + } + // sort the checksums + slices.SortFunc(res.Checksums, func(a, b Checksum) int { + return a.Compare(b) + }) + locks = append(locks, PluginLock{ + Name: name, + Version: max, + Checksums: res.Checksums, + }) + // check if context is cancelled + if ctx.Err() != nil { + return nil, hcl.Diagnostics{{ + Severity: hcl.DiagError, + Summary: "Cancelled by context", + Detail: ctx.Err().Error(), + }} + } + } + // resolve the rest of plugins based on the strict locked versions + for _, lock := range lockFile.Plugins { + // skip plugins that are already resolved + if _, ok := lookupMap[lock.Name]; ok { + continue + } + // skip plugins that are removed from the version constraints + if _, ok := check.Removed[lock.Name]; ok { + continue + } + r.logger.Info("Installing plugin", "name", lock.Name.String(), "version", lock.Version.String()) + _, err := chain.Resolve(ctx, lock.Name, lock.Version, lock.Checksums) + if err != nil { + return nil, hcl.Diagnostics{{ + Severity: hcl.DiagError, + Summary: fmt.Sprintf("Failed to resolve plugin '%s@%s'", lock.Name, lock.Version), + Detail: err.Error(), + }} + } + locks = append(locks, lock) + // check if context is cancelled + if ctx.Err() != nil { + return nil, hcl.Diagnostics{{ + Severity: hcl.DiagError, + Summary: "Cancelled by context", + Detail: ctx.Err().Error(), + }} + } + } + // sort plugin locks by name + slices.SortFunc(locks, func(a, b PluginLock) int { + return a.Name.Compare(b.Name) + }) + return &LockFile{ + Plugins: locks, + }, nil +} + +// Resolve all plugins based on the constraints and returns a map of plugin names to binary paths. +// If the lock file is not satisfied, an error is returned. +func (r *Resolver) Resolve(ctx context.Context, lockFile *LockFile) (map[string]string, hcl.Diagnostics) { + var diags hcl.Diagnostics + // check if the lock file is satisfied by version constraints + check := lockFile.Check(r.constraints) + for name, lock := range check.Removed { + // warn about plugins that are removed from the version constraints + diags = diags.Extend(hcl.Diagnostics{{ + Severity: hcl.DiagWarning, + Summary: fmt.Sprintf("Plugin '%s' is not used", name), + Detail: fmt.Sprintf("Version '%s' is no longer used. Run install to update lock file", lock), + }}) + } + if check.IsInstallRequired() { + // error out about missing & mismatched plugins + for name := range check.Missing { + diags = diags.Extend(hcl.Diagnostics{{ + Severity: hcl.DiagError, + Summary: fmt.Sprintf("Plugin '%s' is not locked", name), + Detail: "Run install to resolve missing plugins.", + }}) + } + for name, constraint := range check.Mismatch { + pluginIdx := slices.IndexFunc(lockFile.Plugins, func(lock PluginLock) bool { + return lock.Name == name + }) + if pluginIdx == -1 { + continue + } + detailFormat := "Version locked at '%s' does not match the new constraint '%s'\nRun install to update lock file." + diags = diags.Extend(hcl.Diagnostics{{ + Severity: hcl.DiagError, + Summary: fmt.Sprintf("Plugin '%s' version mismatch", name), + Detail: fmt.Sprintf(detailFormat, lockFile.Plugins[pluginIdx].Version, constraint), + }}) + } + return nil, diags + } + // chain the sources together + chain := makeSourceChain(r.sources...) + // resolve the plugins + binaryMap := make(map[string]string) + for _, lock := range lockFile.Plugins { + // skip plugins that are removed from the version constraints + if _, ok := check.Removed[lock.Name]; ok { + continue + } + plugin, err := chain.Resolve(ctx, lock.Name, lock.Version, lock.Checksums) + if err != nil { + return nil, diags.Extend(hcl.Diagnostics{{ + Severity: hcl.DiagError, + Summary: fmt.Sprintf("Failed to resolve plugin '%s@%s'", lock.Name, lock.Version), + Detail: err.Error(), + }}) + } + binaryMap[lock.Name.String()] = plugin.BinaryPath + // check if context is cancelled + if ctx.Err() != nil { + return nil, diags.Extend(hcl.Diagnostics{{ + Severity: hcl.DiagError, + Summary: "Context cancelled", + Detail: ctx.Err().Error(), + }}) + } + } + return binaryMap, diags +} diff --git a/plugin/resolver/resolver_test.go b/plugin/resolver/resolver_test.go new file mode 100644 index 00000000..41982844 --- /dev/null +++ b/plugin/resolver/resolver_test.go @@ -0,0 +1,183 @@ +package resolver + +import ( + "context" + "testing" + + "github.com/hashicorp/hcl/v2" + "github.com/stretchr/testify/require" +) + +func TestResolver_Install(t *testing.T) { + source := newMockSource(t) + resolver, diags := NewResolver(map[string]string{ + "blackstork/sqlite": ">= 1.0 < 2.0", + }, WithSources(source)) + require.Len(t, diags, 0) + require.NotNil(t, resolver) + source.EXPECT().Lookup(context.Background(), Name{"blackstork", "sqlite"}).Return([]Version{ + mustVersion(t, "1.0.0"), + mustVersion(t, "1.0.1"), + mustVersion(t, "1.0.2"), + }, nil) + source.EXPECT().Resolve(context.Background(), Name{"blackstork", "sqlite"}, mustVersion(t, "1.0.2"), []Checksum(nil)).Return(&ResolvedPlugin{ + BinaryPath: "/blackstork/sqlite/1.0.2", + Checksums: []Checksum{ + mustChecksum(t, "archive:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + mustChecksum(t, "binary:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + }, nil) + lockFile, diags := resolver.Install(context.Background(), &LockFile{}, false) + require.Len(t, diags, 0) + require.Equal(t, &LockFile{ + Plugins: []PluginLock{ + { + Name: Name{"blackstork", "sqlite"}, + Version: mustVersion(t, "1.0.2"), + Checksums: []Checksum{ + mustChecksum(t, "archive:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + mustChecksum(t, "binary:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + }, + }, + }, lockFile) + source.AssertExpectations(t) + // again with checksums + source.EXPECT().Lookup(context.Background(), Name{"blackstork", "sqlite"}).Return([]Version{ + mustVersion(t, "1.0.0"), + mustVersion(t, "1.0.1"), + mustVersion(t, "1.0.2"), + }, nil) + source.EXPECT().Resolve(context.Background(), Name{"blackstork", "sqlite"}, mustVersion(t, "1.0.2"), lockFile.Plugins[0].Checksums).Return(&ResolvedPlugin{ + BinaryPath: "/blackstork/sqlite/1.0.2", + Checksums: []Checksum{ + mustChecksum(t, "archive:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + mustChecksum(t, "binary:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + }, nil) + lockFile, diags = resolver.Install(context.Background(), lockFile, false) + require.Len(t, diags, 0) + require.Equal(t, &LockFile{ + Plugins: []PluginLock{ + { + Name: Name{"blackstork", "sqlite"}, + Version: mustVersion(t, "1.0.2"), + Checksums: []Checksum{ + mustChecksum(t, "archive:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + mustChecksum(t, "binary:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + }, + }, + }, lockFile) + source.AssertExpectations(t) +} + +func TestResolver_InstallError(t *testing.T) { + source := newMockSource(t) + resolver, diags := NewResolver(map[string]string{ + "blackstork/sqlite": ">= 1.0 < 2.0", + }, WithSources(source)) + require.Len(t, diags, 0) + require.NotNil(t, resolver) + // missing plugin + source.EXPECT().Lookup(context.Background(), Name{"blackstork", "sqlite"}).Return([]Version{}, nil) + lockFile, diags := resolver.Install(context.Background(), &LockFile{}, false) + require.Len(t, diags, 1) + require.Nil(t, lockFile) + source.AssertExpectations(t) +} + +func TestResolver_Resolve(t *testing.T) { + source := newMockSource(t) + resolver, diags := NewResolver(map[string]string{ + "blackstork/sqlite": ">= 1.0 < 2.0", + }, WithSources(source)) + require.Len(t, diags, 0) + require.NotNil(t, resolver) + source.EXPECT().Resolve(context.Background(), Name{"blackstork", "sqlite"}, mustVersion(t, "1.0.2"), []Checksum{ + mustChecksum(t, "archive:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + mustChecksum(t, "binary:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }).Return(&ResolvedPlugin{ + BinaryPath: "/blackstork/sqlite@1.0.2", + Checksums: []Checksum{ + mustChecksum(t, "archive:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + mustChecksum(t, "binary:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + }, nil) + binMap, diags := resolver.Resolve(context.Background(), &LockFile{ + Plugins: []PluginLock{ + { + Name: Name{"blackstork", "sqlite"}, + Version: mustVersion(t, "1.0.2"), + Checksums: []Checksum{ + mustChecksum(t, "archive:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + mustChecksum(t, "binary:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + }, + }, + }) + require.Len(t, diags, 0) + require.Equal(t, map[string]string{ + "blackstork/sqlite": "/blackstork/sqlite@1.0.2", + }, binMap) + source.AssertExpectations(t) +} + +func TestResolver_ResolveBadLockFile(t *testing.T) { + source := newMockSource(t) + resolver, diags := NewResolver(map[string]string{ + "blackstork/sqlite": ">= 1.0 < 2.0", + }, WithSources(source)) + require.Len(t, diags, 0) + require.NotNil(t, resolver) + // missing plugin + binMap, diags := resolver.Resolve(context.Background(), &LockFile{}) + require.Len(t, diags, 1) + require.Nil(t, binMap) + source.AssertExpectations(t) +} + +func TestResolver_ResolveMissmatchLockFile(t *testing.T) { + source := newMockSource(t) + resolver, diags := NewResolver(map[string]string{ + "blackstork/sqlite": ">= 1.0 < 2.0", + }, WithSources(source)) + require.Len(t, diags, 0) + require.NotNil(t, resolver) + binMap, diags := resolver.Resolve(context.Background(), &LockFile{ + Plugins: []PluginLock{ + { + Name: Name{"blackstork", "sqlite"}, + Version: mustVersion(t, "3.0.0"), + Checksums: []Checksum{ + mustChecksum(t, "archive:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + mustChecksum(t, "binary:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + }, + }, + }) + require.Len(t, diags, 1) + require.Nil(t, binMap) +} + +func TestResolver_ResolveWarn(t *testing.T) { + source := newMockSource(t) + resolver, diags := NewResolver(map[string]string{}, WithSources(source)) + require.Len(t, diags, 0) + require.NotNil(t, resolver) + binMap, diags := resolver.Resolve(context.Background(), &LockFile{ + Plugins: []PluginLock{ + { + Name: Name{"blackstork", "sqlite"}, + Version: mustVersion(t, "1.0.2"), + Checksums: []Checksum{ + mustChecksum(t, "archive:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + mustChecksum(t, "binary:linux:amd64:lgNgp5LO81yt1boBsiaNsJCzLWD9r5ovW+el5k/dDZ8="), + }, + }, + }, + }) + require.Len(t, diags, 1) + require.Equal(t, hcl.DiagWarning, diags[0].Severity) + require.Empty(t, binMap) +} diff --git a/plugin/resolver/source.go b/plugin/resolver/source.go new file mode 100644 index 00000000..41c309ed --- /dev/null +++ b/plugin/resolver/source.go @@ -0,0 +1,70 @@ +package resolver + +import ( + "context" + "fmt" + "slices" +) + +// ErrPluginNotFound is returned when a plugin is not found in the source. +var ErrPluginNotFound = fmt.Errorf("plugin not found") + +// ResolvedPlugin contains the binary path and checksum for a plugin. +type ResolvedPlugin struct { + // BinaryPath for the current platform. + BinaryPath string + // Checksums is a list of checksums for the plugin including all supported platforms. + Checksums []Checksum +} + +// Source is the interface for plugin sources. +// A source is responsible for listing, looking up, and resolving plugins. +// The source may use a local directory, a registry, or any other source. +// If the source is unable to find a plugin, it should return ErrPluginNotFound. +type Source interface { + // Lookup returns a list of available versions for the given plugin. + Lookup(ctx context.Context, name Name) ([]Version, error) + // Resolve returns the binary path and checksums for the given plugin version. + Resolve(ctx context.Context, name Name, version Version, checksums []Checksum) (*ResolvedPlugin, error) +} + +// makeSourceChain returns a source that chains the given sources together. +// When looking up a plugin, the sources are queried in order and the results are concatenated and sorted. +// When resolving a plugin, the sources are queried in order and the first result is returned. +// If a source returns an error other than ErrPluginNotFound, the chain is interrupted and the error is returned. +// If all sources return ErrPluginNotFound, then ErrPluginNotFound is returned. +func makeSourceChain(sources ...Source) Source { + return &sourceChain{sources} +} + +type sourceChain struct { + sources []Source +} + +func (source *sourceChain) Lookup(ctx context.Context, name Name) ([]Version, error) { + var matches []Version + for _, s := range source.sources { + found, err := s.Lookup(ctx, name) + if err != nil { + return nil, err + } + matches = append(matches, found...) + } + slices.SortFunc(matches, func(a, b Version) int { + return a.Compare(b) + }) + return slices.Compact(matches), nil +} + +func (source *sourceChain) Resolve(ctx context.Context, name Name, version Version, checksums []Checksum) (*ResolvedPlugin, error) { + for _, s := range source.sources { + info, err := s.Resolve(ctx, name, version, checksums) + if err == nil { + return info, nil + } + if err != ErrPluginNotFound { + return nil, err + } + } + return nil, ErrPluginNotFound +} diff --git a/plugin/resolver/source_local.go b/plugin/resolver/source_local.go new file mode 100644 index 00000000..76c6f061 --- /dev/null +++ b/plugin/resolver/source_local.go @@ -0,0 +1,131 @@ +package resolver + +import ( + "context" + "crypto/sha256" + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "strings" + + "github.com/Masterminds/semver/v3" +) + +// LocalSource is a plugin source that looks up plugins from a local directory. +// The directory structure should be: +// +// "//@" +// +// For example with the path ".fabric/plugins" plugin name "blackstork/sqlite" and version "1.0.0": +// +// ".fabric/plugins/blackstork/sqlite@1.0.0" +// +// File checksums can be provided in a file with the same name as the plugin binary but with a "_checksums.txt" suffix. +// The file should contain a list of checksums for all supported platforms. +type LocalSource struct { + // Path is the root directory to look up plugins. + Path string +} + +// Lookup returns the versions found of the plugin with the given name. +func (source LocalSource) Lookup(ctx context.Context, name Name) ([]Version, error) { + if source.Path == "" { + return nil, fmt.Errorf("no path provided for local source") + } + pluginDir := filepath.Join(source.Path, name.Namespace()) + entries, err := os.ReadDir(pluginDir) + if os.IsNotExist(err) { + return nil, nil + } else if err != nil { + return nil, fmt.Errorf("failed to read plugin fronm local dir '%s': %w", source.Path, err) + } + var matches []Version + for _, entry := range entries { + if entry.IsDir() { + continue + } + parts := strings.Split(entry.Name(), "@") + if len(parts) != 2 { + continue + } + if parts[0] != name.Short() { + continue + } + parts[1] = strings.TrimSuffix(parts[1], ".exe") + version, err := semver.NewVersion(parts[1]) + if err != nil { + continue + } + matches = append(matches, Version{version}) + } + return matches, nil +} + +// Resolve returns the binary path and checksum for the given plugin version. +func (source LocalSource) Resolve(ctx context.Context, name Name, version Version, checksums []Checksum) (*ResolvedPlugin, error) { + pluginDir := filepath.Join(source.Path, name.Namespace()) + pluginPath := filepath.Join(pluginDir, fmt.Sprintf("%s@%s", name.Short(), version.String())) + checksumPath := pluginPath + "_checksums.txt" + info, err := os.Stat(pluginPath) + if os.IsNotExist(err) { + info, err = os.Stat(pluginPath + ".exe") + if os.IsNotExist(err) { + return nil, ErrPluginNotFound + } else if err != nil { + return nil, fmt.Errorf("failed to stat plugin file: %w", err) + } + pluginPath += ".exe" + } + if info.IsDir() { + return nil, fmt.Errorf("plugin file is a directory") + } + // calculate the checksum of the plugin binary + h := sha256.New() + file, err := os.Open(pluginPath) + if err != nil { + return nil, fmt.Errorf("failed to open plugin file: %w", err) + } + defer file.Close() + if _, err := io.Copy(h, file); err != nil { + return nil, fmt.Errorf("failed to hash plugin file: %w", err) + } + checksum := Checksum{ + Object: "binary", + OS: runtime.GOOS, + Arch: runtime.GOARCH, + Sum: h.Sum(nil), + } + // If the checksums are not provided, then we assume the checksums are the same as the binary. + if len(checksums) == 0 { + // If the checksums metadata file exists then we use the checksums from the file. + // This file is created by RemoteSource when downloading plugins. + // This is useful when the checksums are different for different platforms. + if f, err := os.Open(checksumPath); err == nil { + defer f.Close() + checksums, err = decodeChecksums(f) + if err != nil { + return nil, fmt.Errorf("failed to decode checksums from local source: %w", err) + } + // Additionally, we check that the checksums match the binary. + if !checksum.Match(checksums) { + return nil, fmt.Errorf("invalid plugin binary checksum: '%s'", checksum) + } + } else if os.IsNotExist(err) { + // If the checksums file does not exist, then we assume the checksums are the same as the binary. + checksums = []Checksum{checksum} + } else { + // If there is an error opening the checksums file, then we return the error. + // This is useful for debugging. + return nil, fmt.Errorf("failed to open checksums file at local source: %w", err) + } + } else if !checksum.Match(checksums) { + // If the checksums are provided, then we check that the checksums match the binary. + return nil, fmt.Errorf("invalid plugin binary checksum: '%s'", checksum) + } + return &ResolvedPlugin{ + BinaryPath: pluginPath, + Checksums: checksums, + }, nil +} diff --git a/plugin/resolver/source_local_test.go b/plugin/resolver/source_local_test.go new file mode 100644 index 00000000..5208c93d --- /dev/null +++ b/plugin/resolver/source_local_test.go @@ -0,0 +1,363 @@ +package resolver + +import ( + "bytes" + "context" + "os" + "path/filepath" + "reflect" + "runtime" + "testing" + "text/template" + + "github.com/stretchr/testify/require" +) + +type mockFile struct { + path string + content string + isDir bool +} + +func mockFileDir(t *testing.T, files []mockFile) string { + t.Helper() + tmpDir := t.TempDir() + for _, file := range files { + if file.isDir { + err := os.MkdirAll(filepath.Join(tmpDir, file.path), 0755) + require.NoError(t, err) + continue + } + err := os.MkdirAll(filepath.Dir(filepath.Join(tmpDir, file.path)), 0755) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, file.path), []byte(file.content), 0644) + require.NoError(t, err) + } + return tmpDir +} + +func TestLocalSource_Lookup(t *testing.T) { + type fields struct { + Path string + } + type args struct { + ctx context.Context + name Name + } + tests := []struct { + name string + fields fields + args args + want []Version + wantErr bool + }{ + { + name: "no path", + fields: fields{ + Path: "", + }, + args: args{ + ctx: context.Background(), + name: Name{"blackstork", "sqlite"}, + }, + want: nil, + wantErr: true, + }, + { + name: "no version", + fields: fields{ + Path: t.TempDir(), + }, + args: args{ + ctx: context.Background(), + name: Name{"blackstork", "sqlite"}, + }, + want: nil, + wantErr: false, + }, + { + name: "one version", + fields: fields{ + Path: mockFileDir(t, []mockFile{ + {path: "blackstork/sqlite@1.0.0", content: "plugin"}, + }), + }, + args: args{ + ctx: context.Background(), + name: Name{"blackstork", "sqlite"}, + }, + want: []Version{mustVersion(t, "1.0.0")}, + wantErr: false, + }, + { + name: "one version with exe", + fields: fields{ + Path: mockFileDir(t, []mockFile{ + {path: "blackstork/sqlite@1.0.0.exe", content: "plugin"}, + }), + }, + args: args{ + ctx: context.Background(), + name: Name{"blackstork", "sqlite"}, + }, + want: []Version{mustVersion(t, "1.0.0")}, + wantErr: false, + }, + { + name: "multiple version", + fields: fields{ + Path: mockFileDir(t, []mockFile{ + {path: "blackstork/sqlite@1.0.0", content: "plugin"}, + {path: "blackstork/sqlite@1.0.1", content: "plugin"}, + }), + }, + args: args{ + ctx: context.Background(), + name: Name{"blackstork", "sqlite"}, + }, + want: []Version{mustVersion(t, "1.0.0"), mustVersion(t, "1.0.1")}, + wantErr: false, + }, + { + name: "multiple version with exe", + fields: fields{ + Path: mockFileDir(t, []mockFile{ + {path: "blackstork/sqlite@1.0.0.exe", content: "plugin"}, + {path: "blackstork/sqlite@1.0.1.exe", content: "plugin"}, + }), + }, + args: args{ + ctx: context.Background(), + name: Name{"blackstork", "sqlite"}, + }, + want: []Version{mustVersion(t, "1.0.0"), mustVersion(t, "1.0.1")}, + wantErr: false, + }, + { + name: "skip invalid version", + fields: fields{ + Path: mockFileDir(t, []mockFile{ + {path: "blackstork/sqlite@1.0.0", content: "plugin"}, + {path: "blackstork/sqlite@invalid", content: "plugin"}, + {path: "blackstork/@", content: "plugin"}, + {path: "blackstork/@1.0.0", content: "plugin"}, + {path: "blackstork/sqlite@", content: "plugin"}, + {path: "blackstork/@@", content: "plugin"}, + }), + }, + args: args{ + ctx: context.Background(), + name: Name{"blackstork", "sqlite"}, + }, + want: []Version{mustVersion(t, "1.0.0")}, + wantErr: false, + }, + { + name: "skip non-matching name", + fields: fields{ + Path: mockFileDir(t, []mockFile{ + {path: "blackstork/sqlite@1.0.0", content: "plugin"}, + {path: "blackstork/other@1.0.1", content: "plugin"}, + }), + }, + args: args{ + ctx: context.Background(), + name: Name{"blackstork", "sqlite"}, + }, + want: []Version{mustVersion(t, "1.0.0")}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + source := LocalSource{ + Path: tt.fields.Path, + } + got, err := source.Lookup(tt.args.ctx, tt.args.name) + if (err != nil) != tt.wantErr { + t.Errorf("LocalSource.Lookup() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("LocalSource.Lookup() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLocalSource_Resolve(t *testing.T) { + type fields struct { + Path string + } + type args struct { + ctx context.Context + name Name + version Version + checksums []Checksum + } + tests := []struct { + name string + fields fields + args args + want *ResolvedPlugin + wantErr bool + }{ + { + name: "no path", + fields: fields{ + Path: "", + }, + args: args{ + ctx: context.Background(), + name: Name{"blackstork", "sqlite"}, + version: mustVersion(t, "1.0.0"), + checksums: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "no version", + fields: fields{ + Path: t.TempDir(), + }, + args: args{ + ctx: context.Background(), + name: Name{"blackstork", "sqlite"}, + version: mustVersion(t, "1.0.0"), + checksums: nil, + }, + want: nil, + wantErr: true, + }, + { + name: "just binary", + fields: fields{ + Path: mockFileDir(t, []mockFile{ + {path: "blackstork/sqlite@1.0.0", content: "plugin"}, + }), + }, + args: args{ + ctx: context.Background(), + name: Name{"blackstork", "sqlite"}, + version: mustVersion(t, "1.0.0"), + checksums: nil, + }, + want: &ResolvedPlugin{ + BinaryPath: "{{.tempDir}}/blackstork/sqlite@1.0.0", + Checksums: []Checksum{mustChecksum(t, "binary:"+runtime.GOOS+":"+runtime.GOARCH+":XmieKwFnK/M5luddXjcv9gxTbOFZmhRY6GfNj0vvUWA=")}, + }, + wantErr: false, + }, + { + name: "just binary with exe", + fields: fields{ + Path: mockFileDir(t, []mockFile{ + {path: "blackstork/sqlite@1.0.0.exe", content: "plugin"}, + }), + }, + args: args{ + ctx: context.Background(), + name: Name{"blackstork", "sqlite"}, + version: mustVersion(t, "1.0.0"), + checksums: nil, + }, + want: &ResolvedPlugin{ + BinaryPath: "{{.tempDir}}/blackstork/sqlite@1.0.0.exe", + Checksums: []Checksum{mustChecksum(t, "binary:"+runtime.GOOS+":"+runtime.GOARCH+":XmieKwFnK/M5luddXjcv9gxTbOFZmhRY6GfNj0vvUWA=")}, + }, + wantErr: false, + }, + { + name: "binary and checksum", + fields: fields{ + Path: mockFileDir(t, []mockFile{ + {path: "blackstork/sqlite@1.0.0", content: "plugin"}, + { + path: "blackstork/sqlite@1.0.0_checksums.txt", + content: "archive:darwin:amd64:XmieKwFnK/M5luddXjcv9gxTbOFZmhRY6GfNj0vvUWA=\n" + + "archive:linux:amd64:YmieKwFnK/M5luddXjcv9gxTbOFZmhRY6GfNj0vvUWA=\n" + + "archive:windows:amd64:ZmieKwFnK/M5luddXjcv9gxTbOFZmhRY6GfNj0vvUWA=\n" + + "binary:" + runtime.GOOS + ":" + runtime.GOARCH + ":XmieKwFnK/M5luddXjcv9gxTbOFZmhRY6GfNj0vvUWA=", + }, + }), + }, + args: args{ + ctx: context.Background(), + name: Name{"blackstork", "sqlite"}, + version: mustVersion(t, "1.0.0"), + checksums: nil, + }, + want: &ResolvedPlugin{ + BinaryPath: "{{.tempDir}}/blackstork/sqlite@1.0.0", + Checksums: []Checksum{ + mustChecksum(t, "archive:darwin:amd64:XmieKwFnK/M5luddXjcv9gxTbOFZmhRY6GfNj0vvUWA="), + mustChecksum(t, "archive:linux:amd64:YmieKwFnK/M5luddXjcv9gxTbOFZmhRY6GfNj0vvUWA="), + mustChecksum(t, "archive:windows:amd64:ZmieKwFnK/M5luddXjcv9gxTbOFZmhRY6GfNj0vvUWA="), + mustChecksum(t, "binary:"+runtime.GOOS+":"+runtime.GOARCH+":XmieKwFnK/M5luddXjcv9gxTbOFZmhRY6GfNj0vvUWA="), + }, + }, + wantErr: false, + }, + { + name: "binary checksum does not match with input", + fields: fields{ + Path: mockFileDir(t, []mockFile{ + {path: "blackstork/sqlite@1.0.0", content: "plugin"}, + }), + }, + args: args{ + ctx: context.Background(), + name: Name{"blackstork", "sqlite"}, + version: mustVersion(t, "1.0.0"), + checksums: []Checksum{mustChecksum(t, "binary:"+runtime.GOOS+":"+runtime.GOARCH+":XmieKwFnK/M5luddXjcv9gxTbOFZmhRY6GfNj1vvUWA=")}, + }, + want: nil, + wantErr: true, + }, + { + name: "binary checksum does match with input", + fields: fields{ + Path: mockFileDir(t, []mockFile{ + {path: "blackstork/sqlite@1.0.0", content: "plugin"}, + }), + }, + args: args{ + ctx: context.Background(), + name: Name{"blackstork", "sqlite"}, + version: mustVersion(t, "1.0.0"), + checksums: []Checksum{mustChecksum(t, "binary:"+runtime.GOOS+":"+runtime.GOARCH+":XmieKwFnK/M5luddXjcv9gxTbOFZmhRY6GfNj0vvUWA=")}, + }, + want: &ResolvedPlugin{ + BinaryPath: "{{.tempDir}}/blackstork/sqlite@1.0.0", + Checksums: []Checksum{mustChecksum(t, "binary:"+runtime.GOOS+":"+runtime.GOARCH+":XmieKwFnK/M5luddXjcv9gxTbOFZmhRY6GfNj0vvUWA=")}, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + source := LocalSource{ + Path: tt.fields.Path, + } + got, err := source.Resolve(tt.args.ctx, tt.args.name, tt.args.version, tt.args.checksums) + if (err != nil) != tt.wantErr { + t.Errorf("LocalSource.Resolve() error = %v, wantErr %v", err, tt.wantErr) + return + } + if want := tt.want; want != nil { + tmpl, err := template.New("test").Parse(tt.want.BinaryPath) + require.NoError(t, err) + var buf bytes.Buffer + err = tmpl.Execute(&buf, map[string]interface{}{ + "tempDir": tt.fields.Path, + }) + require.NoError(t, err) + want.BinaryPath = buf.String() + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("LocalSource.Resolve() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/plugin/resolver/source_remote.go b/plugin/resolver/source_remote.go new file mode 100644 index 00000000..d0df24fe --- /dev/null +++ b/plugin/resolver/source_remote.go @@ -0,0 +1,363 @@ +package resolver + +import ( + "archive/tar" + "compress/gzip" + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "mime" + "net/http" + "os" + "path/filepath" + "runtime" + "slices" + "strings" + "time" +) + +const ( + maxDownloadSize = 50 * 1024 * 1024 // 50MB + downloadTimeout = 5 * time.Minute + regAPITimeout = 10 * time.Second +) + +// RemoteSource is a plugin source that looks up plugins from a remote registry. +// The registry should implement the Fabric Registry API. +type RemoteSource struct { + // BaseURL is the base URL of the registry. + BaseURL string + // DownloadDir is the directory where the plugins are downloaded. + DownloadDir string + // UserAgent is the http user agent to use for the requests. + // Useful for debugging and statistics on the registry side. + UserAgent string +} + +// regVersion represents a version of a plugin in the registry +type regVersion struct { + Version Version `json:"version"` + Platforms []regPlatform `json:"platforms"` +} + +// regPlatform represents an available platform for a plugin version in the registry +type regPlatform struct { + OS string `json:"os"` + Arch string `json:"arch"` +} + +// regDownloadInfo represents the download info for a specific platform of a plugin version in the registry +type regDownloadInfo struct { + OS string `json:"os"` + Arch string `json:"arch"` + DownloadURL string `json:"download_url"` +} + +// regError represents an error response. +type regError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +// registryError implements the error interface. +func (err regError) Error() string { + return fmt.Sprintf("[code=%s]: %s", err.Code, err.Message) +} + +// Lookup returns the versions found of the plugin with the given name. +func (source RemoteSource) Lookup(ctx context.Context, name Name) ([]Version, error) { + versions, err := source.fetchVersions(ctx, name) + if err != nil { + return nil, fmt.Errorf("failed to lookup plugin versions in the registry: %w", err) + } + var matches []Version + for _, version := range versions { + hasPlatform := slices.ContainsFunc(version.Platforms, func(platform regPlatform) bool { + return platform.OS == runtime.GOOS && platform.Arch == runtime.GOARCH + }) + if hasPlatform { + matches = append(matches, version.Version) + } + } + return matches, nil +} + +// call makes a http request to the registry with the given timeout. +func (source RemoteSource) call(req *http.Request, timeout time.Duration) (*http.Response, error) { + if source.UserAgent != "" { + req.Header.Set("User-Agent", source.UserAgent) + } + client := &http.Client{ + Timeout: timeout, + } + return client.Do(req) +} + +// decodeBody decodes the http response from the registry into the provided value. +func (source RemoteSource) decodeBody(resp *http.Response, v interface{}) error { + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + var errResp struct { + Error regError `json:"error"` + } + if err := json.NewDecoder(resp.Body).Decode(&errResp); err != nil { + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + return errResp.Error + } + return json.NewDecoder(resp.Body).Decode(v) +} + +// fetchVersions looks up the plugin versions in the registry. +func (source RemoteSource) fetchVersions(ctx context.Context, name Name) ([]regVersion, error) { + url := fmt.Sprintf("%s/v1/plugins/%s/%s/versions", source.BaseURL, name.Namespace(), name.Short()) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/json") + resp, err := source.call(req, regAPITimeout) + if err != nil { + return nil, err + } + defer resp.Body.Close() + var respData struct { + Versions []regVersion `json:"versions"` + } + if err := source.decodeBody(resp, &respData); err != nil { + return nil, err + } + return respData.Versions, nil +} + +// Resolve returns the binary path and checksum for the given plugin version. +func (source RemoteSource) Resolve(ctx context.Context, name Name, version Version, checksums []Checksum) (*ResolvedPlugin, error) { + downloadInfo, err := source.fetchDownloadInfo(ctx, name, version) + if err != nil { + return nil, fmt.Errorf("failed to get plugin from the registry: %w", err) + } + return source.download(ctx, name, version, downloadInfo, checksums) +} + +// fetchDownloadInfo resolves the download info for sthe given plugin version from the registry. +func (source RemoteSource) fetchDownloadInfo(ctx context.Context, name Name, version Version) (*regDownloadInfo, error) { + url := fmt.Sprintf("%s/v1/plugins/%s/%s/%s/download/%s/%s", + source.BaseURL, + name.Namespace(), + name.Short(), + version.String(), + runtime.GOOS, + runtime.GOARCH, + ) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/json") + resp, err := source.call(req, regAPITimeout) + if err != nil { + return nil, err + } + defer resp.Body.Close() + var info regDownloadInfo + if err := source.decodeBody(resp, &info); err != nil { + return nil, err + } + return &info, nil +} + +// fetchChecksums fetches the plugin checksums from the registry. +func (source RemoteSource) fetchChecksums(ctx context.Context, name Name, version Version) ([]Checksum, error) { + url := fmt.Sprintf("%s/v1/plugins/%s/%s/%s/checksums", source.BaseURL, name.Namespace(), name.Short(), version.String()) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/json") + resp, err := source.call(req, regAPITimeout) + if err != nil { + return nil, err + } + defer resp.Body.Close() + var respData struct { + Checksums []Checksum `json:"checksums"` + } + if err := source.decodeBody(resp, &respData); err != nil { + return nil, err + } + return respData.Checksums, nil +} + +// download downloads the plugin from the registry and returns the binary path and checksum. +func (source RemoteSource) download(ctx context.Context, name Name, version Version, info *regDownloadInfo, checksums []Checksum) (_ *ResolvedPlugin, err error) { + // If the checksums are not provided it means plugin version is not locked and we need to fetch the checksums from the registry. + if len(checksums) == 0 { + var err error + checksums, err = source.fetchChecksums(ctx, name, version) + if err != nil { + return nil, fmt.Errorf("failed to fetch plugin checksums: %w", err) + } + } + // make a http request to download the plugin + req, err := http.NewRequestWithContext(ctx, "GET", info.DownloadURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create download request: %w", err) + } + req.Header.Set("Accept", "application/octet-stream") + resp, err := source.call(req, downloadTimeout) + if err != nil { + return nil, fmt.Errorf("failed to download plugin: %w", err) + } + defer resp.Body.Close() + // verify download response headers + if err = source.verifyDownloadHeaders(resp); err != nil { + return nil, err + } + // calculate checksum of the downloaded archive while writing to the file + h := sha256.New() + buf := io.TeeReader(resp.Body, h) + // extract plugin while downloading without saving to disk + binaryPath, checksumPath, err := source.extract(name, version, buf, checksums) + if err != nil { + return nil, fmt.Errorf("failed to extract plugin: %w", err) + } + // cleanup extracted files if there is an error during checksum verification + defer func() { + if err == nil { + return + } + // if there is an error, remove extracted binary file + os.Remove(binaryPath) + os.Remove(checksumPath) + // remove directory if it is empty + entries, err := os.ReadDir(filepath.Dir(binaryPath)) + if err == nil && len(entries) == 0 { + os.Remove(filepath.Dir(binaryPath)) + } + }() + // read remaining data from the response body to verify the checksum of the downloaded archive + if _, err := io.Copy(io.Discard, buf); err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + // verify checksum of the downloaded archive + sum := Checksum{ + Object: "archive", + OS: runtime.GOOS, + Arch: runtime.GOARCH, + Sum: h.Sum(nil), + } + if !sum.Match(checksums) { + return nil, fmt.Errorf("invalid plugin archive checksum: '%s'", sum) + } + return &ResolvedPlugin{ + BinaryPath: binaryPath, + Checksums: checksums, + }, nil +} + +// verifyDownloadHeaders verifies the download response headers. +func (source RemoteSource) verifyDownloadHeaders(res *http.Response) error { + // verify the download size + if res.ContentLength > maxDownloadSize { + return fmt.Errorf("plugin download size exceeds the limit, got = %d, expect < %d", res.ContentLength, maxDownloadSize) + } + disposition, params, err := mime.ParseMediaType(res.Header.Get("Content-Disposition")) + if err != nil { + return fmt.Errorf("failed to parse content disposition: %w", err) + } + if disposition != "attachment" { + return fmt.Errorf("unsupported content disposition: %s", disposition) + } + fn := params["filename"] + if fn == "" { + return fmt.Errorf("missing filename in content disposition") + } + if !strings.HasSuffix(fn, ".tar.gz") { + return fmt.Errorf("unsupported archive type: %s", fn) + } + return nil +} + +// extract the plugin from the tar.gz file and returns the binary and checksum file path. +func (source RemoteSource) extract(name Name, version Version, archive io.Reader, checksums []Checksum) (binPath, sumPath string, err error) { + read, err := gzip.NewReader(archive) + if err != nil { + return "", "", fmt.Errorf("failed to create gzip reader: %w", err) + } + defer read.Close() + reader := tar.NewReader(read) + var found *tar.Header + for { + header, err := reader.Next() + if err == io.EOF { + break + } + if header.Typeflag != tar.TypeReg { + continue + } + if header.Name != fmt.Sprintf("%s@%s", name.Short(), version.String()) && + header.Name != fmt.Sprintf("%s@%s.exe", name.Short(), version.String()) { + continue + } + found = header + break + } + if found == nil { + return "", "", fmt.Errorf("plugin binary not found in tar.gz file") + } + binaryPath := filepath.Join(source.DownloadDir, name.Namespace(), filepath.Base(found.Name)) + checksumPath := strings.TrimSuffix(binaryPath, ".exe") + "_checksums.txt" + if err := os.MkdirAll(filepath.Dir(binaryPath), 0o755); err != nil { + return "", "", fmt.Errorf("failed to create plugin directory: %w", err) + } + binaryFile, err := os.OpenFile(binaryPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755) + if err != nil { + return "", "", fmt.Errorf("failed to create plugin file: %w", err) + } + // cleanup the downloaded binary on error + defer func() { + binaryFile.Close() + if err != nil { + // if there is an error, remove extracted binary file and checksum file + os.Remove(binaryPath) + // remove directory if it is empty + entries, err := os.ReadDir(filepath.Dir(binaryPath)) + if err == nil && len(entries) == 0 { + os.Remove(filepath.Dir(binaryPath)) + } + } + }() + // calculate checksum of the plugin binary while writing to the file + h := sha256.New() + buf := io.MultiWriter(h, binaryFile) + // write the plugin binary + if _, err := io.Copy(buf, reader); err != nil { //nolint:gosec // lint issue not possible here + return "", "", fmt.Errorf("failed to write plugin file: %w", err) + } + sum := Checksum{ + Object: "binary", + OS: runtime.GOOS, + Arch: runtime.GOARCH, + Sum: h.Sum(nil), + } + if !sum.Match(checksums) { + return "", "", fmt.Errorf("invalid plugin binary checksum: '%s'", sum) + } + // Create checksums file to be used for the following installs when plugin is installed from the local source. + checksumFile, err := os.Create(checksumPath) + if err != nil { + return "", "", fmt.Errorf("failed to create plugin meta file: %w", err) + } + // cleanup checksum file operation + defer func() { + checksumFile.Close() + if err != nil { // if there is an error, remove checksum file + os.Remove(checksumPath) + } + }() + if err := encodeChecksums(checksumFile, checksums); err != nil { + return "", "", fmt.Errorf("failed to write plugin meta file: %w", err) + } + return binaryPath, checksumPath, nil +} diff --git a/plugin/resolver/source_remote_test.go b/plugin/resolver/source_remote_test.go new file mode 100644 index 00000000..1218ca36 --- /dev/null +++ b/plugin/resolver/source_remote_test.go @@ -0,0 +1,202 @@ +package resolver + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "context" + "crypto/sha256" + "io" + "net/http" + "net/http/httptest" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRemoteSourceLookup(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + assert.Equal(t, "/v1/plugins/blackstork/sqlite/versions", r.URL.Path) + assert.Equal(t, "application/json", r.Header.Get("Accept")) + assert.Equal(t, "test/0.1", r.Header.Get("User-Agent")) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "versions": [ + { + "version": "1.0.0", + "platforms": [ + { + "os": "` + runtime.GOOS + `", + "arch": "` + runtime.GOARCH + `" + } + ] + }, + { + "version": "1.0.1", + "platforms": [ + { + "os": "` + runtime.GOOS + `", + "arch": "` + runtime.GOARCH + `" + } + ] + } + ] + }`)) + })) + defer srv.Close() + source := RemoteSource{ + BaseURL: srv.URL, + UserAgent: "test/0.1", + } + versions, err := source.Lookup(context.Background(), Name{"blackstork", "sqlite"}) + assert.NoError(t, err) + assert.Equal(t, []Version{ + mustVersion(t, "1.0.0"), + mustVersion(t, "1.0.1"), + }, versions) +} + +func TestRemoteSourceLookupError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`{ + "error": { + "code": "not_found", + "message": "plugin not found" + } + }`)) + })) + defer srv.Close() + source := RemoteSource{ + BaseURL: srv.URL, + } + versions, err := source.Lookup(context.Background(), Name{"blackstork", "sqlite"}) + assert.EqualError(t, err, "failed to lookup plugin versions in the registry: [code=not_found]: plugin not found") + assert.Nil(t, versions) +} + +func mockTarGz(t *testing.T, files map[string]string) ([]byte, []Checksum) { + t.Helper() + checksums := []Checksum{} + buf := bytes.NewBuffer(nil) + gz := gzip.NewWriter(buf) + w := tar.NewWriter(gz) + for name, content := range files { + hdr := &tar.Header{ + Name: name, + Size: int64(len(content)), + } + if err := w.WriteHeader(hdr); err != nil { + t.Fatal(err) + } + if _, err := w.Write([]byte(content)); err != nil { + t.Fatal(err) + } + h := sha256.New() + h.Write([]byte(content)) + checksums = append(checksums, Checksum{ + OS: runtime.GOOS, + Arch: runtime.GOARCH, + Object: "binary", + Sum: h.Sum(nil), + }) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + gz.Flush() + gz.Close() + h := sha256.New() + if _, err := io.Copy(h, bytes.NewReader(buf.Bytes())); err != nil { + t.Fatal(err) + } + checksums = append(checksums, Checksum{ + OS: runtime.GOOS, + Arch: runtime.GOARCH, + Object: "archive", + Sum: h.Sum(nil), + }) + return buf.Bytes(), checksums +} + +func TestRemoteSourceResolve(t *testing.T) { + archive, checksums := mockTarGz(t, map[string]string{ + "sqlite@1.0.0": "plugin-binary", + }) + var srv *httptest.Server + srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/plugins/blackstork/sqlite/1.0.0/checksums": + assert.Equal(t, "GET", r.Method) + assert.Equal(t, "application/json", r.Header.Get("Accept")) + assert.Equal(t, "test/0.1", r.Header.Get("User-Agent")) + w.Header().Add("Content-Type", "application/json") + checksumsJSON := []string{} + for _, c := range checksums { + raw, err := c.MarshalJSON() + assert.NoError(t, err) + checksumsJSON = append(checksumsJSON, string(raw)) + } + w.Write([]byte(`{ + "checksums": [ + ` + strings.Join(checksumsJSON, ",") + ` + ] + }`)) + case "/v1/plugins/blackstork/sqlite/1.0.0/download/" + runtime.GOOS + "/" + runtime.GOARCH: + assert.Equal(t, "GET", r.Method) + assert.Equal(t, "application/json", r.Header.Get("Accept")) + assert.Equal(t, "test/0.1", r.Header.Get("User-Agent")) + w.Header().Add("Content-Type", "application/json") + w.Write([]byte(`{ + "os": "` + runtime.GOOS + `", + "arch": "` + runtime.GOARCH + `", + "download_url": "` + srv.URL + `/download" + }`)) + case "/download": + assert.Equal(t, "GET", r.Method) + w.Header().Add("Content-Type", "octet/stream") + w.Header().Add("Content-Disposition", "attachment; filename=plugin.tar.gz") + w.Write(archive) + default: + t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path) + } + })) + defer srv.Close() + source := RemoteSource{ + BaseURL: srv.URL, + UserAgent: "test/0.1", + DownloadDir: t.TempDir(), + } + // without checksums input + plugin, err := source.Resolve(context.Background(), Name{"blackstork", "sqlite"}, mustVersion(t, "1.0.0"), nil) + require.NoError(t, err) + assert.Equal(t, checksums, plugin.Checksums) + assert.Equal(t, filepath.Join(source.DownloadDir, "blackstork/sqlite@1.0.0"), plugin.BinaryPath) + // pass with valid checksums input + plugin, err = source.Resolve(context.Background(), Name{"blackstork", "sqlite"}, mustVersion(t, "1.0.0"), checksums) + require.NoError(t, err) + assert.Equal(t, checksums, plugin.Checksums) + assert.Equal(t, filepath.Join(source.DownloadDir, "blackstork/sqlite@1.0.0"), plugin.BinaryPath) + // fail with invalid checksums input + plugin, err = source.Resolve(context.Background(), Name{"blackstork", "sqlite"}, mustVersion(t, "1.0.0"), []Checksum{ + { + OS: runtime.GOOS, + Arch: runtime.GOARCH, + Object: "archive", + Sum: []byte("other"), + }, + { + OS: runtime.GOOS, + Arch: runtime.GOARCH, + Object: "binary", + Sum: []byte("other"), + }, + }) + require.Nil(t, plugin) + require.Error(t, err, "failed to resolve plugin: checksum mismatch") +} diff --git a/plugin/resolver/version.go b/plugin/resolver/version.go new file mode 100644 index 00000000..2754586e --- /dev/null +++ b/plugin/resolver/version.go @@ -0,0 +1,58 @@ +package resolver + +import ( + "fmt" + "strconv" + + "github.com/Masterminds/semver/v3" +) + +// Version is a version of a plugin. It is a wrapper around semver.Version with strict parsing. +type Version struct { + *semver.Version +} + +// UnmarshalJSON parses a JSON string into a PluginVersion using strict semver parsing. +func (v *Version) UnmarshalJSON(data []byte) error { + raw, err := strconv.Unquote(string(data)) + if err != nil { + return fmt.Errorf("failed to unquote version: %w", err) + } + ver, err := semver.StrictNewVersion(raw) + if err != nil { + return err + } + *v = Version{ver} + return nil +} + +// Compare compares the version with another version. +func (v Version) Compare(other Version) int { + return v.Version.Compare(other.Version) +} + +// ConstraintMap is a map of plugin names to version constraints. +type ConstraintMap map[Name]*semver.Constraints + +// ParseConstraintMap parses string map into a PluginConstraintMap. +func ParseConstraintMap(src map[string]string) (ConstraintMap, error) { + if src == nil { + return nil, nil + } + parsed := make(ConstraintMap) + for name, version := range src { + if version == "" { + return nil, fmt.Errorf("missing plugin version constraint for '%s'", name) + } + parsedName, err := ParseName(name) + if err != nil { + return nil, err + } + constraints, err := semver.NewConstraint(version) + if err != nil { + return nil, err + } + parsed[parsedName] = constraints + } + return parsed, nil +} diff --git a/plugin/resolver/version_test.go b/plugin/resolver/version_test.go new file mode 100644 index 00000000..22f55ec3 --- /dev/null +++ b/plugin/resolver/version_test.go @@ -0,0 +1,148 @@ +package resolver + +import ( + "reflect" + "testing" + + "github.com/Masterminds/semver/v3" + "github.com/stretchr/testify/require" +) + +func mustVersion(t *testing.T, str string) Version { + t.Helper() + v, err := semver.NewVersion(str) + require.NoError(t, err) + return Version{v} +} + +func TestParseConstraintMap(t *testing.T) { + type args struct { + src map[string]string + } + tests := []struct { + name string + args args + want ConstraintMap + wantErr bool + }{ + { + name: "nil", + args: args{ + src: nil, + }, + want: nil, + wantErr: false, + }, + { + name: "empty", + args: args{ + src: map[string]string{}, + }, + want: ConstraintMap{}, + wantErr: false, + }, + { + name: "single", + args: args{ + src: map[string]string{ + "ns/name": "1.0.0", + }, + }, + want: ConstraintMap{ + Name{"ns", "name"}: mustConstraint(t, "1.0.0"), + }, + wantErr: false, + }, + { + name: "with_v_prefix", + args: args{ + src: map[string]string{ + "ns/name": "v1.0.0", + }, + }, + want: ConstraintMap{ + Name{"ns", "name"}: mustConstraint(t, "v1.0.0"), + }, + wantErr: false, + }, + { + name: "multiple", + args: args{ + src: map[string]string{ + "ns/name1": "1.0.0", + "ns/name2": "2.0.0", + "ns/name3": "3.0.0", + }, + }, + want: ConstraintMap{ + Name{"ns", "name1"}: mustConstraint(t, "1.0.0"), + Name{"ns", "name2"}: mustConstraint(t, "2.0.0"), + Name{"ns", "name3"}: mustConstraint(t, "3.0.0"), + }, + wantErr: false, + }, + { + name: "invalid_name", + args: args{ + src: map[string]string{ + "ns": "1.0.0", + }, + }, + want: nil, + wantErr: true, + }, + { + name: "invalid_version", + args: args{ + src: map[string]string{ + "ns/name": "", + }, + }, + want: nil, + wantErr: true, + }, + { + name: "invalid_version_constraint", + args: args{ + src: map[string]string{ + "ns/name": "1.0.0+", + }, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseConstraintMap(tt.args.src) + if (err != nil) != tt.wantErr { + t.Errorf("parseVersionMap() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseVersionMap() = %v, want %v", got, tt.want) + } + }) + } +} + +func mustConstraint(t *testing.T, str string) *semver.Constraints { + t.Helper() + c, err := semver.NewConstraint(str) + require.NoError(t, err) + return c +} + +func TestVersion_UnmarshalJSON(t *testing.T) { + v := new(Version) + err := v.UnmarshalJSON([]byte(`"1.0.0"`)) + require.NoError(t, err) + require.Equal(t, mustVersion(t, "1.0.0"), *v) + v = new(Version) + err = v.UnmarshalJSON([]byte(`"v1.0.0"`)) + require.Error(t, err) + err = v.UnmarshalJSON([]byte(`"1.0"`)) + require.Error(t, err) + err = v.UnmarshalJSON([]byte(`"1"`)) + require.Error(t, err) +} diff --git a/plugin/runner/loader.go b/plugin/runner/loader.go index 0a5429a4..9461205b 100644 --- a/plugin/runner/loader.go +++ b/plugin/runner/loader.go @@ -2,6 +2,7 @@ package runner import ( "fmt" + "log/slog" "os" "github.com/hashicorp/hcl/v2" @@ -26,18 +27,18 @@ type loadedContentProvider struct { } type loader struct { - resolver *resolver - versionMap VersionMap - builtin []*plugin.Schema + logger *slog.Logger + binaryMap map[string]string + builtin *plugin.Schema pluginMap map[string]loadedPlugin dataMap map[string]loadedDataSource contentMap map[string]loadedContentProvider } -func makeLoader(mirrorDir string, builtin []*plugin.Schema, pluginMap VersionMap) *loader { +func makeLoader(binaryMap map[string]string, builtin *plugin.Schema, logger *slog.Logger) *loader { return &loader{ - resolver: makeResolver(mirrorDir), - versionMap: pluginMap, + logger: logger, + binaryMap: binaryMap, builtin: builtin, pluginMap: make(map[string]loadedPlugin), dataMap: make(map[string]loadedDataSource), @@ -50,14 +51,12 @@ func nopCloser() error { } func (l *loader) loadAll() hcl.Diagnostics { - for _, p := range l.builtin { - if diags := l.registerPlugin(p, nopCloser); diags.HasErrors() { - diags = append(diags, l.closeAll()...) - return diags - } + if diags := l.registerPlugin(l.builtin, nopCloser); diags.HasErrors() { + diags = append(diags, l.closeAll()...) + return diags } - for name, version := range l.versionMap { - if diags := l.loadBinary(name, version); diags.HasErrors() { + for name, binaryPath := range l.binaryMap { + if diags := l.loadBinary(name, binaryPath); diags.HasErrors() { diags = append(diags, l.closeAll()...) return diags } @@ -79,92 +78,95 @@ func (l *loader) closeAll() hcl.Diagnostics { return diags } -func (l *loader) registerDataSource(name string, p *plugin.Schema, ds *plugin.DataSource) hcl.Diagnostics { +func (l *loader) registerDataSource(name string, schema *plugin.Schema, ds *plugin.DataSource) hcl.Diagnostics { if found, has := l.dataMap[name]; has { return hcl.Diagnostics{{ Severity: hcl.DiagError, Summary: "Duplicate data source", - Detail: fmt.Sprintf("Data source %s provided by plugin %s@%s and %s@%s", name, p.Name, p.Version, found.plugin.Name, found.plugin.Version), + Detail: fmt.Sprintf("Data source %s provided by plugin %s@%s and %s@%s", name, schema.Name, schema.Version, found.plugin.Name, found.plugin.Version), }} } - l.dataMap[name] = loadedDataSource{p, ds} + l.dataMap[name] = loadedDataSource{ + plugin: schema, + DataSource: ds, + } return nil } -func (l *loader) registerContentProvider(name string, p *plugin.Schema, cp *plugin.ContentProvider) hcl.Diagnostics { +func (l *loader) registerContentProvider(name string, schema *plugin.Schema, cp *plugin.ContentProvider) hcl.Diagnostics { if found, has := l.contentMap[name]; has { return hcl.Diagnostics{{ Severity: hcl.DiagError, Summary: "Duplicate content provider", - Detail: fmt.Sprintf("Content provider %s provided by plugin %s@%s and %s@%s", name, p.Name, p.Version, found.plugin.Name, found.plugin.Version), + Detail: fmt.Sprintf("Content provider %s provided by plugin %s@%s and %s@%s", name, schema.Name, schema.Version, found.plugin.Name, found.plugin.Version), }} } - l.contentMap[name] = loadedContentProvider{p, cp} + l.contentMap[name] = loadedContentProvider{ + plugin: schema, + ContentProvider: cp, + } return nil } -func (l *loader) registerPlugin(p *plugin.Schema, closefn func() error) hcl.Diagnostics { - if diags := p.Validate(); diags.HasErrors() { +func (l *loader) registerPlugin(schema *plugin.Schema, closefn func() error) hcl.Diagnostics { + if diags := schema.Validate(); diags.HasErrors() { return diags } - if found, has := l.pluginMap[p.Name]; has { + if found, has := l.pluginMap[schema.Name]; has { diags := hcl.Diagnostics{{ Severity: hcl.DiagError, - Summary: "Plugin conflict", - Detail: fmt.Sprintf("Plugin %s@%s and %s@%s have the same name", p.Name, p.Version, found.Name, found.Version), + Summary: fmt.Sprintf("Plugin %s conflict", schema.Name), + Detail: fmt.Sprintf("%s@%s and %s@%s have the same schema name", schema.Name, schema.Version, found.Name, found.Version), }} err := found.closefn() if err != nil { diags = append(diags, &hcl.Diagnostic{ Severity: hcl.DiagError, - Summary: "Failed to close plugin", - Detail: fmt.Sprintf("Failed to close plugin %s@%s: %v", found.Name, found.Version, err), + Summary: fmt.Sprintf("Failed to close plugin %s@%s", found.Name, found.Version), + Detail: err.Error(), }) } return diags } plugin := loadedPlugin{ closefn: closefn, - Schema: p, + Schema: schema, } - l.pluginMap[p.Name] = plugin - for name, source := range p.DataSources { - if diags := l.registerDataSource(name, p, source); diags.HasErrors() { + l.pluginMap[schema.Name] = plugin + for name, source := range schema.DataSources { + if diags := l.registerDataSource(name, schema, source); diags.HasErrors() { return diags } } - for name, provider := range p.ContentProviders { - if diags := l.registerContentProvider(name, p, provider); diags.HasErrors() { + for name, provider := range schema.ContentProviders { + if diags := l.registerContentProvider(name, schema, provider); diags.HasErrors() { return diags } } return nil } -func (l *loader) loadBinary(name, version string) hcl.Diagnostics { - loc, diags := l.resolver.resolve(name, version) - if diags.HasErrors() { - return diags - } - if info, err := os.Stat(loc); os.IsNotExist(err) { +func (l *loader) loadBinary(name, binaryPath string) hcl.Diagnostics { + if info, err := os.Stat(binaryPath); os.IsNotExist(err) { return hcl.Diagnostics{{ Severity: hcl.DiagError, - Summary: "Plugin not found", - Detail: fmt.Sprintf("Plugin %s@%s not found at: %s", name, version, loc), + Summary: fmt.Sprintf("Plugin %s binary not found", name), + Detail: fmt.Sprintf("Executable not found at: %s", binaryPath), }} } else if info.IsDir() { return hcl.Diagnostics{{ Severity: hcl.DiagError, - Summary: "Plugin is a directory", - Detail: fmt.Sprintf("Plugin %s@%s is a directory at: %s", name, version, loc), + Summary: fmt.Sprintf("Plugin %s binary path is a directory", name), + Detail: fmt.Sprintf("Path %s is a directory", binaryPath), }} } - p, close, err := pluginapiv1.NewClient(loc) + l.logger.Info("Loading plugin", "name", name, "path", binaryPath) + p, close, err := pluginapiv1.NewClient(name, binaryPath, l.logger) if err != nil { return hcl.Diagnostics{{ Severity: hcl.DiagError, - Summary: "Failed to create plugin client", - Detail: fmt.Sprintf("Failed to create plugin client for %s@%s: %v", name, version, err), + Summary: fmt.Sprintf("Failed to load plugin %s", name), + Detail: err.Error(), }} } return l.registerPlugin(p, close) diff --git a/plugin/runner/options.go b/plugin/runner/options.go deleted file mode 100644 index 97a85c0b..00000000 --- a/plugin/runner/options.go +++ /dev/null @@ -1,33 +0,0 @@ -package runner - -import "github.com/blackstork-io/fabric/plugin" - -type options struct { - pluginDir string - versionMap VersionMap - builtin []*plugin.Schema -} - -var defaultOptions = options{ - pluginDir: "./plugins", -} - -type Option func(*options) - -func WithPluginDir(dir string) Option { - return func(o *options) { - o.pluginDir = dir - } -} - -func WithPluginVersions(m VersionMap) Option { - return func(o *options) { - o.versionMap = m - } -} - -func WithBuiltIn(builtin ...*plugin.Schema) Option { - return func(o *options) { - o.builtin = builtin - } -} diff --git a/plugin/runner/resolver.go b/plugin/runner/resolver.go deleted file mode 100644 index ccf9294d..00000000 --- a/plugin/runner/resolver.go +++ /dev/null @@ -1,96 +0,0 @@ -package runner - -import ( - "fmt" - "os" - "path/filepath" - "strings" - - "github.com/Masterminds/semver/v3" - "github.com/hashicorp/hcl/v2" -) - -type resolver struct { - mirrorDir string -} - -func makeResolver(mirrorDir string) *resolver { - return &resolver{ - mirrorDir: mirrorDir, - } -} - -func (r *resolver) resolve(name, version string) (loc string, diags hcl.Diagnostics) { - nameSpace, pluginName, err := r.parseName(name) - if err != nil { - return "", hcl.Diagnostics{{ - Severity: hcl.DiagError, - Summary: "Invalid plugin name", - Detail: fmt.Sprintf("Invalid plugin name '%s': %s", name, err), - }} - } - constraint, err := semver.NewConstraint(version) - if err != nil { - return "", hcl.Diagnostics{{ - Severity: hcl.DiagError, - Summary: "Failed to resolve plugin", - Detail: fmt.Sprintf("Invalid version constraint: %s", err), - }} - } - entry, err := os.ReadDir(filepath.Join(r.mirrorDir, nameSpace)) - if err != nil { - return "", hcl.Diagnostics{{ - Severity: hcl.DiagError, - Summary: "Failed to resolve plugin", - Detail: fmt.Sprintf("Failed to read directory for namespace '%s': %s", nameSpace, err), - }} - } - matched := make(map[string]*semver.Version) - for _, e := range entry { - if e.IsDir() { - continue - } - parts := strings.SplitN(e.Name(), "@", 2) - if len(parts) != 2 || parts[0] != pluginName { - continue - } - v, err := semver.NewVersion(parts[1]) - if err != nil { - continue - } - if !constraint.Check(v) { - continue - } - matched[parts[1]] = v - } - if len(matched) == 0 { - return "", hcl.Diagnostics{{ - Severity: hcl.DiagError, - Summary: "Failed to resolve plugin binary", - Detail: fmt.Sprintf("No plugin matches version constraint for %s@%s", name, version), - }} - } - // find latest version that matches version constraint - var latestVerStr string - var latestVer *semver.Version - for str, ver := range matched { - if latestVer == nil { - latestVerStr = str - latestVer = ver - continue - } - if ver.Compare(latestVer) > 0 { - latestVerStr = str - latestVer = ver - } - } - return filepath.Join(r.mirrorDir, nameSpace, fmt.Sprintf("%s@%s", pluginName, latestVerStr)), nil -} - -func (r *resolver) parseName(name string) (string, string, error) { - parts := strings.SplitN(name, "/", 2) - if len(parts) != 2 { - return "", "", fmt.Errorf("plugin name '%s' is not in the form '/'", name) - } - return parts[0], parts[1], nil -} diff --git a/plugin/runner/runner.go b/plugin/runner/runner.go index 18b35b79..530fa089 100644 --- a/plugin/runner/runner.go +++ b/plugin/runner/runner.go @@ -2,91 +2,25 @@ package runner import ( "fmt" - "strings" - "unicode" + "log/slog" - "github.com/Masterminds/semver/v3" "github.com/hashicorp/hcl/v2" "github.com/blackstork-io/fabric/plugin" ) -type VersionMap map[string]string - type Runner struct { - cacheDir string pluginMap map[string]loadedPlugin dataMap map[string]loadedDataSource contentMap map[string]loadedContentProvider } -func validatePluginName(name string) hcl.Diagnostics { - parts := strings.Split(name, "/") - if len(parts) != 2 { - return hcl.Diagnostics{{ - Severity: hcl.DiagError, - Summary: "Invalid plugin name", - Detail: fmt.Sprintf("plugin name '%s' is not in the form '/'", name), - }} - } - for _, r := range parts[0] { - if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '-' { - return hcl.Diagnostics{{ - Severity: hcl.DiagError, - Summary: "Invalid plugin name", - Detail: fmt.Sprintf("plugin name '%s' contains invalid character: '%c'", name, r), - }} - } - } - for _, r := range parts[1] { - if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '-' { - return hcl.Diagnostics{{ - Severity: hcl.DiagError, - Summary: "Invalid plugin name", - Detail: fmt.Sprintf("plugin name '%s' contains invalid character: '%c'", name, r), - }} - } - } - return nil -} - -func validatePluginVersionMap(versionMap VersionMap) (diags hcl.Diagnostics) { - for name, version := range versionMap { - diags = validatePluginName(name).Extend(diags) - if version == "" { - diags = diags.Append(&hcl.Diagnostic{ - Severity: hcl.DiagError, - Summary: "Missing plugin version", - Detail: fmt.Sprintf("Missing plugin version for '%s'", name), - }) - continue - } - _, err := semver.NewConstraint(version) - if err != nil { - diags = diags.Append(&hcl.Diagnostic{ - Severity: hcl.DiagError, - Summary: "Invalid plugin version", - Detail: fmt.Sprintf("Invalid version constraint for '%s': %s", name, err), - }) - } - } - return diags -} - -func Load(o ...Option) (*Runner, hcl.Diagnostics) { - opts := defaultOptions - for _, opt := range o { - opt(&opts) - } - if diags := validatePluginVersionMap(opts.versionMap); diags.HasErrors() { - return nil, diags - } - loader := makeLoader(opts.pluginDir, opts.builtin, opts.versionMap) +func Load(binaryMap map[string]string, builtin *plugin.Schema, logger *slog.Logger) (*Runner, hcl.Diagnostics) { + loader := makeLoader(binaryMap, builtin, logger) if diags := loader.loadAll(); diags.HasErrors() { return nil, diags } return &Runner{ - cacheDir: opts.pluginDir, pluginMap: loader.pluginMap, dataMap: loader.dataMap, contentMap: loader.contentMap, @@ -98,8 +32,8 @@ func (m *Runner) DataSource(name string) (*plugin.DataSource, hcl.Diagnostics) { if !has { return nil, hcl.Diagnostics{{ Severity: hcl.DiagError, - Summary: "Data source not found", - Detail: fmt.Sprintf("data source '%s' not found in any plugin", name), + Summary: fmt.Sprintf("Missing data source '%s'", name), + Detail: fmt.Sprintf("'%s' not found in any plugin", name), }} } return source.DataSource, nil @@ -110,8 +44,8 @@ func (m *Runner) ContentProvider(name string) (*plugin.ContentProvider, hcl.Diag if !has { return nil, hcl.Diagnostics{{ Severity: hcl.DiagError, - Summary: "Content provider not found", - Detail: fmt.Sprintf("content provider '%s' not found in any plugin", name), + Summary: fmt.Sprintf("Missing content provider '%s'", name), + Detail: fmt.Sprintf("'%s' not found in any plugin", name), }} } return provider.ContentProvider, nil @@ -123,8 +57,8 @@ func (m *Runner) Close() hcl.Diagnostics { if err := p.closefn(); err != nil { diags = append(diags, &hcl.Diagnostic{ Severity: hcl.DiagWarning, - Summary: "Failed to close plugin", - Detail: fmt.Sprintf("failed to close plugin %s@%s: %v", p.Name, p.Version, err), + Summary: fmt.Sprintf("Failed to close plugin '%s'", p.Name), + Detail: err.Error(), }) } } diff --git a/test/e2e/data_test.go b/test/e2e/data_test.go index 75338309..2a02ed82 100644 --- a/test/e2e/data_test.go +++ b/test/e2e/data_test.go @@ -29,17 +29,18 @@ func dataTest(t *testing.T, testName string, files []string, target string, expe Mode: 0o777, } } - eval := cmd.NewEvaluator("") + eval := cmd.NewEvaluator() defer func() { eval.Cleanup(nil) }() var res plugin.Data diags := eval.ParseFabricFiles(sourceDir) + ctx := context.Background() if !diags.HasErrors() { - if !diags.Extend(eval.LoadRunner()) { + if !diags.Extend(eval.LoadPluginResolver(false)) && !diags.Extend(eval.LoadPluginRunner(ctx)) { var diag diagnostics.Diag - res, diag = cmd.Data(context.Background(), eval.Blocks, eval.PluginCaller(), target) + res, diag = cmd.Data(ctx, eval.Blocks, eval.PluginCaller(), target) diags.Extend(diag) } } diff --git a/test/e2e/render_test.go b/test/e2e/render_test.go index c90736cf..a0016383 100644 --- a/test/e2e/render_test.go +++ b/test/e2e/render_test.go @@ -28,17 +28,18 @@ func renderTest(t *testing.T, testName string, files []string, docName string, e Mode: 0o777, } } - eval := cmd.NewEvaluator("") + eval := cmd.NewEvaluator() defer func() { eval.Cleanup(nil) }() var res []string diags := eval.ParseFabricFiles(sourceDir) + ctx := context.Background() if !diags.HasErrors() { - if !diags.Extend(eval.LoadRunner()) { + if !diags.Extend(eval.LoadPluginResolver(false)) && !diags.Extend(eval.LoadPluginRunner(ctx)) { var diag diagnostics.Diag - res, diag = cmd.Render(context.Background(), eval.Blocks, eval.PluginCaller(), docName) + res, diag = cmd.Render(ctx, eval.Blocks, eval.PluginCaller(), docName) diags.Extend(diag) } } diff --git a/tools/pluginmeta/main.go b/tools/pluginmeta/main.go new file mode 100644 index 00000000..86adf0c3 --- /dev/null +++ b/tools/pluginmeta/main.go @@ -0,0 +1,217 @@ +package main + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "slices" + "strings" + "text/template" + + "github.com/spf13/pflag" + "gopkg.in/yaml.v3" +) + +var ( + version string + namespace string + output string + configFile string + osName string + archName string + plugin string +) + +// This is used to generate plugin metadata for release +func main() { + flags := pflag.NewFlagSet("pluginmeta", pflag.ExitOnError) + flags.StringVar(&namespace, "namespace", "blackstork", "namespace for plugins") + flags.StringVar(&configFile, "config", ".goreleaser.yaml", "path to goreleaser config") + flags.StringVar(&output, "output", ".tmp/plugins.json", "path to output plugins.json") + flags.StringVar(&version, "version", "0.0.0", "version for plugins") + flags.StringVar(&osName, "os", "", "os for patch") + flags.StringVar(&archName, "arch", "", "arch for patch") + flags.StringVar(&plugin, "plugin", "", "plugin for patch") + flags.Parse(os.Args[1:]) + args := flags.Args() + if len(args) == 1 && args[0] == "patch" { + // Patch metadata + meta, err := readMeta() + if err != nil { + panic(err) + } + err = patchMeta(meta, plugin, osName, archName) + if err != nil { + panic(err) + } + return + } + // Read and parse config + cfg, err := readConfig() + if err != nil { + panic(err) + } + meta, err := parseConfig(cfg) + if err != nil { + panic(err) + } + // Write metadata + err = os.MkdirAll(filepath.Dir(output), 0o755) + if err != nil { + panic(err) + } + err = writeMetadata(meta) + if err != nil { + panic(err) + } +} + +func patchMeta(meta *Metadata, plugin, osName, archName string) error { + split := strings.Split(filepath.Base(plugin), "@") + if len(split) != 2 { + return fmt.Errorf("invalid plugin name") + } + name := split[0] + var archive *PluginArchiveMetadata + for _, p := range meta.Plugins { + if p.Name != fmt.Sprintf("%s/%s", namespace, name) { + continue + } + for _, a := range p.Archives { + if a.OS != osName || a.Arch != archName { + continue + } + archive = a + break + } + break + } + if archive == nil { + return fmt.Errorf("archive not found") + } + f, err := os.Open(plugin) + if err != nil { + return err + } + defer f.Close() + h := sha256.New() + _, err = io.Copy(h, f) + if err != nil { + return err + } + archive.BinaryChecksum = base64.StdEncoding.EncodeToString(h.Sum(nil)) + return writeMetadata(meta) +} + +func readMeta() (*Metadata, error) { + f, err := os.Open(output) + if err != nil { + return nil, err + } + defer f.Close() + var meta Metadata + err = json.NewDecoder(f).Decode(&meta) + if err != nil { + return nil, err + } + return &meta, nil +} + +func readConfig() (*ReleaserConfig, error) { + f, err := os.Open(configFile) + if err != nil { + return nil, err + } + defer f.Close() + var config ReleaserConfig + err = yaml.NewDecoder(f).Decode(&config) + if err != nil { + return nil, err + } + return &config, nil +} + +// parseConfig creates metadata for plugins from the given goreleaser configuration. +func parseConfig(cfg *ReleaserConfig) (*Metadata, error) { + plugins := make([]*PluginMetadata, 0) + for _, artifact := range cfg.Archives { + if !strings.HasPrefix(artifact.ID, "plugin_") { + continue + } + if len(artifact.Builds) != 1 { + return nil, fmt.Errorf("plugin artifacts must have exactly one build") + } + buildIdx := slices.IndexFunc(cfg.Builds, func(b ReleaserBuild) bool { + return b.ID == artifact.Builds[0] + }) + if buildIdx == -1 { + return nil, fmt.Errorf("build not found") + } + build := cfg.Builds[buildIdx] + if len(build.GOOS) == 0 { + return nil, fmt.Errorf("build must have at least one GOOS") + } + plugin := &PluginMetadata{ + Name: namespace + "/" + strings.TrimPrefix(artifact.ID, "plugin_"), + Version: version, + Archives: make([]*PluginArchiveMetadata, 0), + } + tmpl := template.Must(template.New("name").Parse(artifact.NameTemplate)) + for _, goos := range build.GOOS { + archList := osArchList(goos) + ext := artifact.Format + for _, arch := range archList { + args := map[string]any{ + "Os": goos, + "Arch": arch, + "Arm": nil, + } + var filename strings.Builder + err := tmpl.Execute(&filename, args) + if err != nil { + return nil, err + } + binary := &PluginArchiveMetadata{ + Filename: filename.String() + "." + ext, + OS: goos, + Arch: arch, + } + plugin.Archives = append(plugin.Archives, binary) + } + } + plugins = append(plugins, plugin) + } + return &Metadata{Plugins: plugins}, nil +} + +func osArchList(goos string) []string { + switch goos { + case "linux": + return []string{"amd64", "arm64", "386"} + case "darwin": + return []string{"amd64", "arm64"} + case "windows": + return []string{"amd64", "386", "arm64"} + default: + return []string{} + } +} + +func writeMetadata(metadata *Metadata) error { + f, err := os.Create(output) + if err != nil { + return err + } + defer f.Close() + enc := json.NewEncoder(f) + enc.SetIndent("", " ") + err = enc.Encode(metadata) + if err != nil { + return err + } + return nil +} diff --git a/tools/pluginmeta/metadata.go b/tools/pluginmeta/metadata.go new file mode 100644 index 00000000..03dafa13 --- /dev/null +++ b/tools/pluginmeta/metadata.go @@ -0,0 +1,18 @@ +package main + +type Metadata struct { + Plugins []*PluginMetadata `json:"plugins"` +} + +type PluginMetadata struct { + Name string `json:"name"` + Version string `json:"version"` + Archives []*PluginArchiveMetadata `json:"archives"` +} + +type PluginArchiveMetadata struct { + Filename string `json:"filename"` + OS string `json:"os"` + Arch string `json:"arch"` + BinaryChecksum string `json:"binary_checksum"` +} diff --git a/tools/pluginmeta/releaser_config.go b/tools/pluginmeta/releaser_config.go new file mode 100644 index 00000000..12fab0d6 --- /dev/null +++ b/tools/pluginmeta/releaser_config.go @@ -0,0 +1,23 @@ +package main + +type ReleaserConfig struct { + Builds []ReleaserBuild `yaml:"builds"` + Archives []ReleaserArchive `yaml:"archives"` +} + +type ReleaserBuild struct { + ID string `yaml:"id"` + GOOS []string `yaml:"goos"` +} + +type ReleaserArchive struct { + ID string `yaml:"id"` + Format string `yaml:"format"` + Builds []string `yaml:"builds"` + NameTemplate string `yaml:"name_template"` +} + +type ReleaserFormatOverride struct { + GOOS string `yaml:"goos"` + Format string `yaml:"format"` +}