Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plugin backend reload capability #3112

Merged
merged 16 commits into from
Aug 8, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions logical/plugin/backend_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ func (b *backendPluginClient) HandleRequest(req *logical.Request) (*logical.Resp
}
var reply HandleRequestReply

if req.Connection != nil {
oldConnState := req.Connection.ConnState
req.Connection.ConnState = nil
defer func() {
req.Connection.ConnState = oldConnState
}()
}

err := b.client.Call("Plugin.HandleRequest", args, &reply)
if err != nil {
return nil, err
Expand Down Expand Up @@ -137,6 +145,14 @@ func (b *backendPluginClient) HandleExistenceCheck(req *logical.Request) (bool,
}
var reply HandleExistenceCheckReply

if req.Connection != nil {
oldConnState := req.Connection.ConnState
req.Connection.ConnState = nil
defer func() {
req.Connection.ConnState = oldConnState
}()
}

err := b.client.Call("Plugin.HandleExistenceCheck", args, &reply)
if err != nil {
return false, false, err
Expand Down
19 changes: 15 additions & 4 deletions logical/plugin/mock/path_internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,26 @@ import (
// it is used to test the invalidate func.
func pathInternal(b *backend) *framework.Path {
return &framework.Path{
Pattern: "internal",
Fields: map[string]*framework.FieldSchema{},
ExistenceCheck: b.pathExistenceCheck,
Pattern: "internal",
Fields: map[string]*framework.FieldSchema{
"value": &framework.FieldSchema{Type: framework.TypeString},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathInternalRead,
logical.UpdateOperation: b.pathInternalUpdate,
logical.ReadOperation: b.pathInternalRead,
},
}
}

func (b *backend) pathInternalUpdate(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
value := data.Get("value").(string)
b.internal = value
// Return the secret
return nil, nil

}

func (b *backend) pathInternalRead(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
// Return the secret
Expand Down
62 changes: 62 additions & 0 deletions vault/logical_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,27 @@ func NewSystemBackend(core *Core) *SystemBackend {
HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog"][0]),
HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]),
},
&framework.Path{
Pattern: "plugins/backend/reload$",

Fields: map[string]*framework.FieldSchema{
"plugin": &framework.FieldSchema{
Type: framework.TypeString,
Description: strings.TrimSpace(sysHelp["plugin-backend-reload-plugin"][0]),
},
"mounts": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: strings.TrimSpace(sysHelp["plugin-backend-reload-mounts"][0]),
},
},

Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.handlePluginReloadUpdate,
},

HelpSynopsis: strings.TrimSpace(sysHelp["plugin-reload"][0]),
HelpDescription: strings.TrimSpace(sysHelp["plugin-reload"][1]),
},
},
}

Expand Down Expand Up @@ -975,6 +996,32 @@ func (b *SystemBackend) handlePluginCatalogDelete(req *logical.Request, d *frame
return nil, nil
}

func (b *SystemBackend) handlePluginReloadUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
pluginName := d.Get("plugin").(string)
pluginMounts := d.Get("mounts").([]string)

if pluginName != "" && len(pluginMounts) > 0 {
return logical.ErrorResponse("plugin and mounts cannot be set at the same time"), nil
}
if pluginName == "" && len(pluginMounts) == 0 {
return logical.ErrorResponse("plugin or mounts must be provided"), nil
}

if pluginName != "" {
err := b.Core.reloadMatchingPlugin(pluginName)
if err != nil {
return nil, err
}
} else if len(pluginMounts) > 0 {
err := b.Core.reloadMatchingPluginMounts(pluginMounts)
if err != nil {
return nil, err
}
}

return nil, nil
}

// handleAuditedHeaderUpdate creates or overwrites a header entry
func (b *SystemBackend) handleAuditedHeaderUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
header := d.Get("header").(string)
Expand Down Expand Up @@ -2855,4 +2902,19 @@ This path responds to the following HTTP methods.
`The path to list leases under. Example: "aws/creds/deploy"`,
"",
},
"plugin-reload": {
"Reload mounts that use a particular backend plugin.",
`Reload mounts that use a particular backend plugin. Either the plugin name
or the desired plugin backend mounts must be provided, but not both. In the
case that the plugin name is provided, all mounted paths that use that plugin
backend will be reloaded.`,
},
"plugin-backend-reload-plugin": {
`The name of the plugin to reload, as registered in the plugin catalog.`,
"",
},
"plugin-backend-reload-mounts": {
`The mount paths of the plugin backends to reload.`,
"",
},
}
156 changes: 150 additions & 6 deletions vault/logical_system_integ_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package vault_test

