diff --git a/changelog/16673.txt b/changelog/16673.txt new file mode 100644 index 000000000000..e632bfe47523 --- /dev/null +++ b/changelog/16673.txt @@ -0,0 +1,3 @@ +```release-note:bug +plugin/secrets/auth: Fix a bug with aliased backends such as aws-ec2 or generic +``` diff --git a/vault/auth.go b/vault/auth.go index 14685a2b6de3..5f5762da18c8 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -911,12 +911,12 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV f, ok := c.credentialBackends[t] if !ok { - plug, err := c.pluginCatalog.Get(ctx, entry.Type, consts.PluginTypeCredential) + plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeCredential) if err != nil { return nil, err } if plug == nil { - return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, entry.Type) + return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, t) } f = plugin.Factory diff --git a/vault/auth_test.go b/vault/auth_test.go index 5c9e4479025a..1ecf0495588f 100644 --- a/vault/auth_test.go +++ b/vault/auth_test.go @@ -220,6 +220,65 @@ func TestCore_EnableCredential(t *testing.T) { } } +// TestCore_EnableCredential_aws_ec2 tests that we can successfully mount aws +// auth using the alias "aws-ec2" +func TestCore_EnableCredential_aws_ec2(t *testing.T) { + c, keys, _ := TestCoreUnsealed(t) + c.credentialBackends["aws"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { + return &NoopBackend{ + BackendType: logical.TypeCredential, + }, nil + } + + me := &MountEntry{ + Table: credentialTableType, + Path: "foo", + Type: "aws-ec2", + } + err := c.enableCredential(namespace.RootContext(nil), me) + if err != nil { + t.Fatalf("err: %v", err) + } + + match := c.router.MatchingMount(namespace.RootContext(nil), "auth/foo/bar") + if match != "auth/foo/" { + t.Fatalf("missing mount, match: %q", match) + } + + inmemSink := metrics.NewInmemSink(1000000*time.Hour, 2000000*time.Hour) + conf := &CoreConfig{ + Physical: c.physical, + DisableMlock: true, + BuiltinRegistry: NewMockBuiltinRegistry(), + MetricSink: metricsutil.NewClusterMetricSink("test-cluster", inmemSink), + MetricsHelper: metricsutil.NewMetricsHelper(inmemSink, false), + } + c2, err := NewCore(conf) + if err != nil { + t.Fatalf("err: %v", err) + } + defer c2.Shutdown() + c2.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { + return &NoopBackend{ + BackendType: logical.TypeCredential, + }, nil + } + for i, key := range keys { + unseal, err := TestCoreUnseal(c2, key) + if err != nil { + t.Fatalf("err: %v", err) + } + if i+1 == len(keys) && !unseal { + t.Fatalf("should be unsealed") + } + } + + // Verify matching auth tables + if !reflect.DeepEqual(c.auth, c2.auth) { + t.Fatalf("mismatch: %v %v", c.auth, c2.auth) + } +} + // Test that the local table actually gets populated as expected with local // entries, and that upon reading the entries from both are recombined // correctly diff --git a/vault/mount.go b/vault/mount.go index db2d50b5bcf5..06671ba99ae3 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -1419,12 +1419,12 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView f, ok := c.logicalBackends[t] if !ok { - plug, err := c.pluginCatalog.Get(ctx, entry.Type, consts.PluginTypeSecrets) + plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeSecrets) if err != nil { return nil, err } if plug == nil { - return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, entry.Type) + return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, t) } f = plugin.Factory diff --git a/vault/mount_test.go b/vault/mount_test.go index 1462d475d9c9..523d064068b7 100644 --- a/vault/mount_test.go +++ b/vault/mount_test.go @@ -187,6 +187,54 @@ func TestCore_Mount(t *testing.T) { } } +// TestCore_Mount_kv_generic tests that we can successfully mount kv using the +// kv alias "generic" +func TestCore_Mount_kv_generic(t *testing.T) { + c, keys, _ := TestCoreUnsealed(t) + me := &MountEntry{ + Table: mountTableType, + Path: "foo", + Type: "generic", + } + err := c.mount(namespace.RootContext(nil), me) + if err != nil { + t.Fatalf("err: %v", err) + } + + match := c.router.MatchingMount(namespace.RootContext(nil), "foo/bar") + if match != "foo/" { + t.Fatalf("missing mount") + } + + inmemSink := metrics.NewInmemSink(1000000*time.Hour, 2000000*time.Hour) + conf := &CoreConfig{ + Physical: c.physical, + DisableMlock: true, + BuiltinRegistry: NewMockBuiltinRegistry(), + MetricSink: metricsutil.NewClusterMetricSink("test-cluster", inmemSink), + MetricsHelper: metricsutil.NewMetricsHelper(inmemSink, false), + } + c2, err := NewCore(conf) + if err != nil { + t.Fatalf("err: %v", err) + } + defer c2.Shutdown() + for i, key := range keys { + unseal, err := TestCoreUnseal(c2, key) + if err != nil { + t.Fatalf("err: %v", err) + } + if i+1 == len(keys) && !unseal { + t.Fatalf("should be unsealed") + } + } + + // Verify matching mount tables + if diff := deep.Equal(c.mounts.sortEntriesByPath(), c2.mounts.sortEntriesByPath()); len(diff) > 0 { + t.Fatalf("mismatch: %v", diff) + } +} + // Test that the local table actually gets populated as expected with local // entries, and that upon reading the entries from both are recombined // correctly diff --git a/vault/testing.go b/vault/testing.go index ee275f9a8930..7a9246090a5c 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -2166,6 +2166,7 @@ func NewMockBuiltinRegistry() *mockBuiltinRegistry { "mysql-database-plugin": consts.PluginTypeDatabase, "postgresql-database-plugin": consts.PluginTypeDatabase, "approle": consts.PluginTypeCredential, + "aws": consts.PluginTypeCredential, }, } } @@ -2188,6 +2189,15 @@ func (m *mockBuiltinRegistry) Get(name string, pluginType consts.PluginType) (fu return toFunc(approle.Factory), true } + if name == "aws" { + return toFunc(func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) { + b := new(framework.Backend) + b.Setup(ctx, config) + b.BackendType = logical.TypeCredential + return b, nil + }), true + } + if name == "postgresql-database-plugin" { return toFunc(func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) { b := new(framework.Backend)