From f975ea65c4249e2f37c1a7e23b00853c530035c9 Mon Sep 17 00:00:00 2001 From: John-Michael Faircloth Date: Thu, 11 Aug 2022 01:02:05 +0000 Subject: [PATCH] backport of commit fd6f90404a13b9afe6111363c2d0d913a3c5c386 --- changelog/16673.txt | 3 +++ vault/auth.go | 4 +-- vault/auth_test.go | 59 +++++++++++++++++++++++++++++++++++++++++++++ vault/mount.go | 4 +-- vault/mount_test.go | 48 ++++++++++++++++++++++++++++++++++++ vault/testing.go | 10 ++++++++ 6 files changed, 124 insertions(+), 4 deletions(-) create mode 100644 changelog/16673.txt diff --git a/changelog/16673.txt b/changelog/16673.txt new file mode 100644 index 000000000000..e632bfe47523 --- /dev/null +++ b/changelog/16673.txt @@ -0,0 +1,3 @@ +```release-note:bug +plugin/secrets/auth: Fix a bug with aliased backends such as aws-ec2 or generic +``` diff --git a/vault/auth.go b/vault/auth.go index 14685a2b6de3..5f5762da18c8 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -911,12 +911,12 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV f, ok := c.credentialBackends[t] if !ok { - plug, err := c.pluginCatalog.Get(ctx, entry.Type, consts.PluginTypeCredential) + plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeCredential) if err != nil { return nil, err } if plug == nil { - return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, entry.Type) + return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, t) } f = plugin.Factory diff --git a/vault/auth_test.go b/vault/auth_test.go index 5c9e4479025a..1ecf0495588f 100644 --- a/vault/auth_test.go +++ b/vault/auth_test.go @@ -220,6 +220,65 @@ func TestCore_EnableCredential(t *testing.T) { } } +// TestCore_EnableCredential_aws_ec2 tests that we can successfully mount aws +// auth using the alias "aws-ec2" +func TestCore_EnableCredential_aws_ec2(t *testing.T) { + c, keys, _ := TestCoreUnsealed(t) + c.credentialBackends["aws"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { + return &NoopBackend{ + BackendType: logical.TypeCredential, + }, nil + } + + me := &MountEntry{ + Table: credentialTableType, + Path: "foo", + Type: "aws-ec2", + } + err := c.enableCredential(namespace.RootContext(nil), me) + if err != nil { + t.Fatalf("err: %v", err) + } + + match := c.router.MatchingMount(namespace.RootContext(nil), "auth/foo/bar") + if match != "auth/foo/" { + t.Fatalf("missing mount, match: %q", match) + } + + inmemSink := metrics.NewInmemSink(1000000*time.Hour, 2000000*time.Hour) + conf := &CoreConfig{ + Physical: c.physical, + DisableMlock: true, + BuiltinRegistry: NewMockBuiltinRegistry(), + MetricSink: metricsutil.NewClusterMetricSink("test-cluster", inmemSink), + MetricsHelper: metricsutil.NewMetricsHelper(inmemSink, false), + } + c2, err := NewCore(conf) + if err != nil { + t.Fatalf("err: %v", err) + } + defer c2.Shutdown() + c2.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { + return &NoopBackend{ + BackendType: logical.TypeCredential, + }, nil + } + for i, key := range keys { + unseal, err := TestCoreUnseal(c2, key) + if err != nil { + t.Fatalf("err: %v", err) + } + if i+1 == len(keys) && !unseal { + t.Fatalf("should be unsealed") + } + } + + // Verify matching auth tables + if !reflect.DeepEqual(c.auth, c2.auth) { + t.Fatalf("mismatch: %v %v", c.auth, c2.auth) + } +} + // Test that the local table actually gets populated as expected with local // entries, and that upon reading the entries from both are recombined // correctly diff --git a/vault/mount.go b/vault/mount.go index ea51a707c687..eb535a8ceaa1 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -1410,12 +1410,12 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView f, ok := c.logicalBackends[t] if !ok { - plug, err := c.pluginCatalog.Get(ctx, entry.Type, consts.PluginTypeSecrets) + plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeSecrets) if err != nil { return nil, err } if plug == nil { - return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, entry.Type) + return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, t) } f = plugin.Factory diff --git a/vault/mount_test.go b/vault/mount_test.go index 1462d475d9c9..523d064068b7 100644 --- a/vault/mount_test.go +++ b/vault/mount_test.go @@ -187,6 +187,54 @@ func TestCore_Mount(t *testing.T) { } } +// TestCore_Mount_kv_generic tests that we can successfully mount kv using the +// kv alias "generic" +func TestCore_Mount_kv_generic(t *testing.T) { + c, keys, _ := TestCoreUnsealed(t) + me := &MountEntry{ + Table: mountTableType, + Path: "foo", + Type: "generic", + } + err := c.mount(namespace.RootContext(nil), me) + if err != nil { + t.Fatalf("err: %v", err) + } + + match := c.router.MatchingMount(namespace.RootContext(nil), "foo/bar") + if match != "foo/" { + t.Fatalf("missing mount") + } + + inmemSink := metrics.NewInmemSink(1000000*time.Hour, 2000000*time.Hour) + conf := &CoreConfig{ + Physical: c.physical, + DisableMlock: true, + BuiltinRegistry: NewMockBuiltinRegistry(), + MetricSink: metricsutil.NewClusterMetricSink("test-cluster", inmemSink), + MetricsHelper: metricsutil.NewMetricsHelper(inmemSink, false), + } + c2, err := NewCore(conf) + if err != nil { + t.Fatalf("err: %v", err) + } + defer c2.Shutdown() + for i, key := range keys { + unseal, err := TestCoreUnseal(c2, key) + if err != nil { + t.Fatalf("err: %v", err) + } + if i+1 == len(keys) && !unseal { + t.Fatalf("should be unsealed") + } + } + + // Verify matching mount tables + if diff := deep.Equal(c.mounts.sortEntriesByPath(), c2.mounts.sortEntriesByPath()); len(diff) > 0 { + t.Fatalf("mismatch: %v", diff) + } +} + // Test that the local table actually gets populated as expected with local // entries, and that upon reading the entries from both are recombined // correctly diff --git a/vault/testing.go b/vault/testing.go index 97a9c76d4342..6d8780d4ec92 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -2167,6 +2167,7 @@ func NewMockBuiltinRegistry() *mockBuiltinRegistry { "mysql-database-plugin": consts.PluginTypeDatabase, "postgresql-database-plugin": consts.PluginTypeDatabase, "approle": consts.PluginTypeCredential, + "aws": consts.PluginTypeCredential, }, } } @@ -2189,6 +2190,15 @@ func (m *mockBuiltinRegistry) Get(name string, pluginType consts.PluginType) (fu return toFunc(approle.Factory), true } + if name == "aws" { + return toFunc(func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) { + b := new(framework.Backend) + b.Setup(ctx, config) + b.BackendType = logical.TypeCredential + return b, nil + }), true + } + if name == "postgresql-database-plugin" { return dbPostgres.New, true }