import (
"fmt"
"os"
"testing"
"time"
Expand Down Expand Up @@ -28,11 +29,10 @@ func TestSystemBackend_enableAuth_plugin(t *testing.T) {
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
core := cluster.Cores[0].Core
vault.TestWaitActive(t, core)

core := cores[0]

b := vault.NewSystemBackend(core.Core)
b := vault.NewSystemBackend(core)
logger := logformat.NewVaultLogger(log.LevelTrace)
bc := &logical.BackendConfig{
Logger: logger,
Expand All @@ -49,7 +49,7 @@ func TestSystemBackend_enableAuth_plugin(t *testing.T) {

os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)

vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMain")
vault.TestAddTestPlugin(t, core, "mock-plugin", "TestBackend_PluginMainCredentials")

req := logical.TestRequest(t, logical.UpdateOperation, "auth/mock-plugin")
req.Data["type"] = "plugin"
Expand All @@ -64,7 +64,151 @@ func TestSystemBackend_enableAuth_plugin(t *testing.T) {
}
}

func TestBackend_PluginMain(t *testing.T) {
func TestSystemBackend_PluginReload(t *testing.T) {
data := map[string]interface{}{
"plugin": "mock-plugin",
}
t.Run("plugin", func(t *testing.T) { testSystemBackend_PluginReload(t, data) })

data = map[string]interface{}{
"mounts": "mock-0/,mock-1/",
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize this is making sure there's no error, but is there a way to verify that a reload actually happened other than no error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure how this can be done, since the backend has no access/knowledge of the underlying plugin process.

t.Run("mounts", func(t *testing.T) { testSystemBackend_PluginReload(t, data) })
}

func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}) {
cluster, b := testSystemBackendMock(t, 2)
defer cluster.Cleanup()

core := cluster.Cores[0]

for i := 0; i < 2; i++ {
// Update internal value in the backend
req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("mock-%d/internal", i))
req.ClientToken = core.Client.Token()
req.Data["value"] = "baz"
resp, err := core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %v", resp)
}
}

// Perform plugin reload
req := logical.TestRequest(t, logical.UpdateOperation, "plugins/backend/reload")
req.ClientToken = core.Client.Token()
req.Data = reqData
resp, err := b.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %v", resp)
}

for i := 0; i < 2; i++ {
// Ensure internal backed value is reset
req := logical.TestRequest(t, logical.ReadOperation, "mock-1/internal")
req.ClientToken = core.Client.Token()
resp, err := core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: response should not be nil")
}
if resp.Data["value"].(string) == "baz" {
t.Fatal("did not expect backend internal value to be 'baz'")
}
}
}

// testSystemBackendMock returns a systemBackend with the desired number
// of mounted mock plugin backends
func testSystemBackendMock(t *testing.T, numMounts int) (*vault.TestCluster, *vault.SystemBackend) {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"plugin": plugin.Factory,
},
}

cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()

core := cluster.Cores[0].Core
vault.TestWaitActive(t, core)

b := vault.NewSystemBackend(core)
logger := logformat.NewVaultLogger(log.LevelTrace)
bc := &logical.BackendConfig{
Logger: logger,
System: logical.StaticSystemView{
DefaultLeaseTTLVal: time.Hour * 24,
MaxLeaseTTLVal: time.Hour * 24 * 32,
},
}

err := b.Backend.Setup(bc)
if err != nil {
t.Fatal(err)
}

os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)

vault.TestAddTestPlugin(t, core, "mock-plugin", "TestBackend_PluginMainLogical")

for i := 0; i < numMounts; i++ {
req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("mounts/mock-%d/", i))
req.Data["type"] = "plugin"
req.Data["config"] = map[string]interface{}{
"plugin_name": "mock-plugin",
}

resp, err := b.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %v", resp)
}
}

return cluster, b
}

func TestBackend_PluginMainLogical(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return
}

caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv)
if caPEM == "" {
t.Fatal("CA cert not passed in")
}

factoryFunc := mock.FactoryType(logical.TypeLogical)

args := []string{"--ca-cert=" + caPEM}

apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(args)
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig)
err := lplugin.Serve(&lplugin.ServeOpts{
BackendFactoryFunc: factoryFunc,
TLSProviderFunc: tlsProviderFunc,
})
if err != nil {
t.Fatal(err)
}
}

func TestBackend_PluginMainCredentials(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return
}
Expand Down
1 change: 0 additions & 1 deletion vault/mount.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ func (c *Core) mount(entry *MountEntry) error {
conf["plugin_name"] = entry.Config.PluginName
}

// Consider having plugin name under entry.Options
backend, err := c.newLogicalBackend(entry.Type, sysView, view, conf)
if err != nil {
return err
Expand Down
Loading