From b1037baf76f25ab007af1bdf160baa100f3842cd Mon Sep 17 00:00:00 2001 From: John-Michael Faircloth Date: Thu, 3 Feb 2022 17:37:15 -0600 Subject: [PATCH 01/22] feat: DB plugin multiplexing (#13734) * WIP: start from main and get a plugin runner from core * move MultiplexedClient map to plugin catalog - call sys.NewPluginClient from PluginFactory - updates to getPluginClient - thread through isMetadataMode * use go-plugin ClientProtocol interface - call sys.NewPluginClient from dbplugin.NewPluginClient * move PluginSets to dbplugin package - export dbplugin HandshakeConfig - small refactor of PluginCatalog.getPluginClient * add removeMultiplexedClient; clean up on Close() - call client.Kill from plugin catalog - set rpcClient when muxed client exists * add ID to dbplugin.DatabasePluginClient struct * only create one plugin process per plugin type * update NewPluginClient to return connection ID to sdk - wrap grpc.ClientConn so we can inject the ID into context - get ID from context on grpc server * add v6 multiplexing protocol version * WIP: backwards compat for db plugins * Ensure locking on plugin catalog access - Create public GetPluginClient method for plugin catalog - rename postgres db plugin * use the New constructor for db plugins * grpc server: use write lock for Close and rlock for CRUD * cleanup MultiplexedClients on Close * remove TODO * fix multiplexing regression with grpc server connection * cleanup grpc server instances on close * embed ClientProtocol in Multiplexer interface * use PluginClientConfig arg to make NewPluginClient plugin type agnostic * create a new plugin process for non-muxed plugins --- builtin/logical/database/backend.go | 1 + .../database/path_config_connection.go | 4 + helper/builtinplugins/registry.go | 4 +- .../postgresql-database-plugin/main.go | 7 +- plugins/database/postgresql/postgresql.go | 1 - .../dbplugin/v5/grpc_database_plugin.go | 39 ++- sdk/database/dbplugin/v5/grpc_server.go | 127 ++++++++- sdk/database/dbplugin/v5/plugin_client.go | 65 ++--- sdk/database/dbplugin/v5/plugin_factory.go | 11 +- sdk/database/dbplugin/v5/plugin_server.go | 31 ++- sdk/helper/pluginutil/run_config.go | 44 +-- sdk/helper/pluginutil/runner.go | 31 ++- sdk/logical/system_view.go | 6 + sdk/plugin/grpc_system.go | 4 + vault/dynamic_system_view.go | 9 + vault/plugin_catalog.go | 262 +++++++++++++++--- 16 files changed, 528 insertions(+), 118 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 8572b0947eea..a88f2b6b068e 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -110,6 +110,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { } type databaseBackend struct { + // connections holds configured database connections by config name connections map[string]*dbPluginInstance logger log.Logger diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index bbe5769ea350..96a60d5a16f3 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -317,6 +317,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { if err != nil { return logical.ErrorResponse("error creating database object: %s", err), nil } + b.Logger().Debug("created database object", "name", name, "plugin_name", config.PluginName) initReq := v5.InitializeRequest{ Config: config.ConnectionDetails, @@ -365,6 +366,9 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { "Vault (or the sdk if using a custom plugin) to gain password policy support", config.PluginName)) } + if len(resp.Warnings) == 0 { + return nil, nil + } return resp, nil } } diff --git a/helper/builtinplugins/registry.go b/helper/builtinplugins/registry.go index 325658654b50..b0d96530bd11 100644 --- a/helper/builtinplugins/registry.go +++ b/helper/builtinplugins/registry.go @@ -150,8 +150,8 @@ type registry struct { logicalBackends map[string]logical.Factory } -// Get returns the BuiltinFactory func for a particular backend plugin -// from the plugins map. +// Get returns the Factory func for a particular backend plugin from the +// plugins map. func (r *registry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) { switch pluginType { case consts.PluginTypeCredential: diff --git a/plugins/database/postgresql/postgresql-database-plugin/main.go b/plugins/database/postgresql/postgresql-database-plugin/main.go index 3d2e14cd9aab..de168012ceb2 100644 --- a/plugins/database/postgresql/postgresql-database-plugin/main.go +++ b/plugins/database/postgresql/postgresql-database-plugin/main.go @@ -18,12 +18,7 @@ func main() { // Run instantiates a PostgreSQL object, and runs the RPC server for the plugin func Run() error { - dbType, err := postgresql.New() - if err != nil { - return err - } - - dbplugin.Serve(dbType.(dbplugin.Database)) + dbplugin.ServeMultiplex(postgresql.New) return nil } diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index dba0bf74595d..a3826d3a83f0 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -48,7 +48,6 @@ var ( singleQuotedPhrases = regexp.MustCompile(`('.*?')`) ) -// New implements builtinplugins.BuiltinFactory func New() (interface{}, error) { db := new() // Wrap the plugin with middleware to sanitize errors diff --git a/sdk/database/dbplugin/v5/grpc_database_plugin.go b/sdk/database/dbplugin/v5/grpc_database_plugin.go index 96d296ad799c..43840430c9fc 100644 --- a/sdk/database/dbplugin/v5/grpc_database_plugin.go +++ b/sdk/database/dbplugin/v5/grpc_database_plugin.go @@ -6,20 +6,26 @@ import ( "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" ) // handshakeConfigs are used to just do a basic handshake between // a plugin and host. If the handshake fails, a user friendly error is shown. // This prevents users from executing bad plugins or executing a plugin // directory. It is a UX feature, not a security feature. -var handshakeConfig = plugin.HandshakeConfig{ - ProtocolVersion: 5, +var HandshakeConfig = plugin.HandshakeConfig{ MagicCookieKey: "VAULT_DATABASE_PLUGIN", MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb", } +const multiplexingCtxKey string = "multiplex_id" + +// Factory is the factory function to create a dbplugin Database. +type Factory func() (interface{}, error) + type GRPCDatabasePlugin struct { - Impl Database + FactoryFunc Factory + Impl Database // Embeding this will disable the netRPC protocol plugin.NetRPCUnsupportedPlugin @@ -31,7 +37,9 @@ var ( ) func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error { - proto.RegisterDatabaseServer(s, gRPCServer{impl: d.Impl}) + server := gRPCServer{factoryFunc: d.FactoryFunc, instances: make(map[string]Database)} + + proto.RegisterDatabaseServer(s, server) return nil } @@ -42,3 +50,26 @@ func (GRPCDatabasePlugin) GRPCClient(doneCtx context.Context, _ *plugin.GRPCBrok } return client, nil } + +type databaseClientConn struct { + *grpc.ClientConn + id string +} + +var _ grpc.ClientConnInterface = &databaseClientConn{} + +func (d *databaseClientConn) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error { + // Inject ID to the context + md := metadata.Pairs(multiplexingCtxKey, d.id) + idCtx := metadata.NewOutgoingContext(ctx, md) + + return d.ClientConn.Invoke(idCtx, method, args, reply, opts...) +} + +func (d *databaseClientConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + // Inject ID to the context + md := metadata.Pairs(multiplexingCtxKey, d.id) + idCtx := metadata.NewOutgoingContext(ctx, md) + + return d.ClientConn.NewStream(idCtx, desc, method, opts...) +} diff --git a/sdk/database/dbplugin/v5/grpc_server.go b/sdk/database/dbplugin/v5/grpc_server.go index c75fa8ef0ed0..a5e4d0b72c64 100644 --- a/sdk/database/dbplugin/v5/grpc_server.go +++ b/sdk/database/dbplugin/v5/grpc_server.go @@ -3,11 +3,13 @@ package dbplugin import ( "context" "fmt" + "sync" "time" "github.com/golang/protobuf/ptypes" "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -16,11 +18,87 @@ var _ proto.DatabaseServer = gRPCServer{} type gRPCServer struct { proto.UnimplementedDatabaseServer - impl Database + factoryFunc func() (interface{}, error) + instances map[string]Database + sync.RWMutex +} + +func getMultiplexIDFromContext(ctx context.Context) (string, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", fmt.Errorf("missing plugin multiplexing metadata") + } + + multiplexIDs := md[multiplexingCtxKey] + if len(multiplexIDs) != 1 { + return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs)) + } + + multiplexID := multiplexIDs[0] + if multiplexID == "" { + return "", fmt.Errorf("empty multiplex ID in metadata") + } + + return multiplexID, nil +} + +func (g gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) { + g.Lock() + defer g.Unlock() + + id, err := getMultiplexIDFromContext(ctx) + if err != nil { + return nil, err + } + + if db, ok := g.instances[id]; ok { + return db, nil + } + + db, err := g.factoryFunc() + if err != nil { + return nil, err + } + + database := db.(Database) + g.instances[id] = database + + return database, nil +} + +// getDatabaseInternal returns the database but does not hold a lock +func (g gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) { + id, err := getMultiplexIDFromContext(ctx) + if err != nil { + return nil, err + } + + if id == "" { + return nil, fmt.Errorf("no instance ID found for multiplexed plugin") + } + + if db, ok := g.instances[id]; ok { + return db, nil + } + + return nil, fmt.Errorf("no database instance found") +} + +// getDatabase holds a read lock and returns the database +func (g gRPCServer) getDatabase(ctx context.Context) (Database, error) { + g.RLock() + impl, err := g.getDatabaseInternal(ctx) + g.RUnlock() + return impl, err } // Initialize the database plugin func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeRequest) (*proto.InitializeResponse, error) { + impl, err := g.getOrCreateDatabase(ctx) + if err != nil { + return nil, err + } + rawConfig := structToMap(request.ConfigData) dbReq := InitializeRequest{ @@ -28,7 +106,7 @@ func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeReq VerifyConnection: request.VerifyConnection, } - dbResp, err := g.impl.Initialize(ctx, dbReq) + dbResp, err := impl.Initialize(ctx, dbReq) if err != nil { return &proto.InitializeResponse{}, status.Errorf(codes.Internal, "failed to initialize: %s", err) } @@ -60,6 +138,11 @@ func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*pr expiration = exp } + impl, err := g.getDatabase(ctx) + if err != nil { + return nil, err + } + dbReq := NewUserRequest{ UsernameConfig: UsernameMetadata{ DisplayName: req.GetUsernameConfig().GetDisplayName(), @@ -71,7 +154,7 @@ func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*pr RollbackStatements: getStatementsFromProto(req.GetRollbackStatements()), } - dbResp, err := g.impl.NewUser(ctx, dbReq) + dbResp, err := impl.NewUser(ctx, dbReq) if err != nil { return &proto.NewUserResponse{}, status.Errorf(codes.Internal, "unable to create new user: %s", err) } @@ -92,7 +175,12 @@ func (g gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, err.Error()) } - _, err = g.impl.UpdateUser(ctx, dbReq) + impl, err := g.getDatabase(ctx) + if err != nil { + return nil, err + } + + _, err = impl.UpdateUser(ctx, dbReq) if err != nil { return &proto.UpdateUserResponse{}, status.Errorf(codes.Internal, "unable to update user: %s", err) } @@ -153,7 +241,12 @@ func (g gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest Statements: getStatementsFromProto(req.GetStatements()), } - _, err := g.impl.DeleteUser(ctx, dbReq) + impl, err := g.getDatabase(ctx) + if err != nil { + return nil, err + } + + _, err = impl.DeleteUser(ctx, dbReq) if err != nil { return &proto.DeleteUserResponse{}, status.Errorf(codes.Internal, "unable to delete user: %s", err) } @@ -161,7 +254,12 @@ func (g gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest } func (g gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeResponse, error) { - t, err := g.impl.Type() + impl, err := g.getOrCreateDatabase(ctx) + if err != nil { + return nil, err + } + + t, err := impl.Type() if err != nil { return &proto.TypeResponse{}, status.Errorf(codes.Internal, "unable to retrieve type: %s", err) } @@ -173,10 +271,25 @@ func (g gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeRespon } func (g gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) { - err := g.impl.Close() + g.Lock() + defer g.Unlock() + + impl, err := g.getDatabaseInternal(ctx) + if err != nil { + return nil, err + } + + err = impl.Close() if err != nil { return &proto.Empty{}, status.Errorf(codes.Internal, "unable to close database plugin: %s", err) } + + id, err := getMultiplexIDFromContext(ctx) + if err != nil { + return nil, err + } + delete(g.instances, id) + return &proto.Empty{}, nil } diff --git a/sdk/database/dbplugin/v5/plugin_client.go b/sdk/database/dbplugin/v5/plugin_client.go index d2e096110472..ac3af9a16dc5 100644 --- a/sdk/database/dbplugin/v5/plugin_client.go +++ b/sdk/database/dbplugin/v5/plugin_client.go @@ -3,19 +3,14 @@ package dbplugin import ( "context" "errors" - "sync" - log "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" "github.com/hashicorp/vault/sdk/helper/pluginutil" ) -// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's Close -// method to also call Kill() on the plugin.Client. type DatabasePluginClient struct { - client *plugin.Client - sync.Mutex - + client pluginutil.Multiplexer Database } @@ -23,42 +18,31 @@ type DatabasePluginClient struct { // and kill the plugin. func (dc *DatabasePluginClient) Close() error { err := dc.Database.Close() - dc.client.Kill() + dc.client.Close() return err } -// NewPluginClient returns a databaseRPCClient with a connection to a running -// plugin. The client is wrapped in a DatabasePluginClient object to ensure the -// plugin is killed on call of Close(). -func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (Database, error) { - // pluginSets is the map of plugins we can dispense. - pluginSets := map[int]plugin.PluginSet{ - 5: { - "database": new(GRPCDatabasePlugin), - }, - } - - client, err := pluginRunner.RunConfig(ctx, - pluginutil.Runner(sys), - pluginutil.PluginSets(pluginSets), - pluginutil.HandshakeConfig(handshakeConfig), - pluginutil.Logger(logger), - pluginutil.MetadataMode(isMetadataMode), - pluginutil.AutoMTLS(true), - ) - if err != nil { - return nil, err - } +// pluginSets is the map of plugins we can dispense. +var PluginSets = map[int]plugin.PluginSet{ + 5: { + "database": &GRPCDatabasePlugin{}, + }, + 6: { + "database": &GRPCDatabasePlugin{}, + }, +} - // Connect via RPC - rpcClient, err := client.Client() +// NewPluginClient returns a databaseRPCClient with a connection to a running +// plugin. +func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (Database, error) { + pluginClient, err := sys.NewPluginClient(ctx, pluginRunner, config) if err != nil { return nil, err } // Request the plugin - raw, err := rpcClient.Dispense("database") + raw, err := pluginClient.Dispense("database") if err != nil { return nil, err } @@ -68,14 +52,23 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne var db Database switch raw.(type) { case gRPCClient: - db = raw.(gRPCClient) + + gRPCClient := raw.(gRPCClient) + + // Wrap clientConn with our implementation so that we can inject the + // ID into the context + cc := &databaseClientConn{ + ClientConn: pluginClient.Conn(), + id: pluginClient.ID(), + } + gRPCClient.client = proto.NewDatabaseClient(cc) + db = gRPCClient default: return nil, errors.New("unsupported client type") } - // Wrap RPC implementation in DatabasePluginClient return &DatabasePluginClient{ - client: client, + client: pluginClient, Database: db, }, nil } diff --git a/sdk/database/dbplugin/v5/plugin_factory.go b/sdk/database/dbplugin/v5/plugin_factory.go index f203f1ed4cc7..c8b777870049 100644 --- a/sdk/database/dbplugin/v5/plugin_factory.go +++ b/sdk/database/dbplugin/v5/plugin_factory.go @@ -40,8 +40,16 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu transport = "builtin" } else { + config := pluginutil.PluginClientConfig{ + Name: pluginName, + PluginSets: PluginSets, + HandshakeConfig: HandshakeConfig, + Logger: namedLogger, + IsMetadataMode: false, + AutoMTLS: true, + } // create a DatabasePluginClient instance - db, err = NewPluginClient(ctx, sys, pluginRunner, namedLogger, false) + db, err = NewPluginClient(ctx, sys, pluginRunner, config) if err != nil { return nil, err } @@ -59,6 +67,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu if err != nil { return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err) } + logger.Debug("got database plugin instance", "type", typeStr) // Wrap with metrics middleware db = &databaseMetricsMiddleware{ diff --git a/sdk/database/dbplugin/v5/plugin_server.go b/sdk/database/dbplugin/v5/plugin_server.go index 11d04e6450a6..c1fc495ba757 100644 --- a/sdk/database/dbplugin/v5/plugin_server.go +++ b/sdk/database/dbplugin/v5/plugin_server.go @@ -31,7 +31,36 @@ func ServeConfig(db Database) *plugin.ServeConfig { } conf := &plugin.ServeConfig{ - HandshakeConfig: handshakeConfig, + HandshakeConfig: HandshakeConfig, + VersionedPlugins: pluginSets, + GRPCServer: plugin.DefaultGRPCServer, + } + + return conf +} + +func ServeMultiplex(factory Factory) { + plugin.Serve(ServeConfigMultiplex(factory)) +} + +func ServeConfigMultiplex(factory Factory) *plugin.ServeConfig { + err := pluginutil.OptionallyEnableMlock() + if err != nil { + fmt.Println(err) + return nil + } + + // pluginSets is the map of plugins we can dispense. + pluginSets := map[int]plugin.PluginSet{ + 6: { + "database": &GRPCDatabasePlugin{ + FactoryFunc: factory, + }, + }, + } + + conf := &plugin.ServeConfig{ + HandshakeConfig: HandshakeConfig, VersionedPlugins: pluginSets, GRPCServer: plugin.DefaultGRPCServer, } diff --git a/sdk/helper/pluginutil/run_config.go b/sdk/helper/pluginutil/run_config.go index f801287d7d4d..4c654c4991ea 100644 --- a/sdk/helper/pluginutil/run_config.go +++ b/sdk/helper/pluginutil/run_config.go @@ -12,6 +12,15 @@ import ( "github.com/hashicorp/vault/sdk/version" ) +type PluginClientConfig struct { + Name string + PluginSets map[int]plugin.PluginSet + HandshakeConfig plugin.HandshakeConfig + Logger log.Logger + IsMetadataMode bool + AutoMTLS bool +} + type runConfig struct { // Provided by PluginRunner command string @@ -21,12 +30,9 @@ type runConfig struct { // Initialized with what's in PluginRunner.Env, but can be added to env []string - wrapper RunnerUtil - pluginSets map[int]plugin.PluginSet - hs plugin.HandshakeConfig - logger log.Logger - isMetadataMode bool - autoMTLS bool + wrapper RunnerUtil + + PluginClientConfig } func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error) { @@ -39,14 +45,14 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error } cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version)) - if rc.isMetadataMode { - rc.logger = rc.logger.With("metadata", "true") + if rc.IsMetadataMode { + rc.Logger = rc.Logger.With("metadata", "true") } - metadataEnv := fmt.Sprintf("%s=%t", PluginMetadataModeEnv, rc.isMetadataMode) + metadataEnv := fmt.Sprintf("%s=%t", PluginMetadataModeEnv, rc.IsMetadataMode) cmd.Env = append(cmd.Env, metadataEnv) var clientTLSConfig *tls.Config - if !rc.autoMTLS && !rc.isMetadataMode { + if !rc.AutoMTLS && !rc.IsMetadataMode { // Get a CA TLS Certificate certBytes, key, err := generateCert() if err != nil { @@ -76,17 +82,17 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error } clientConfig := &plugin.ClientConfig{ - HandshakeConfig: rc.hs, - VersionedPlugins: rc.pluginSets, + HandshakeConfig: rc.HandshakeConfig, + VersionedPlugins: rc.PluginSets, Cmd: cmd, SecureConfig: secureConfig, TLSConfig: clientTLSConfig, - Logger: rc.logger, + Logger: rc.Logger, AllowedProtocols: []plugin.Protocol{ plugin.ProtocolNetRPC, plugin.ProtocolGRPC, }, - AutoMTLS: rc.autoMTLS, + AutoMTLS: rc.AutoMTLS, } return clientConfig, nil } @@ -117,31 +123,31 @@ func Runner(wrapper RunnerUtil) RunOpt { func PluginSets(pluginSets map[int]plugin.PluginSet) RunOpt { return func(rc *runConfig) { - rc.pluginSets = pluginSets + rc.PluginSets = pluginSets } } func HandshakeConfig(hs plugin.HandshakeConfig) RunOpt { return func(rc *runConfig) { - rc.hs = hs + rc.HandshakeConfig = hs } } func Logger(logger log.Logger) RunOpt { return func(rc *runConfig) { - rc.logger = logger + rc.Logger = logger } } func MetadataMode(isMetadataMode bool) RunOpt { return func(rc *runConfig) { - rc.isMetadataMode = isMetadataMode + rc.IsMetadataMode = isMetadataMode } } func AutoMTLS(autoMTLS bool) RunOpt { return func(rc *runConfig) { - rc.autoMTLS = autoMTLS + rc.AutoMTLS = autoMTLS } } diff --git a/sdk/helper/pluginutil/runner.go b/sdk/helper/pluginutil/runner.go index ecd60eeb3459..8c9be9cb6928 100644 --- a/sdk/helper/pluginutil/runner.go +++ b/sdk/helper/pluginutil/runner.go @@ -8,6 +8,7 @@ import ( plugin "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/wrapping" + "google.golang.org/grpc" ) // Looker defines the plugin Lookup function that looks into the plugin catalog @@ -21,6 +22,7 @@ type Looker interface { // configuration and wrapping data in a response wrapped token. // logical.SystemView implementations satisfy this interface. type RunnerUtil interface { + NewPluginClient(ctx context.Context, pluginRunner *PluginRunner, config PluginClientConfig) (Multiplexer, error) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) MlockEnabled() bool } @@ -31,17 +33,26 @@ type LookRunnerUtil interface { RunnerUtil } +type Multiplexer interface { + ID() string + Conn() *grpc.ClientConn + MultiplexingSupport() bool + + plugin.ClientProtocol +} + // PluginRunner defines the metadata needed to run a plugin securely with // go-plugin. type PluginRunner struct { - Name string `json:"name" structs:"name"` - Type consts.PluginType `json:"type" structs:"type"` - Command string `json:"command" structs:"command"` - Args []string `json:"args" structs:"args"` - Env []string `json:"env" structs:"env"` - Sha256 []byte `json:"sha256" structs:"sha256"` - Builtin bool `json:"builtin" structs:"builtin"` - BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"` + Name string `json:"name" structs:"name"` + Type consts.PluginType `json:"type" structs:"type"` + Command string `json:"command" structs:"command"` + Args []string `json:"args" structs:"args"` + Env []string `json:"env" structs:"env"` + Sha256 []byte `json:"sha256" structs:"sha256"` + Builtin bool `json:"builtin" structs:"builtin"` + BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"` + MultiplexingSupport bool `json:"multiplexing_support" structs:"multiplexing_support"` } // Run takes a wrapper RunnerUtil instance along with the go-plugin parameters and @@ -86,3 +97,7 @@ func CtxCancelIfCanceled(f context.CancelFunc, ctxCanceler context.Context) chan }() return quitCh } + +func MultiplexingSupport(version int) bool { + return version == 6 +} diff --git a/sdk/logical/system_view.go b/sdk/logical/system_view.go index 8ea6766b9941..e029b7764e94 100644 --- a/sdk/logical/system_view.go +++ b/sdk/logical/system_view.go @@ -56,6 +56,8 @@ type SystemView interface { // name. Returns a PluginRunner or an error if a plugin can not be found. LookupPlugin(context.Context, string, consts.PluginType) (*pluginutil.PluginRunner, error) + NewPluginClient(ctx context.Context, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (pluginutil.Multiplexer, error) + // MlockEnabled returns the configuration setting for enabling mlock on // plugins. MlockEnabled() bool @@ -152,6 +154,10 @@ func (d StaticSystemView) ReplicationState() consts.ReplicationState { return d.ReplicationStateVal } +func (d StaticSystemView) NewPluginClient(ctx context.Context, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (pluginutil.Multiplexer, error) { + return nil, errors.New("NewPluginClient is not implemented in StaticSystemView") +} + func (d StaticSystemView) ResponseWrapData(_ context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) { return nil, errors.New("ResponseWrapData is not implemented in StaticSystemView") } diff --git a/sdk/plugin/grpc_system.go b/sdk/plugin/grpc_system.go index ca7db03176fd..a6eae146ad3f 100644 --- a/sdk/plugin/grpc_system.go +++ b/sdk/plugin/grpc_system.go @@ -99,6 +99,10 @@ func (s *gRPCSystemViewClient) ResponseWrapData(ctx context.Context, data map[st return info, nil } +func (s *gRPCSystemViewClient) NewPluginClient(ctx context.Context, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (pluginutil.Multiplexer, error) { + return nil, fmt.Errorf("cannot call NewPluginClient from a plugin backend") +} + func (s *gRPCSystemViewClient) LookupPlugin(_ context.Context, _ string, _ consts.PluginType) (*pluginutil.PluginRunner, error) { return nil, fmt.Errorf("cannot call LookupPlugin from a plugin backend") } diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index ab63ae61bc2e..4c82cac2e19a 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -215,6 +215,15 @@ func (d dynamicSystemView) ResponseWrapData(ctx context.Context, data map[string return resp.WrapInfo, nil } +func (d dynamicSystemView) NewPluginClient(ctx context.Context, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (pluginutil.Multiplexer, error) { + c, err := d.core.pluginCatalog.GetPluginClient(ctx, d, pluginRunner, config) + if err != nil { + return nil, err + } + + return c, nil +} + // LookupPlugin looks for a plugin with the given name in the plugin catalog. It // returns a PluginRunner or an error if no plugin was found. func (d dynamicSystemView) LookupPlugin(ctx context.Context, name string, pluginType consts.PluginType) (*pluginutil.PluginRunner, error) { diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 71d4603e6e19..3af050689fa9 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -12,6 +12,8 @@ import ( log "github.com/hashicorp/go-hclog" multierror "github.com/hashicorp/go-multierror" + plugin "github.com/hashicorp/go-plugin" + "github.com/hashicorp/go-secure-stdlib/base62" v4 "github.com/hashicorp/vault/sdk/database/dbplugin" v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" "github.com/hashicorp/vault/sdk/helper/consts" @@ -19,6 +21,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" backendplugin "github.com/hashicorp/vault/sdk/plugin" + "google.golang.org/grpc" ) var ( @@ -35,15 +38,103 @@ type PluginCatalog struct { builtinRegistry BuiltinRegistry catalogView *BarrierView directory string + logger log.Logger + + // multiplexedClients holds plugin process connections by plugin name + // This allows a single grpc connection to communicate with multiple + // databases. Each database configuration using the same plugin will be + // routed to the existing plugin process. + multiplexedClients map[string]*MultiplexedClient lock sync.RWMutex } +type PluginClient struct { + logger log.Logger + id string + protocol plugin.ClientProtocol + // client handles the lifecycle of a plugin process + client *plugin.Client +} + +type MultiplexedClient struct { + logger log.Logger + + // name is the plugin name + name string + + client *plugin.Client + + connections map[string]*PluginClient + multiplexingSupport bool +} + +func (p *PluginClient) Conn() *grpc.ClientConn { + gc := p.protocol.(*plugin.GRPCClient) + return gc.Conn +} + +func (p *PluginClient) ID() string { + return p.id +} + +func (p *MultiplexedClient) ByID(id string) *PluginClient { + return p.connections[id] +} + +func (p *PluginClient) MultiplexingSupport() bool { + if p.client == nil { + return false + } + return pluginutil.MultiplexingSupport(p.client.NegotiatedVersion()) +} + +func (c *PluginCatalog) removeMultiplexedClient(name, id string) { + mpc, ok := c.multiplexedClients[name] + if !ok { + return + } + + mpc.connections[id].Close() + + delete(mpc.connections, id) + if len(mpc.connections) == 0 { + delete(c.multiplexedClients, name) + } +} + +func (p *PluginClient) Close() error { + p.logger.Debug("attempting to kill plugin process", "id", p.id) + p.client.Kill() + err := p.protocol.Close() + p.client = nil + p.protocol = nil + p.logger.Debug("killed plugin process", "id", p.id) + return err +} + +func (p *PluginClient) Dispense(name string) (interface{}, error) { + pluginInstance, err := p.protocol.Dispense(name) + if err != nil { + return nil, err + } + return pluginInstance, nil +} + +func (p *PluginClient) Ping() error { + err := p.protocol.Ping() + if err != nil { + return err + } + return nil +} + func (c *Core) setupPluginCatalog(ctx context.Context) error { c.pluginCatalog = &PluginCatalog{ builtinRegistry: c.builtinRegistry, catalogView: NewBarrierView(c.barrier, pluginCatalogPath), directory: c.pluginDirectory, + logger: c.logger, } // Run upgrade if untyped plugins exist @@ -59,14 +150,108 @@ func (c *Core) setupPluginCatalog(ctx context.Context) error { return nil } +func (c *PluginCatalog) getMultiplexedClient(pluginName string) *MultiplexedClient { + if mpc, ok := c.multiplexedClients[pluginName]; ok { + c.logger.Debug("MultiplexedClient exists", "pluginName", pluginName) + + return mpc + } + + c.logger.Debug("MultiplexedClient does not exist", "pluginName", pluginName) + + return c.newMultiplexedClient(pluginName) +} + +func (c *PluginCatalog) newMultiplexedClient(pluginName string) *MultiplexedClient { + if c.multiplexedClients == nil { + c.multiplexedClients = make(map[string]*MultiplexedClient) + c.logger.Debug("created multiplexedClients map") + } + + mpc := &MultiplexedClient{connections: make(map[string]*PluginClient), logger: c.logger} + + // set the MultiplexedClient for the given plugin name + c.multiplexedClients[pluginName] = mpc + c.logger.Debug("set the MultiplexedClient for", "pluginName", pluginName) + + return mpc +} + +// GetPluginClient returns a client for managing the lifecycle of a plugin +// process +func (c *PluginCatalog) GetPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (*PluginClient, error) { + c.lock.Lock() + pc, err := c.getPluginClient(ctx, sys, pluginRunner, config) + c.lock.Unlock() + return pc, err +} + +// getPluginClient returns a client for managing the lifecycle of a plugin +// process +func (c *PluginCatalog) getPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (*PluginClient, error) { + mpc := c.getMultiplexedClient(pluginRunner.Name) + c.logger.Debug("getPluginClient", "pluginRunner.MultiplexingSupport", pluginRunner.MultiplexingSupport) + + id, err := base62.Random(10) + if err != nil { + return nil, err + } + + pc := &PluginClient{ + id: id, + logger: c.logger, + } + if !pluginRunner.MultiplexingSupport || mpc.client == nil { + // get a new client + c.logger.Debug("spawning a new plugin process") + client, err := pluginRunner.RunConfig(ctx, + pluginutil.Runner(sys), + pluginutil.PluginSets(config.PluginSets), + pluginutil.HandshakeConfig(config.HandshakeConfig), + pluginutil.Logger(config.Logger), + pluginutil.MetadataMode(config.IsMetadataMode), + pluginutil.AutoMTLS(config.AutoMTLS), + ) + if err != nil { + return nil, err + } + + mpc.client = client + pc.client = client + + } + + if pluginRunner.MultiplexingSupport && mpc.client != nil { + // return existing client + c.logger.Debug("return existing client") + + pc.client = mpc.client + } + + // Get the protocol client for this connection. + // Subsequent calls to this will return the same client. + rpcClient, err := pc.client.Client() + if err != nil { + return nil, err + } + + pc.protocol = rpcClient + mpc.connections[id] = pc + mpc.multiplexingSupport = pc.MultiplexingSupport() + + mpc.name = pluginRunner.Name + + return mpc.connections[id], nil +} + // getPluginTypeFromUnknown will attempt to run the plugin to determine the // type. It will first attempt to run as a database plugin then a backend // plugin. Both of these will be run in metadata mode. -func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, error) { +func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, bool, error) { merr := &multierror.Error{} - err := isDatabasePlugin(ctx, plugin) + multiplexingSupport, err := c.isDatabasePlugin(ctx, plugin) if err == nil { - return consts.PluginTypeDatabase, nil + return consts.PluginTypeDatabase, multiplexingSupport, nil } merr = multierror.Append(merr, err) @@ -75,7 +260,7 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log if err == nil { err := client.Setup(ctx, &logical.BackendConfig{}) if err != nil { - return consts.PluginTypeUnknown, err + return consts.PluginTypeUnknown, false, err } backendType := client.Type() @@ -83,9 +268,9 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log switch backendType { case logical.TypeCredential: - return consts.PluginTypeCredential, nil + return consts.PluginTypeCredential, false, nil case logical.TypeLogical: - return consts.PluginTypeSecrets, nil + return consts.PluginTypeSecrets, false, nil } } else { merr = multierror.Append(merr, err) @@ -102,29 +287,38 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log "error", merr.Error()) } - return consts.PluginTypeUnknown, nil + return consts.PluginTypeUnknown, false, nil } -func isDatabasePlugin(ctx context.Context, plugin *pluginutil.PluginRunner) error { +func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (bool, error) { merr := &multierror.Error{} - // Attempt to run as database V5 plugin - v5Client, err := v5.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true) + config := pluginutil.PluginClientConfig{ + Name: pluginRunner.Name, + PluginSets: v5.PluginSets, + HandshakeConfig: v5.HandshakeConfig, + Logger: log.NewNullLogger(), + IsMetadataMode: true, + AutoMTLS: true, + } + // Attempt to run as database V5 or V6 plugin + v5Client, err := c.getPluginClient(ctx, nil, pluginRunner, config) if err == nil { + multiplexingSupport := v5Client.MultiplexingSupport() // Close the client and cleanup the plugin process - v5Client.Close() - return nil + c.removeMultiplexedClient(pluginRunner.Name, v5Client.ID()) + return multiplexingSupport, nil } merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v5: %w", err)) - v4Client, err := v4.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true) + v4Client, err := v4.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true) if err == nil { // Close the client and cleanup the plugin process v4Client.Close() - return nil + return false, nil } merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v4: %w", err)) - return merr.ErrorOrNil() + return false, merr.ErrorOrNil() } // UpdatePlugins will loop over all the plugins of unknown type and attempt to @@ -170,7 +364,7 @@ func (c *PluginCatalog) UpgradePlugins(ctx context.Context, logger log.Logger) e cmdOld := plugin.Command plugin.Command = filepath.Join(c.directory, plugin.Command) - pluginType, err := c.getPluginTypeFromUnknown(ctx, logger, plugin) + pluginType, multiplexingSupport, err := c.getPluginTypeFromUnknown(ctx, logger, plugin) if err != nil { retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: %s", pluginName, err)) continue @@ -181,7 +375,7 @@ func (c *PluginCatalog) UpgradePlugins(ctx context.Context, logger log.Logger) e } // Upgrade the storage - err = c.setInternal(ctx, pluginName, pluginType, cmdOld, plugin.Args, plugin.Env, plugin.Sha256) + err = c.setInternal(ctx, pluginName, pluginType, multiplexingSupport, cmdOld, plugin.Args, plugin.Env, plugin.Sha256) if err != nil { retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: %s", pluginName, err)) continue @@ -269,10 +463,10 @@ func (c *PluginCatalog) Set(ctx context.Context, name string, pluginType consts. c.lock.Lock() defer c.lock.Unlock() - return c.setInternal(ctx, name, pluginType, command, args, env, sha256) + return c.setInternal(ctx, name, pluginType, false, command, args, env, sha256) } -func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType consts.PluginType, command string, args []string, env []string, sha256 []byte) error { +func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType consts.PluginType, multiplexingSupport bool, command string, args []string, env []string, sha256 []byte) error { // Best effort check to make sure the command isn't breaking out of the // configured plugin directory. commandFull := filepath.Join(c.directory, command) @@ -294,15 +488,16 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType // entryTmp should only be used for the below type check, it uses the // full command instead of the relative command. entryTmp := &pluginutil.PluginRunner{ - Name: name, - Command: commandFull, - Args: args, - Env: env, - Sha256: sha256, - Builtin: false, + Name: name, + Command: commandFull, + Args: args, + Env: env, + Sha256: sha256, + Builtin: false, + MultiplexingSupport: multiplexingSupport, } - pluginType, err = c.getPluginTypeFromUnknown(ctx, log.Default(), entryTmp) + pluginType, multiplexingSupport, err = c.getPluginTypeFromUnknown(ctx, log.Default(), entryTmp) if err != nil { return err } @@ -312,13 +507,14 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType } entry := &pluginutil.PluginRunner{ - Name: name, - Type: pluginType, - Command: command, - Args: args, - Env: env, - Sha256: sha256, - Builtin: false, + Name: name, + Type: pluginType, + Command: command, + Args: args, + Env: env, + Sha256: sha256, + Builtin: false, + MultiplexingSupport: multiplexingSupport, } buf, err := json.Marshal(entry) From 4782d45bfdc4619887b13eca2525a89d071cf212 Mon Sep 17 00:00:00 2001 From: John-Michael Faircloth Date: Thu, 10 Feb 2022 12:11:04 -0600 Subject: [PATCH 02/22] feat: plugin multiplexing: handle plugin client cleanup (#13896) * use closure for plugin client cleanup * log and return errors; add comments * move rpcClient wrapping to core for ID injection * refactor core plugin client and sdk * remove unused ID method * refactor and only wrap clientConn on multiplexed plugins * rename structs and do not export types * Slight refactor of system view interface * Revert "Slight refactor of system view interface" This reverts commit 73d420e5cd2f0415e000c5a9284ea72a58016dd6. * Revert "Revert "Slight refactor of system view interface"" This reverts commit f75527008a1db06d04a23e04c3059674be8adb5f. * only provide pluginRunner arg to the internal newPluginClient method * embed ClientProtocol in pluginClient and name logger * Add back MLock support * remove enableMlock arg from setupPluginCatalog * rename plugin util interface to PluginClient Co-authored-by: Brian Kassouf --- .../database/path_config_connection.go | 3 +- .../dbplugin/v5/grpc_database_plugin.go | 26 -- sdk/database/dbplugin/v5/grpc_server.go | 3 +- sdk/database/dbplugin/v5/plugin_client.go | 22 +- sdk/database/dbplugin/v5/plugin_factory.go | 3 +- sdk/helper/pluginutil/run_config.go | 11 +- sdk/helper/pluginutil/runner.go | 13 +- sdk/logical/system_view.go | 6 +- sdk/plugin/grpc_system.go | 2 +- vault/dynamic_system_view.go | 11 +- vault/plugin_catalog.go | 315 +++++++++++------- 11 files changed, 236 insertions(+), 179 deletions(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 96a60d5a16f3..554d83fe3935 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -317,7 +317,6 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { if err != nil { return logical.ErrorResponse("error creating database object: %s", err), nil } - b.Logger().Debug("created database object", "name", name, "plugin_name", config.PluginName) initReq := v5.InitializeRequest{ Config: config.ConnectionDetails, @@ -330,6 +329,8 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { } config.ConnectionDetails = initResp.Config + b.Logger().Debug("created database object", "name", name, "plugin_name", config.PluginName) + b.Lock() defer b.Unlock() diff --git a/sdk/database/dbplugin/v5/grpc_database_plugin.go b/sdk/database/dbplugin/v5/grpc_database_plugin.go index 43840430c9fc..0419e350eb81 100644 --- a/sdk/database/dbplugin/v5/grpc_database_plugin.go +++ b/sdk/database/dbplugin/v5/grpc_database_plugin.go @@ -6,7 +6,6 @@ import ( "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" "google.golang.org/grpc" - "google.golang.org/grpc/metadata" ) // handshakeConfigs are used to just do a basic handshake between @@ -18,8 +17,6 @@ var HandshakeConfig = plugin.HandshakeConfig{ MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb", } -const multiplexingCtxKey string = "multiplex_id" - // Factory is the factory function to create a dbplugin Database. type Factory func() (interface{}, error) @@ -50,26 +47,3 @@ func (GRPCDatabasePlugin) GRPCClient(doneCtx context.Context, _ *plugin.GRPCBrok } return client, nil } - -type databaseClientConn struct { - *grpc.ClientConn - id string -} - -var _ grpc.ClientConnInterface = &databaseClientConn{} - -func (d *databaseClientConn) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error { - // Inject ID to the context - md := metadata.Pairs(multiplexingCtxKey, d.id) - idCtx := metadata.NewOutgoingContext(ctx, md) - - return d.ClientConn.Invoke(idCtx, method, args, reply, opts...) -} - -func (d *databaseClientConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { - // Inject ID to the context - md := metadata.Pairs(multiplexingCtxKey, d.id) - idCtx := metadata.NewOutgoingContext(ctx, md) - - return d.ClientConn.NewStream(idCtx, desc, method, opts...) -} diff --git a/sdk/database/dbplugin/v5/grpc_server.go b/sdk/database/dbplugin/v5/grpc_server.go index a5e4d0b72c64..f9ceef6dd937 100644 --- a/sdk/database/dbplugin/v5/grpc_server.go +++ b/sdk/database/dbplugin/v5/grpc_server.go @@ -8,6 +8,7 @@ import ( "github.com/golang/protobuf/ptypes" "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" + "github.com/hashicorp/vault/sdk/helper/pluginutil" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -29,7 +30,7 @@ func getMultiplexIDFromContext(ctx context.Context) (string, error) { return "", fmt.Errorf("missing plugin multiplexing metadata") } - multiplexIDs := md[multiplexingCtxKey] + multiplexIDs := md[pluginutil.MultiplexingCtxKey] if len(multiplexIDs) != 1 { return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs)) } diff --git a/sdk/database/dbplugin/v5/plugin_client.go b/sdk/database/dbplugin/v5/plugin_client.go index ac3af9a16dc5..4ef24cb8e242 100644 --- a/sdk/database/dbplugin/v5/plugin_client.go +++ b/sdk/database/dbplugin/v5/plugin_client.go @@ -10,7 +10,7 @@ import ( ) type DatabasePluginClient struct { - client pluginutil.Multiplexer + client pluginutil.PluginClient Database } @@ -35,8 +35,8 @@ var PluginSets = map[int]plugin.PluginSet{ // NewPluginClient returns a databaseRPCClient with a connection to a running // plugin. -func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (Database, error) { - pluginClient, err := sys.NewPluginClient(ctx, pluginRunner, config) +func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, config pluginutil.PluginClientConfig) (Database, error) { + pluginClient, err := sys.NewPluginClient(ctx, config) if err != nil { return nil, err } @@ -50,19 +50,13 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne // We should have a database type now. This feels like a normal interface // implementation but is in fact over an RPC connection. var db Database - switch raw.(type) { + switch c := raw.(type) { case gRPCClient: + // This is an abstraction leak from go-plugin but it is necessary in + // order to enable multiplexing on multiplexed plugins + c.client = proto.NewDatabaseClient(pluginClient.Conn()) - gRPCClient := raw.(gRPCClient) - - // Wrap clientConn with our implementation so that we can inject the - // ID into the context - cc := &databaseClientConn{ - ClientConn: pluginClient.Conn(), - id: pluginClient.ID(), - } - gRPCClient.client = proto.NewDatabaseClient(cc) - db = gRPCClient + db = c default: return nil, errors.New("unsupported client type") } diff --git a/sdk/database/dbplugin/v5/plugin_factory.go b/sdk/database/dbplugin/v5/plugin_factory.go index c8b777870049..b87dc3a75a60 100644 --- a/sdk/database/dbplugin/v5/plugin_factory.go +++ b/sdk/database/dbplugin/v5/plugin_factory.go @@ -42,6 +42,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu } else { config := pluginutil.PluginClientConfig{ Name: pluginName, + PluginType: consts.PluginTypeDatabase, PluginSets: PluginSets, HandshakeConfig: HandshakeConfig, Logger: namedLogger, @@ -49,7 +50,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu AutoMTLS: true, } // create a DatabasePluginClient instance - db, err = NewPluginClient(ctx, sys, pluginRunner, config) + db, err = NewPluginClient(ctx, sys, config) if err != nil { return nil, err } diff --git a/sdk/helper/pluginutil/run_config.go b/sdk/helper/pluginutil/run_config.go index 4c654c4991ea..cb804f60d873 100644 --- a/sdk/helper/pluginutil/run_config.go +++ b/sdk/helper/pluginutil/run_config.go @@ -9,16 +9,19 @@ import ( log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/version" ) type PluginClientConfig struct { Name string + PluginType consts.PluginType PluginSets map[int]plugin.PluginSet HandshakeConfig plugin.HandshakeConfig Logger log.Logger IsMetadataMode bool AutoMTLS bool + MLock bool } type runConfig struct { @@ -40,7 +43,7 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error cmd.Env = append(cmd.Env, rc.env...) // Add the mlock setting to the ENV of the plugin - if rc.wrapper != nil && rc.wrapper.MlockEnabled() { + if rc.MLock || (rc.wrapper != nil && rc.wrapper.MlockEnabled()) { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true")) } cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version)) @@ -151,6 +154,12 @@ func AutoMTLS(autoMTLS bool) RunOpt { } } +func MLock(mlock bool) RunOpt { + return func(rc *runConfig) { + rc.MLock = mlock + } +} + func (r *PluginRunner) RunConfig(ctx context.Context, opts ...RunOpt) (*plugin.Client, error) { rc := runConfig{ command: r.Command, diff --git a/sdk/helper/pluginutil/runner.go b/sdk/helper/pluginutil/runner.go index 8c9be9cb6928..fdb82b6007f6 100644 --- a/sdk/helper/pluginutil/runner.go +++ b/sdk/helper/pluginutil/runner.go @@ -22,7 +22,7 @@ type Looker interface { // configuration and wrapping data in a response wrapped token. // logical.SystemView implementations satisfy this interface. type RunnerUtil interface { - NewPluginClient(ctx context.Context, pluginRunner *PluginRunner, config PluginClientConfig) (Multiplexer, error) + NewPluginClient(ctx context.Context, config PluginClientConfig) (PluginClient, error) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) MlockEnabled() bool } @@ -33,14 +33,15 @@ type LookRunnerUtil interface { RunnerUtil } -type Multiplexer interface { - ID() string - Conn() *grpc.ClientConn - MultiplexingSupport() bool +type PluginClient interface { + Conn() grpc.ClientConnInterface + MultiplexingSupport() (bool, error) plugin.ClientProtocol } +const MultiplexingCtxKey string = "multiplex_id" + // PluginRunner defines the metadata needed to run a plugin securely with // go-plugin. type PluginRunner struct { @@ -98,6 +99,8 @@ func CtxCancelIfCanceled(f context.CancelFunc, ctxCanceler context.Context) chan return quitCh } +// MultiplexingSupport returns true if a plugin supports multiplexing. +// Currently this is hardcoded for database plugins. func MultiplexingSupport(version int) bool { return version == 6 } diff --git a/sdk/logical/system_view.go b/sdk/logical/system_view.go index e029b7764e94..83b4a951e842 100644 --- a/sdk/logical/system_view.go +++ b/sdk/logical/system_view.go @@ -56,7 +56,9 @@ type SystemView interface { // name. Returns a PluginRunner or an error if a plugin can not be found. LookupPlugin(context.Context, string, consts.PluginType) (*pluginutil.PluginRunner, error) - NewPluginClient(ctx context.Context, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (pluginutil.Multiplexer, error) + // NewPluginClient returns a client for managing the lifecycle of plugin + // processes + NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) // MlockEnabled returns the configuration setting for enabling mlock on // plugins. @@ -154,7 +156,7 @@ func (d StaticSystemView) ReplicationState() consts.ReplicationState { return d.ReplicationStateVal } -func (d StaticSystemView) NewPluginClient(ctx context.Context, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (pluginutil.Multiplexer, error) { +func (d StaticSystemView) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) { return nil, errors.New("NewPluginClient is not implemented in StaticSystemView") } diff --git a/sdk/plugin/grpc_system.go b/sdk/plugin/grpc_system.go index a6eae146ad3f..81bc324bf0fa 100644 --- a/sdk/plugin/grpc_system.go +++ b/sdk/plugin/grpc_system.go @@ -99,7 +99,7 @@ func (s *gRPCSystemViewClient) ResponseWrapData(ctx context.Context, data map[st return info, nil } -func (s *gRPCSystemViewClient) NewPluginClient(ctx context.Context, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (pluginutil.Multiplexer, error) { +func (s *gRPCSystemViewClient) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) { return nil, fmt.Errorf("cannot call NewPluginClient from a plugin backend") } diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 4c82cac2e19a..0c64f09329d1 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -215,8 +215,15 @@ func (d dynamicSystemView) ResponseWrapData(ctx context.Context, data map[string return resp.WrapInfo, nil } -func (d dynamicSystemView) NewPluginClient(ctx context.Context, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (pluginutil.Multiplexer, error) { - c, err := d.core.pluginCatalog.GetPluginClient(ctx, d, pluginRunner, config) +func (d dynamicSystemView) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) { + if d.core == nil { + return nil, fmt.Errorf("system view core is nil") + } + if d.core.pluginCatalog == nil { + return nil, fmt.Errorf("system view core plugin catalog is nil") + } + + c, err := d.core.pluginCatalog.NewPluginClient(ctx, config) if err != nil { return nil, err } diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 3af050689fa9..48bd0b9ac74e 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -22,6 +22,7 @@ import ( "github.com/hashicorp/vault/sdk/logical" backendplugin "github.com/hashicorp/vault/sdk/plugin" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" ) var ( @@ -40,93 +41,42 @@ type PluginCatalog struct { directory string logger log.Logger - // multiplexedClients holds plugin process connections by plugin name - // This allows a single grpc connection to communicate with multiple - // databases. Each database configuration using the same plugin will be - // routed to the existing plugin process. - multiplexedClients map[string]*MultiplexedClient + // externalPlugins holds plugin process connections by plugin name + // + // This allows plugins that suppport multiplexing to use a single grpc + // connection to communicate with multiple "backends". Each backend + // configuration using the same plugin will be routed to the existing + // plugin process. + externalPlugins map[string]*externalPlugin + mlockPlugins bool lock sync.RWMutex } -type PluginClient struct { - logger log.Logger - id string - protocol plugin.ClientProtocol - // client handles the lifecycle of a plugin process - client *plugin.Client -} - -type MultiplexedClient struct { +// pluginClient represents a connection to a plugin process +type pluginClient struct { logger log.Logger - // name is the plugin name - name string - - client *plugin.Client - - connections map[string]*PluginClient - multiplexingSupport bool -} - -func (p *PluginClient) Conn() *grpc.ClientConn { - gc := p.protocol.(*plugin.GRPCClient) - return gc.Conn -} - -func (p *PluginClient) ID() string { - return p.id -} - -func (p *MultiplexedClient) ByID(id string) *PluginClient { - return p.connections[id] -} - -func (p *PluginClient) MultiplexingSupport() bool { - if p.client == nil { - return false - } - return pluginutil.MultiplexingSupport(p.client.NegotiatedVersion()) -} - -func (c *PluginCatalog) removeMultiplexedClient(name, id string) { - mpc, ok := c.multiplexedClients[name] - if !ok { - return - } - - mpc.connections[id].Close() + // id is the connection ID + id string - delete(mpc.connections, id) - if len(mpc.connections) == 0 { - delete(c.multiplexedClients, name) - } -} + // client handles the lifecycle of a plugin process + // multiplexed plugins share the same client + client *plugin.Client + clientConn grpc.ClientConnInterface + cleanupFunc func() error -func (p *PluginClient) Close() error { - p.logger.Debug("attempting to kill plugin process", "id", p.id) - p.client.Kill() - err := p.protocol.Close() - p.client = nil - p.protocol = nil - p.logger.Debug("killed plugin process", "id", p.id) - return err + plugin.ClientProtocol } -func (p *PluginClient) Dispense(name string) (interface{}, error) { - pluginInstance, err := p.protocol.Dispense(name) - if err != nil { - return nil, err - } - return pluginInstance, nil -} +// externalPlugin holds client connections for multiplexed and +// non-multiplexed plugin processes +type externalPlugin struct { + // name is the plugin name + name string -func (p *PluginClient) Ping() error { - err := p.protocol.Ping() - if err != nil { - return err - } - return nil + // connections holds client connections by ID + connections map[string]*pluginClient } func (c *Core) setupPluginCatalog(ctx context.Context) error { @@ -135,12 +85,14 @@ func (c *Core) setupPluginCatalog(ctx context.Context) error { catalogView: NewBarrierView(c.barrier, pluginCatalogPath), directory: c.pluginDirectory, logger: c.logger, + mlockPlugins: c.enableMlock, } // Run upgrade if untyped plugins exist err := c.pluginCatalog.UpgradePlugins(ctx, c.logger) if err != nil { c.logger.Error("error while upgrading plugin storage", "error", err) + return err } if c.logger.IsInfo() { @@ -150,82 +102,171 @@ func (c *Core) setupPluginCatalog(ctx context.Context) error { return nil } -func (c *PluginCatalog) getMultiplexedClient(pluginName string) *MultiplexedClient { - if mpc, ok := c.multiplexedClients[pluginName]; ok { - c.logger.Debug("MultiplexedClient exists", "pluginName", pluginName) +type pluginClientConn struct { + *grpc.ClientConn + id string +} + +var _ grpc.ClientConnInterface = &pluginClientConn{} + +func (d *pluginClientConn) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error { + // Inject ID to the context + md := metadata.Pairs(pluginutil.MultiplexingCtxKey, d.id) + idCtx := metadata.NewOutgoingContext(ctx, md) + + return d.ClientConn.Invoke(idCtx, method, args, reply, opts...) +} + +func (d *pluginClientConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + // Inject ID to the context + md := metadata.Pairs(pluginutil.MultiplexingCtxKey, d.id) + idCtx := metadata.NewOutgoingContext(ctx, md) + + return d.ClientConn.NewStream(idCtx, desc, method, opts...) +} + +func (p *pluginClient) Conn() grpc.ClientConnInterface { + return p.clientConn +} + +// MultiplexingSupport determines if a plugin client supports multiplexing +func (p *pluginClient) MultiplexingSupport() (bool, error) { + if p.client == nil { + return false, fmt.Errorf("plugin client is nil") + } + return pluginutil.MultiplexingSupport(p.client.NegotiatedVersion()), nil +} + +// Close calls the plugin client's cleanupFunc to do any necessary cleanup on +// the plugin client and the PluginCatalog. This implements the +// plugin.ClientProtocol interface. +func (p *pluginClient) Close() error { + p.logger.Debug("cleaning up plugin client connection", "id", p.id) + return p.cleanupFunc() +} + +func (c *PluginCatalog) removePluginClient(name, id string) error { + var err error + extPlugin, ok := c.externalPlugins[name] + if !ok { + return fmt.Errorf("plugin client not found") + } - return mpc + pluginClient := extPlugin.connections[id] + multiplexingSupport, err := pluginClient.MultiplexingSupport() + if err != nil { + return err } - c.logger.Debug("MultiplexedClient does not exist", "pluginName", pluginName) + delete(extPlugin.connections, id) + if !multiplexingSupport { + pluginClient.client.Kill() - return c.newMultiplexedClient(pluginName) + if len(extPlugin.connections) == 0 { + delete(c.externalPlugins, name) + } + } else if len(extPlugin.connections) == 0 { + pluginClient.client.Kill() + delete(c.externalPlugins, name) + } + return err } -func (c *PluginCatalog) newMultiplexedClient(pluginName string) *MultiplexedClient { - if c.multiplexedClients == nil { - c.multiplexedClients = make(map[string]*MultiplexedClient) - c.logger.Debug("created multiplexedClients map") +func (c *PluginCatalog) getExternalPlugin(pluginName string) *externalPlugin { + if extPlugin, ok := c.externalPlugins[pluginName]; ok { + return extPlugin } - mpc := &MultiplexedClient{connections: make(map[string]*PluginClient), logger: c.logger} + return c.newExternalPlugin(pluginName) +} - // set the MultiplexedClient for the given plugin name - c.multiplexedClients[pluginName] = mpc - c.logger.Debug("set the MultiplexedClient for", "pluginName", pluginName) +func (c *PluginCatalog) newExternalPlugin(pluginName string) *externalPlugin { + if c.externalPlugins == nil { + c.externalPlugins = make(map[string]*externalPlugin) + } - return mpc + extPlugin := &externalPlugin{ + connections: make(map[string]*pluginClient), + name: pluginName, + } + + c.externalPlugins[pluginName] = extPlugin + return extPlugin } -// GetPluginClient returns a client for managing the lifecycle of a plugin +// NewPluginClient returns a client for managing the lifecycle of a plugin // process -func (c *PluginCatalog) GetPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (*PluginClient, error) { +func (c *PluginCatalog) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (*pluginClient, error) { c.lock.Lock() - pc, err := c.getPluginClient(ctx, sys, pluginRunner, config) - c.lock.Unlock() + defer c.lock.Unlock() + if config.Name == "" { + return nil, fmt.Errorf("no name provided for plugin") + } + if config.PluginType == consts.PluginTypeUnknown { + return nil, fmt.Errorf("no plugin type provided") + } + + pluginRunner, err := c.get(ctx, config.Name, config.PluginType) + if err != nil { + return nil, fmt.Errorf("failed to lookup plugin: %w", err) + } + if pluginRunner == nil { + return nil, fmt.Errorf("no plugin found") + } + pc, err := c.newPluginClient(ctx, pluginRunner, config) return pc, err } -// getPluginClient returns a client for managing the lifecycle of a plugin +// newPluginClient returns a client for managing the lifecycle of a plugin // process -func (c *PluginCatalog) getPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (*PluginClient, error) { - mpc := c.getMultiplexedClient(pluginRunner.Name) - c.logger.Debug("getPluginClient", "pluginRunner.MultiplexingSupport", pluginRunner.MultiplexingSupport) +func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (*pluginClient, error) { + if pluginRunner == nil { + return nil, fmt.Errorf("no plugin found") + } + extPlugin := c.getExternalPlugin(pluginRunner.Name) id, err := base62.Random(10) if err != nil { return nil, err } - pc := &PluginClient{ + pc := &pluginClient{ id: id, - logger: c.logger, + logger: c.logger.Named(pluginRunner.Name), + cleanupFunc: func() error { + return c.removePluginClient(pluginRunner.Name, id) + }, } - if !pluginRunner.MultiplexingSupport || mpc.client == nil { - // get a new client - c.logger.Debug("spawning a new plugin process") + + if !pluginRunner.MultiplexingSupport || len(extPlugin.connections) == 0 { + c.logger.Debug("spawning a new plugin process", "id", id) client, err := pluginRunner.RunConfig(ctx, - pluginutil.Runner(sys), pluginutil.PluginSets(config.PluginSets), pluginutil.HandshakeConfig(config.HandshakeConfig), pluginutil.Logger(config.Logger), pluginutil.MetadataMode(config.IsMetadataMode), - pluginutil.AutoMTLS(config.AutoMTLS), + pluginutil.MLock(c.mlockPlugins), + + // NewPluginClient only supports AutoMTLS today + pluginutil.AutoMTLS(true), ) if err != nil { return nil, err } - mpc.client = client pc.client = client + } else { + c.logger.Debug("returning existing plugin client for multiplexed plugin", "id", id) - } - - if pluginRunner.MultiplexingSupport && mpc.client != nil { - // return existing client - c.logger.Debug("return existing client") + // get the first client, since they are all the same + for k := range extPlugin.connections { + pc.client = extPlugin.connections[k].client + break + } - pc.client = mpc.client + if pc.client == nil { + return nil, fmt.Errorf("plugin client is nil") + } } // Get the protocol client for this connection. @@ -235,18 +276,29 @@ func (c *PluginCatalog) getPluginClient(ctx context.Context, sys pluginutil.Runn return nil, err } - pc.protocol = rpcClient - mpc.connections[id] = pc - mpc.multiplexingSupport = pc.MultiplexingSupport() + clientConn := rpcClient.(*plugin.GRPCClient).Conn - mpc.name = pluginRunner.Name + if pluginRunner.MultiplexingSupport { + // Wrap rpcClient with our implementation so that we can inject the + // ID into the context + pc.clientConn = &pluginClientConn{ + ClientConn: clientConn, + id: id, + } + } else { + pc.clientConn = clientConn + } + pc.ClientProtocol = rpcClient + extPlugin.connections[id] = pc + extPlugin.name = pluginRunner.Name - return mpc.connections[id], nil + return extPlugin.connections[id], nil } // getPluginTypeFromUnknown will attempt to run the plugin to determine the -// type. It will first attempt to run as a database plugin then a backend -// plugin. Both of these will be run in metadata mode. +// type and if it supports multiplexing. It will first attempt to run as a +// database plugin then a backend plugin. Both of these will be run in metadata +// mode. func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, bool, error) { merr := &multierror.Error{} multiplexingSupport, err := c.isDatabasePlugin(ctx, plugin) @@ -290,22 +342,33 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log return consts.PluginTypeUnknown, false, nil } +// isDatabasePlugin returns true if the plugin supports multiplexing. An error +// is returned if the plugin is not a database plugin. func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (bool, error) { merr := &multierror.Error{} config := pluginutil.PluginClientConfig{ Name: pluginRunner.Name, PluginSets: v5.PluginSets, + PluginType: consts.PluginTypeDatabase, HandshakeConfig: v5.HandshakeConfig, Logger: log.NewNullLogger(), IsMetadataMode: true, AutoMTLS: true, } - // Attempt to run as database V5 or V6 plugin - v5Client, err := c.getPluginClient(ctx, nil, pluginRunner, config) + // Attempt to run as database V5 or V6 multiplexed plugin + v5Client, err := c.newPluginClient(ctx, pluginRunner, config) if err == nil { - multiplexingSupport := v5Client.MultiplexingSupport() + // At this point the pluginRunner does not know if multiplexing is + // supported or not. So we need to ask the plugin client itself. + multiplexingSupport, err := v5Client.MultiplexingSupport() + if err != nil { + return false, err + } + // Close the client and cleanup the plugin process - c.removeMultiplexedClient(pluginRunner.Name, v5Client.ID()) + err = v5Client.Close() + c.logger.Error("error closing plugin client", "error", err) + return multiplexingSupport, nil } merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v5: %w", err)) @@ -313,7 +376,9 @@ func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *plug v4Client, err := v4.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true) if err == nil { // Close the client and cleanup the plugin process - v4Client.Close() + err = v4Client.Close() + c.logger.Error("error closing plugin client", "error", err) + return false, nil } merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v4: %w", err)) From 25620a151245963710677ec72ef5b41fb099a725 Mon Sep 17 00:00:00 2001 From: John-Michael Faircloth Date: Fri, 11 Feb 2022 11:34:40 -0600 Subject: [PATCH 03/22] feature: multiplexing: fix unit tests (#14007) * fix grpc_server tests and add coverage * update run_config tests * add happy path test case for grpc_server ID from context * update test helpers --- sdk/database/dbplugin/v5/grpc_server.go | 4 - sdk/database/dbplugin/v5/grpc_server_test.go | 136 +++++++++++++------ sdk/helper/pluginutil/run_config_test.go | 101 ++++++++------ 3 files changed, 151 insertions(+), 90 deletions(-) diff --git a/sdk/database/dbplugin/v5/grpc_server.go b/sdk/database/dbplugin/v5/grpc_server.go index f9ceef6dd937..555fd2e34491 100644 --- a/sdk/database/dbplugin/v5/grpc_server.go +++ b/sdk/database/dbplugin/v5/grpc_server.go @@ -74,10 +74,6 @@ func (g gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) { return nil, err } - if id == "" { - return nil, fmt.Errorf("no instance ID found for multiplexed plugin") - } - if db, ok := g.instances[id]; ok { return db, nil } diff --git a/sdk/database/dbplugin/v5/grpc_server_test.go b/sdk/database/dbplugin/v5/grpc_server_test.go index d3861c25445d..52b488266481 100644 --- a/sdk/database/dbplugin/v5/grpc_server_test.go +++ b/sdk/database/dbplugin/v5/grpc_server_test.go @@ -13,7 +13,9 @@ import ( "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/timestamp" "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" + "github.com/hashicorp/vault/sdk/helper/pluginutil" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -77,14 +79,9 @@ func TestGRPCServer_Initialize(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - g := gRPCServer{ - impl: test.db, - } - - // Context doesn't need to timeout since this is just passed through - ctx := context.Background() + idCtx, g := testGrpcServer(t, test.db) + resp, err := g.Initialize(idCtx, test.req) - resp, err := g.Initialize(ctx, test.req) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } @@ -252,14 +249,9 @@ func TestGRPCServer_NewUser(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - g := gRPCServer{ - impl: test.db, - } - - // Context doesn't need to timeout since this is just passed through - ctx := context.Background() + idCtx, g := testGrpcServer(t, test.db) + resp, err := g.NewUser(idCtx, test.req) - resp, err := g.NewUser(ctx, test.req) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } @@ -362,14 +354,9 @@ func TestGRPCServer_UpdateUser(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - g := gRPCServer{ - impl: test.db, - } - - // Context doesn't need to timeout since this is just passed through - ctx := context.Background() + idCtx, g := testGrpcServer(t, test.db) + resp, err := g.UpdateUser(idCtx, test.req) - resp, err := g.UpdateUser(ctx, test.req) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } @@ -430,14 +417,9 @@ func TestGRPCServer_DeleteUser(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - g := gRPCServer{ - impl: test.db, - } + idCtx, g := testGrpcServer(t, test.db) + resp, err := g.DeleteUser(idCtx, test.req) - // Context doesn't need to timeout since this is just passed through - ctx := context.Background() - - resp, err := g.DeleteUser(ctx, test.req) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } @@ -488,14 +470,9 @@ func TestGRPCServer_Type(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - g := gRPCServer{ - impl: test.db, - } - - // Context doesn't need to timeout since this is just passed through - ctx := context.Background() + idCtx, g := testGrpcServer(t, test.db) + resp, err := g.Type(idCtx, &proto.Empty{}) - resp, err := g.Type(ctx, &proto.Empty{}) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } @@ -539,14 +516,9 @@ func TestGRPCServer_Close(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - g := gRPCServer{ - impl: test.db, - } - - // Context doesn't need to timeout since this is just passed through - ctx := context.Background() + idCtx, g := testGrpcServer(t, test.db) + _, err := g.Close(idCtx, &proto.Empty{}) - _, err := g.Close(ctx, &proto.Empty{}) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } @@ -562,6 +534,86 @@ func TestGRPCServer_Close(t *testing.T) { } } +func TestGetMultiplexIDFromContext(t *testing.T) { + type testCase struct { + ctx context.Context + expectedResp string + expectedErr error + } + + tests := map[string]testCase{ + "missing plugin multiplexing metadata": { + ctx: context.Background(), + expectedResp: "", + expectedErr: fmt.Errorf("missing plugin multiplexing metadata"), + }, + "unexpected number of IDs in metadata": { + ctx: idCtx(t, "12345", "67891"), + expectedResp: "", + expectedErr: fmt.Errorf("unexpected number of IDs in metadata: (2)"), + }, + "empty multiplex ID in metadata": { + ctx: idCtx(t, ""), + expectedResp: "", + expectedErr: fmt.Errorf("empty multiplex ID in metadata"), + }, + "happy path, id is returned from metadata": { + ctx: idCtx(t, "12345"), + expectedResp: "12345", + expectedErr: nil, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + resp, err := getMultiplexIDFromContext(test.ctx) + + if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil { + t.Fatalf("err expected, got nil") + } else if !reflect.DeepEqual(err, test.expectedErr) { + t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr) + } + + if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil { + t.Fatalf("no error expected, got: %s", err) + } + + if !reflect.DeepEqual(resp, test.expectedResp) { + t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp) + } + }) + } +} + +func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) { + t.Helper() + g := gRPCServer{ + factoryFunc: func() (interface{}, error) { + return db, nil + }, + instances: make(map[string]Database), + } + + id := "12345" + idCtx := idCtx(t, id) + g.instances[id] = db + + return idCtx, g +} + +// idCtx is a test helper that will return a context with the IDs set in its +// metadata +func idCtx(t *testing.T, ids ...string) context.Context { + t.Helper() + // Context doesn't need to timeout since this is just passed through + ctx := context.Background() + md := metadata.MD{} + for _, id := range ids { + md.Append(pluginutil.MultiplexingCtxKey, id) + } + return metadata.NewIncomingContext(ctx, md) +} + func marshal(t *testing.T, m map[string]interface{}) *structpb.Struct { t.Helper() diff --git a/sdk/helper/pluginutil/run_config_test.go b/sdk/helper/pluginutil/run_config_test.go index 239d36852382..f2373fe9b4a5 100644 --- a/sdk/helper/pluginutil/run_config_test.go +++ b/sdk/helper/pluginutil/run_config_test.go @@ -38,19 +38,21 @@ func TestMakeConfig(t *testing.T) { args: []string{"foo", "bar"}, sha256: []byte("some_sha256"), env: []string{"initial=true"}, - pluginSets: map[int]plugin.PluginSet{ - 1: { - "bogus": nil, + PluginClientConfig: PluginClientConfig{ + PluginSets: map[int]plugin.PluginSet{ + 1: { + "bogus": nil, + }, }, + HandshakeConfig: plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "magic_cookie_key", + MagicCookieValue: "magic_cookie_value", + }, + Logger: hclog.NewNullLogger(), + IsMetadataMode: true, + AutoMTLS: false, }, - hs: plugin.HandshakeConfig{ - ProtocolVersion: 1, - MagicCookieKey: "magic_cookie_key", - MagicCookieValue: "magic_cookie_value", - }, - logger: hclog.NewNullLogger(), - isMetadataMode: true, - autoMTLS: false, }, responseWrapInfoTimes: 0, @@ -97,19 +99,21 @@ func TestMakeConfig(t *testing.T) { args: []string{"foo", "bar"}, sha256: []byte("some_sha256"), env: []string{"initial=true"}, - pluginSets: map[int]plugin.PluginSet{ - 1: { - "bogus": nil, + PluginClientConfig: PluginClientConfig{ + PluginSets: map[int]plugin.PluginSet{ + 1: { + "bogus": nil, + }, }, + HandshakeConfig: plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "magic_cookie_key", + MagicCookieValue: "magic_cookie_value", + }, + Logger: hclog.NewNullLogger(), + IsMetadataMode: false, + AutoMTLS: false, }, - hs: plugin.HandshakeConfig{ - ProtocolVersion: 1, - MagicCookieKey: "magic_cookie_key", - MagicCookieValue: "magic_cookie_value", - }, - logger: hclog.NewNullLogger(), - isMetadataMode: false, - autoMTLS: false, }, responseWrapInfo: &wrapping.ResponseWrapInfo{ @@ -161,19 +165,21 @@ func TestMakeConfig(t *testing.T) { args: []string{"foo", "bar"}, sha256: []byte("some_sha256"), env: []string{"initial=true"}, - pluginSets: map[int]plugin.PluginSet{ - 1: { - "bogus": nil, + PluginClientConfig: PluginClientConfig{ + PluginSets: map[int]plugin.PluginSet{ + 1: { + "bogus": nil, + }, }, + HandshakeConfig: plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "magic_cookie_key", + MagicCookieValue: "magic_cookie_value", + }, + Logger: hclog.NewNullLogger(), + IsMetadataMode: true, + AutoMTLS: true, }, - hs: plugin.HandshakeConfig{ - ProtocolVersion: 1, - MagicCookieKey: "magic_cookie_key", - MagicCookieValue: "magic_cookie_value", - }, - logger: hclog.NewNullLogger(), - isMetadataMode: true, - autoMTLS: true, }, responseWrapInfoTimes: 0, @@ -220,19 +226,21 @@ func TestMakeConfig(t *testing.T) { args: []string{"foo", "bar"}, sha256: []byte("some_sha256"), env: []string{"initial=true"}, - pluginSets: map[int]plugin.PluginSet{ - 1: { - "bogus": nil, + PluginClientConfig: PluginClientConfig{ + PluginSets: map[int]plugin.PluginSet{ + 1: { + "bogus": nil, + }, }, + HandshakeConfig: plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "magic_cookie_key", + MagicCookieValue: "magic_cookie_value", + }, + Logger: hclog.NewNullLogger(), + IsMetadataMode: false, + AutoMTLS: true, }, - hs: plugin.HandshakeConfig{ - ProtocolVersion: 1, - MagicCookieKey: "magic_cookie_key", - MagicCookieValue: "magic_cookie_value", - }, - logger: hclog.NewNullLogger(), - isMetadataMode: false, - autoMTLS: true, }, responseWrapInfoTimes: 0, @@ -329,6 +337,11 @@ type mockRunnerUtil struct { mock.Mock } +func (m *mockRunnerUtil) NewPluginClient(ctx context.Context, config PluginClientConfig) (PluginClient, error) { + args := m.Called(ctx, config) + return args.Get(0).(PluginClient), args.Error(1) +} + func (m *mockRunnerUtil) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) { args := m.Called(ctx, data, ttl, jwt) return args.Get(0).(*wrapping.ResponseWrapInfo), args.Error(1) From c4ec4730de31fdfa6f925301a3e5a482d293df13 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Fri, 11 Feb 2022 11:38:00 -0600 Subject: [PATCH 04/22] feat: multiplexing: handle v5 plugin compiled with new sdk --- sdk/database/dbplugin/v5/grpc_database_plugin.go | 3 +++ sdk/database/dbplugin/v5/grpc_server.go | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/sdk/database/dbplugin/v5/grpc_database_plugin.go b/sdk/database/dbplugin/v5/grpc_database_plugin.go index 0419e350eb81..8bd1a032ad43 100644 --- a/sdk/database/dbplugin/v5/grpc_database_plugin.go +++ b/sdk/database/dbplugin/v5/grpc_database_plugin.go @@ -35,6 +35,9 @@ var ( func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error { server := gRPCServer{factoryFunc: d.FactoryFunc, instances: make(map[string]Database)} + if d.Impl != nil { + server = gRPCServer{singleImpl: d.Impl} + } proto.RegisterDatabaseServer(s, server) return nil diff --git a/sdk/database/dbplugin/v5/grpc_server.go b/sdk/database/dbplugin/v5/grpc_server.go index 555fd2e34491..d7cd10b11c42 100644 --- a/sdk/database/dbplugin/v5/grpc_server.go +++ b/sdk/database/dbplugin/v5/grpc_server.go @@ -19,6 +19,8 @@ var _ proto.DatabaseServer = gRPCServer{} type gRPCServer struct { proto.UnimplementedDatabaseServer + singleImpl Database + factoryFunc func() (interface{}, error) instances map[string]Database sync.RWMutex @@ -47,6 +49,10 @@ func (g gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) { g.Lock() defer g.Unlock() + if g.singleImpl != nil { + return g.singleImpl, nil + } + id, err := getMultiplexIDFromContext(ctx) if err != nil { return nil, err @@ -69,6 +75,10 @@ func (g gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) { // getDatabaseInternal returns the database but does not hold a lock func (g gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) { + if g.singleImpl != nil { + return g.singleImpl, nil + } + id, err := getMultiplexIDFromContext(ctx) if err != nil { return nil, err From 51eabad7476c4a99cf50425c851f628df1ad8991 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Fri, 11 Feb 2022 14:38:34 -0600 Subject: [PATCH 05/22] add mux supported flag and increase test coverage --- .../dbplugin/v5/grpc_database_plugin.go | 15 ++- sdk/database/dbplugin/v5/grpc_server.go | 17 ++- sdk/database/dbplugin/v5/grpc_server_test.go | 109 +++++++++++++----- sdk/database/dbplugin/v5/plugin_client.go | 4 +- 4 files changed, 107 insertions(+), 38 deletions(-) diff --git a/sdk/database/dbplugin/v5/grpc_database_plugin.go b/sdk/database/dbplugin/v5/grpc_database_plugin.go index 8bd1a032ad43..4a68ee6b559b 100644 --- a/sdk/database/dbplugin/v5/grpc_database_plugin.go +++ b/sdk/database/dbplugin/v5/grpc_database_plugin.go @@ -24,6 +24,8 @@ type GRPCDatabasePlugin struct { FactoryFunc Factory Impl Database + multiplexingSupport bool + // Embeding this will disable the netRPC protocol plugin.NetRPCUnsupportedPlugin } @@ -34,9 +36,16 @@ var ( ) func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error { - server := gRPCServer{factoryFunc: d.FactoryFunc, instances: make(map[string]Database)} - if d.Impl != nil { - server = gRPCServer{singleImpl: d.Impl} + var server gRPCServer + + if d.multiplexingSupport { + server = gRPCServer{ + multiplexingSupport: true, + factoryFunc: d.FactoryFunc, + instances: make(map[string]Database), + } + } else { + server = gRPCServer{multiplexingSupport: false, singleImpl: d.Impl} } proto.RegisterDatabaseServer(s, server) diff --git a/sdk/database/dbplugin/v5/grpc_server.go b/sdk/database/dbplugin/v5/grpc_server.go index d7cd10b11c42..423300a94bd4 100644 --- a/sdk/database/dbplugin/v5/grpc_server.go +++ b/sdk/database/dbplugin/v5/grpc_server.go @@ -23,6 +23,9 @@ type gRPCServer struct { factoryFunc func() (interface{}, error) instances map[string]Database + + multiplexingSupport bool + sync.RWMutex } @@ -49,7 +52,7 @@ func (g gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) { g.Lock() defer g.Unlock() - if g.singleImpl != nil { + if !g.multiplexingSupport && g.singleImpl != nil { return g.singleImpl, nil } @@ -75,7 +78,7 @@ func (g gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) { // getDatabaseInternal returns the database but does not hold a lock func (g gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) { - if g.singleImpl != nil { + if !g.multiplexingSupport && g.singleImpl != nil { return g.singleImpl, nil } @@ -291,11 +294,13 @@ func (g gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, er return &proto.Empty{}, status.Errorf(codes.Internal, "unable to close database plugin: %s", err) } - id, err := getMultiplexIDFromContext(ctx) - if err != nil { - return nil, err + if g.multiplexingSupport { + id, err := getMultiplexIDFromContext(ctx) + if err != nil { + return nil, err + } + delete(g.instances, id) } - delete(g.instances, id) return &proto.Empty{}, nil } diff --git a/sdk/database/dbplugin/v5/grpc_server_test.go b/sdk/database/dbplugin/v5/grpc_server_test.go index 52b488266481..d8bbe932a7a0 100644 --- a/sdk/database/dbplugin/v5/grpc_server_test.go +++ b/sdk/database/dbplugin/v5/grpc_server_test.go @@ -24,11 +24,12 @@ var invalidExpiration = time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC) func TestGRPCServer_Initialize(t *testing.T) { type testCase struct { - db Database - req *proto.InitializeRequest - expectedResp *proto.InitializeResponse - expectErr bool - expectCode codes.Code + db Database + req *proto.InitializeRequest + expectedResp *proto.InitializeResponse + expectErr bool + expectCode codes.Code + grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer) } tests := map[string]testCase{ @@ -36,10 +37,11 @@ func TestGRPCServer_Initialize(t *testing.T) { db: fakeDatabase{ initErr: errors.New("initialization error"), }, - req: &proto.InitializeRequest{}, - expectedResp: &proto.InitializeResponse{}, - expectErr: true, - expectCode: codes.Internal, + req: &proto.InitializeRequest{}, + expectedResp: &proto.InitializeResponse{}, + expectErr: true, + expectCode: codes.Internal, + grpcSetupFunc: testGrpcServer, }, "newConfig can't marshal to JSON": { db: fakeDatabase{ @@ -49,12 +51,13 @@ func TestGRPCServer_Initialize(t *testing.T) { }, }, }, - req: &proto.InitializeRequest{}, - expectedResp: &proto.InitializeResponse{}, - expectErr: true, - expectCode: codes.Internal, + req: &proto.InitializeRequest{}, + expectedResp: &proto.InitializeResponse{}, + expectErr: true, + expectCode: codes.Internal, + grpcSetupFunc: testGrpcServer, }, - "happy path with config data": { + "happy path with config data for multiplexed plugin": { db: fakeDatabase{ initResp: InitializeResponse{ Config: map[string]interface{}{ @@ -72,14 +75,37 @@ func TestGRPCServer_Initialize(t *testing.T) { "foo": "bar", }), }, - expectErr: false, - expectCode: codes.OK, + expectErr: false, + expectCode: codes.OK, + grpcSetupFunc: testGrpcServer, + }, + "happy path with config data for non-multiplexed plugin": { + db: fakeDatabase{ + initResp: InitializeResponse{ + Config: map[string]interface{}{ + "foo": "bar", + }, + }, + }, + req: &proto.InitializeRequest{ + ConfigData: marshal(t, map[string]interface{}{ + "foo": "bar", + }), + }, + expectedResp: &proto.InitializeResponse{ + ConfigData: marshal(t, map[string]interface{}{ + "foo": "bar", + }), + }, + expectErr: false, + expectCode: codes.OK, + grpcSetupFunc: testGrpcServerSingleImpl, }, } for name, test := range tests { t.Run(name, func(t *testing.T) { - idCtx, g := testGrpcServer(t, test.db) + idCtx, g := test.grpcSetupFunc(t, test.db) resp, err := g.Initialize(idCtx, test.req) if test.expectErr && err == nil { @@ -494,9 +520,11 @@ func TestGRPCServer_Type(t *testing.T) { func TestGRPCServer_Close(t *testing.T) { type testCase struct { - db Database - expectErr bool - expectCode codes.Code + db Database + expectErr bool + expectCode codes.Code + grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer) + assertFunc func(t *testing.T, g gRPCServer) } tests := map[string]testCase{ @@ -504,19 +532,33 @@ func TestGRPCServer_Close(t *testing.T) { db: fakeDatabase{ closeErr: errors.New("close error"), }, - expectErr: true, - expectCode: codes.Internal, + expectErr: true, + expectCode: codes.Internal, + grpcSetupFunc: testGrpcServer, + assertFunc: nil, }, - "happy path": { - db: fakeDatabase{}, - expectErr: false, - expectCode: codes.OK, + "happy path for multiplexed plugin": { + db: fakeDatabase{}, + expectErr: false, + expectCode: codes.OK, + grpcSetupFunc: testGrpcServer, + assertFunc: func(t *testing.T, g gRPCServer) { + if len(g.instances) != 0 { + t.Fatalf("err expected instances map to be empty") + } + }, + }, + "happy path for non-multiplexed plugin": { + db: fakeDatabase{}, + expectErr: false, + expectCode: codes.OK, + grpcSetupFunc: testGrpcServerSingleImpl, }, } for name, test := range tests { t.Run(name, func(t *testing.T) { - idCtx, g := testGrpcServer(t, test.db) + idCtx, g := test.grpcSetupFunc(t, test.db) _, err := g.Close(idCtx, &proto.Empty{}) if test.expectErr && err == nil { @@ -585,9 +627,12 @@ func TestGetMultiplexIDFromContext(t *testing.T) { } } +// testGrpcServer is a test helper that returns a context with an ID set in its +// metadata and a gRPCServer instance for a multiplexed plugin func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) { t.Helper() g := gRPCServer{ + multiplexingSupport: true, factoryFunc: func() (interface{}, error) { return db, nil }, @@ -601,6 +646,16 @@ func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) { return idCtx, g } +// testGrpcServerSingleImpl is a test helper that returns a context and a +// gRPCServer instance for a non-multiplexed plugin +func testGrpcServerSingleImpl(t *testing.T, db Database) (context.Context, gRPCServer) { + t.Helper() + return context.Background(), gRPCServer{ + multiplexingSupport: false, + singleImpl: db, + } +} + // idCtx is a test helper that will return a context with the IDs set in its // metadata func idCtx(t *testing.T, ids ...string) context.Context { diff --git a/sdk/database/dbplugin/v5/plugin_client.go b/sdk/database/dbplugin/v5/plugin_client.go index 4ef24cb8e242..9678b5de61db 100644 --- a/sdk/database/dbplugin/v5/plugin_client.go +++ b/sdk/database/dbplugin/v5/plugin_client.go @@ -26,10 +26,10 @@ func (dc *DatabasePluginClient) Close() error { // pluginSets is the map of plugins we can dispense. var PluginSets = map[int]plugin.PluginSet{ 5: { - "database": &GRPCDatabasePlugin{}, + "database": &GRPCDatabasePlugin{multiplexingSupport: false}, }, 6: { - "database": &GRPCDatabasePlugin{}, + "database": &GRPCDatabasePlugin{multiplexingSupport: true}, }, } From 2fc2de58830ca6cded17f125301cd6b794c1cdcb Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Fri, 11 Feb 2022 16:38:00 -0600 Subject: [PATCH 06/22] set multiplexingSupport field in plugin server --- sdk/database/dbplugin/v5/plugin_server.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sdk/database/dbplugin/v5/plugin_server.go b/sdk/database/dbplugin/v5/plugin_server.go index c1fc495ba757..d1784b947db5 100644 --- a/sdk/database/dbplugin/v5/plugin_server.go +++ b/sdk/database/dbplugin/v5/plugin_server.go @@ -25,7 +25,8 @@ func ServeConfig(db Database) *plugin.ServeConfig { pluginSets := map[int]plugin.PluginSet{ 5: { "database": &GRPCDatabasePlugin{ - Impl: db, + Impl: db, + multiplexingSupport: false, }, }, } @@ -54,7 +55,8 @@ func ServeConfigMultiplex(factory Factory) *plugin.ServeConfig { pluginSets := map[int]plugin.PluginSet{ 6: { "database": &GRPCDatabasePlugin{ - FactoryFunc: factory, + FactoryFunc: factory, + multiplexingSupport: true, }, }, } From 774606af1190998936ac9b6f920bb8926b3b97a9 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Fri, 11 Feb 2022 18:11:08 -0600 Subject: [PATCH 07/22] remove multiplexingSupport field in sdk --- sdk/database/dbplugin/v5/grpc_database_plugin.go | 14 ++++++-------- sdk/database/dbplugin/v5/grpc_server.go | 9 ++++----- sdk/database/dbplugin/v5/grpc_server_test.go | 9 ++++++--- sdk/database/dbplugin/v5/plugin_client.go | 4 ++-- sdk/database/dbplugin/v5/plugin_server.go | 6 ++---- 5 files changed, 20 insertions(+), 22 deletions(-) diff --git a/sdk/database/dbplugin/v5/grpc_database_plugin.go b/sdk/database/dbplugin/v5/grpc_database_plugin.go index 4a68ee6b559b..5763e2342171 100644 --- a/sdk/database/dbplugin/v5/grpc_database_plugin.go +++ b/sdk/database/dbplugin/v5/grpc_database_plugin.go @@ -24,8 +24,6 @@ type GRPCDatabasePlugin struct { FactoryFunc Factory Impl Database - multiplexingSupport bool - // Embeding this will disable the netRPC protocol plugin.NetRPCUnsupportedPlugin } @@ -38,14 +36,14 @@ var ( func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error { var server gRPCServer - if d.multiplexingSupport { + if d.Impl != nil { + server = gRPCServer{singleImpl: d.Impl} + } else { + // multiplexing is supported server = gRPCServer{ - multiplexingSupport: true, - factoryFunc: d.FactoryFunc, - instances: make(map[string]Database), + factoryFunc: d.FactoryFunc, + instances: make(map[string]Database), } - } else { - server = gRPCServer{multiplexingSupport: false, singleImpl: d.Impl} } proto.RegisterDatabaseServer(s, server) diff --git a/sdk/database/dbplugin/v5/grpc_server.go b/sdk/database/dbplugin/v5/grpc_server.go index 423300a94bd4..880964feb72e 100644 --- a/sdk/database/dbplugin/v5/grpc_server.go +++ b/sdk/database/dbplugin/v5/grpc_server.go @@ -24,8 +24,6 @@ type gRPCServer struct { factoryFunc func() (interface{}, error) instances map[string]Database - multiplexingSupport bool - sync.RWMutex } @@ -52,7 +50,7 @@ func (g gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) { g.Lock() defer g.Unlock() - if !g.multiplexingSupport && g.singleImpl != nil { + if g.singleImpl != nil { return g.singleImpl, nil } @@ -78,7 +76,7 @@ func (g gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) { // getDatabaseInternal returns the database but does not hold a lock func (g gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) { - if !g.multiplexingSupport && g.singleImpl != nil { + if g.singleImpl != nil { return g.singleImpl, nil } @@ -294,7 +292,8 @@ func (g gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, er return &proto.Empty{}, status.Errorf(codes.Internal, "unable to close database plugin: %s", err) } - if g.multiplexingSupport { + if g.singleImpl == nil { + // only cleanup instances map when multiplexing is supported id, err := getMultiplexIDFromContext(ctx) if err != nil { return nil, err diff --git a/sdk/database/dbplugin/v5/grpc_server_test.go b/sdk/database/dbplugin/v5/grpc_server_test.go index d8bbe932a7a0..4f45e54bb74f 100644 --- a/sdk/database/dbplugin/v5/grpc_server_test.go +++ b/sdk/database/dbplugin/v5/grpc_server_test.go @@ -553,6 +553,7 @@ func TestGRPCServer_Close(t *testing.T) { expectErr: false, expectCode: codes.OK, grpcSetupFunc: testGrpcServerSingleImpl, + assertFunc: nil, }, } @@ -572,6 +573,10 @@ func TestGRPCServer_Close(t *testing.T) { if actualCode != test.expectCode { t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode) } + + if test.assertFunc != nil { + test.assertFunc(t, g) + } }) } } @@ -632,7 +637,6 @@ func TestGetMultiplexIDFromContext(t *testing.T) { func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) { t.Helper() g := gRPCServer{ - multiplexingSupport: true, factoryFunc: func() (interface{}, error) { return db, nil }, @@ -651,8 +655,7 @@ func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) { func testGrpcServerSingleImpl(t *testing.T, db Database) (context.Context, gRPCServer) { t.Helper() return context.Background(), gRPCServer{ - multiplexingSupport: false, - singleImpl: db, + singleImpl: db, } } diff --git a/sdk/database/dbplugin/v5/plugin_client.go b/sdk/database/dbplugin/v5/plugin_client.go index 9678b5de61db..4ef24cb8e242 100644 --- a/sdk/database/dbplugin/v5/plugin_client.go +++ b/sdk/database/dbplugin/v5/plugin_client.go @@ -26,10 +26,10 @@ func (dc *DatabasePluginClient) Close() error { // pluginSets is the map of plugins we can dispense. var PluginSets = map[int]plugin.PluginSet{ 5: { - "database": &GRPCDatabasePlugin{multiplexingSupport: false}, + "database": &GRPCDatabasePlugin{}, }, 6: { - "database": &GRPCDatabasePlugin{multiplexingSupport: true}, + "database": &GRPCDatabasePlugin{}, }, } diff --git a/sdk/database/dbplugin/v5/plugin_server.go b/sdk/database/dbplugin/v5/plugin_server.go index d1784b947db5..c1fc495ba757 100644 --- a/sdk/database/dbplugin/v5/plugin_server.go +++ b/sdk/database/dbplugin/v5/plugin_server.go @@ -25,8 +25,7 @@ func ServeConfig(db Database) *plugin.ServeConfig { pluginSets := map[int]plugin.PluginSet{ 5: { "database": &GRPCDatabasePlugin{ - Impl: db, - multiplexingSupport: false, + Impl: db, }, }, } @@ -55,8 +54,7 @@ func ServeConfigMultiplex(factory Factory) *plugin.ServeConfig { pluginSets := map[int]plugin.PluginSet{ 6: { "database": &GRPCDatabasePlugin{ - FactoryFunc: factory, - multiplexingSupport: true, + FactoryFunc: factory, }, }, } From 82813ee68e7ffa150ab4babfb087d662fa669023 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Fri, 11 Feb 2022 18:21:23 -0600 Subject: [PATCH 08/22] revert postgres to non-multiplexed --- .../database/postgresql/postgresql-database-plugin/main.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/plugins/database/postgresql/postgresql-database-plugin/main.go b/plugins/database/postgresql/postgresql-database-plugin/main.go index de168012ceb2..3d2e14cd9aab 100644 --- a/plugins/database/postgresql/postgresql-database-plugin/main.go +++ b/plugins/database/postgresql/postgresql-database-plugin/main.go @@ -18,7 +18,12 @@ func main() { // Run instantiates a PostgreSQL object, and runs the RPC server for the plugin func Run() error { - dbplugin.ServeMultiplex(postgresql.New) + dbType, err := postgresql.New() + if err != nil { + return err + } + + dbplugin.Serve(dbType.(dbplugin.Database)) return nil } From a8e4cbaaafbf5ecb57b2b44f57d2cd988afdbe83 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Fri, 11 Feb 2022 19:59:15 -0600 Subject: [PATCH 09/22] add comments on grpc server fields --- sdk/database/dbplugin/v5/grpc_server.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sdk/database/dbplugin/v5/grpc_server.go b/sdk/database/dbplugin/v5/grpc_server.go index 880964feb72e..c5d65f78e404 100644 --- a/sdk/database/dbplugin/v5/grpc_server.go +++ b/sdk/database/dbplugin/v5/grpc_server.go @@ -19,10 +19,13 @@ var _ proto.DatabaseServer = gRPCServer{} type gRPCServer struct { proto.UnimplementedDatabaseServer + // holds the non-multiplexed Database + // when this is set the plugin does not support multiplexing singleImpl Database - factoryFunc func() (interface{}, error) + // instances holds the multiplexed Databases instances map[string]Database + factoryFunc func() (interface{}, error) sync.RWMutex } From f4a11ed137647c0473a54bc973c1b05b54b94a0b Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Mon, 14 Feb 2022 14:23:45 -0600 Subject: [PATCH 10/22] use pointer receiver on grpc server methods --- sdk/database/dbplugin/v5/grpc_server.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sdk/database/dbplugin/v5/grpc_server.go b/sdk/database/dbplugin/v5/grpc_server.go index c5d65f78e404..c430c4e47393 100644 --- a/sdk/database/dbplugin/v5/grpc_server.go +++ b/sdk/database/dbplugin/v5/grpc_server.go @@ -49,7 +49,7 @@ func getMultiplexIDFromContext(ctx context.Context) (string, error) { return multiplexID, nil } -func (g gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) { +func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) { g.Lock() defer g.Unlock() @@ -78,7 +78,7 @@ func (g gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) { } // getDatabaseInternal returns the database but does not hold a lock -func (g gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) { +func (g *gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) { if g.singleImpl != nil { return g.singleImpl, nil } @@ -96,7 +96,7 @@ func (g gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) { } // getDatabase holds a read lock and returns the database -func (g gRPCServer) getDatabase(ctx context.Context) (Database, error) { +func (g *gRPCServer) getDatabase(ctx context.Context) (Database, error) { g.RLock() impl, err := g.getDatabaseInternal(ctx) g.RUnlock() @@ -104,7 +104,7 @@ func (g gRPCServer) getDatabase(ctx context.Context) (Database, error) { } // Initialize the database plugin -func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeRequest) (*proto.InitializeResponse, error) { +func (g *gRPCServer) Initialize(ctx context.Context, request *proto.InitializeRequest) (*proto.InitializeResponse, error) { impl, err := g.getOrCreateDatabase(ctx) if err != nil { return nil, err @@ -134,7 +134,7 @@ func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeReq return resp, nil } -func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*proto.NewUserResponse, error) { +func (g *gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*proto.NewUserResponse, error) { if req.GetUsernameConfig() == nil { return &proto.NewUserResponse{}, status.Errorf(codes.InvalidArgument, "missing username config") } @@ -176,7 +176,7 @@ func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*pr return resp, nil } -func (g gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest) (*proto.UpdateUserResponse, error) { +func (g *gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest) (*proto.UpdateUserResponse, error) { if req.GetUsername() == "" { return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided") } @@ -243,7 +243,7 @@ func hasChange(dbReq UpdateUserRequest) bool { return false } -func (g gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest) (*proto.DeleteUserResponse, error) { +func (g *gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest) (*proto.DeleteUserResponse, error) { if req.GetUsername() == "" { return &proto.DeleteUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided") } @@ -264,7 +264,7 @@ func (g gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest return &proto.DeleteUserResponse{}, nil } -func (g gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeResponse, error) { +func (g *gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeResponse, error) { impl, err := g.getOrCreateDatabase(ctx) if err != nil { return nil, err @@ -281,7 +281,7 @@ func (g gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeRespon return resp, nil } -func (g gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) { +func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) { g.Lock() defer g.Unlock() From 909f23a0fa9e29936db75fefb44dec3b7bd8f7c8 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Mon, 14 Feb 2022 14:23:56 -0600 Subject: [PATCH 11/22] add changelog --- changelog/14033.txt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 changelog/14033.txt diff --git a/changelog/14033.txt b/changelog/14033.txt new file mode 100644 index 000000000000..9e113a813954 --- /dev/null +++ b/changelog/14033.txt @@ -0,0 +1,3 @@ +```release-note:feature +**Database plugin multiplexing**: manage multiple database connections with a single plugin process +``` From efaa9e7cb529a301ae7a923ab118bd56139fa165 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Mon, 14 Feb 2022 14:32:46 -0600 Subject: [PATCH 12/22] use pointer for grpcserver instance --- sdk/database/dbplugin/v5/grpc_database_plugin.go | 2 +- sdk/database/dbplugin/v5/grpc_server.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/database/dbplugin/v5/grpc_database_plugin.go b/sdk/database/dbplugin/v5/grpc_database_plugin.go index 5763e2342171..d819cfa2617b 100644 --- a/sdk/database/dbplugin/v5/grpc_database_plugin.go +++ b/sdk/database/dbplugin/v5/grpc_database_plugin.go @@ -46,7 +46,7 @@ func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) err } } - proto.RegisterDatabaseServer(s, server) + proto.RegisterDatabaseServer(s, &server) return nil } diff --git a/sdk/database/dbplugin/v5/grpc_server.go b/sdk/database/dbplugin/v5/grpc_server.go index c430c4e47393..88c65d4d238c 100644 --- a/sdk/database/dbplugin/v5/grpc_server.go +++ b/sdk/database/dbplugin/v5/grpc_server.go @@ -14,7 +14,7 @@ import ( "google.golang.org/grpc/status" ) -var _ proto.DatabaseServer = gRPCServer{} +var _ proto.DatabaseServer = &gRPCServer{} type gRPCServer struct { proto.UnimplementedDatabaseServer From 1a5a6099f1671f4202d18ee2b6a3647bb352e9cb Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 14 Feb 2022 13:25:19 -0800 Subject: [PATCH 13/22] Use a gRPC server to determine if a plugin should be multiplexed --- Makefile | 1 + helper/forwarding/types.pb.go | 2 +- helper/identity/mfa/types.pb.go | 2 +- helper/identity/types.pb.go | 2 +- helper/storagepacker/types.pb.go | 2 +- physical/raft/types.pb.go | 2 +- sdk/database/dbplugin/database.pb.go | 2 +- .../dbplugin/v5/grpc_database_plugin.go | 7 + sdk/database/dbplugin/v5/proto/database.pb.go | 2 +- sdk/helper/pluginutil/multiplexing.go | 41 ++++ sdk/helper/pluginutil/multiplexing.pb.go | 213 ++++++++++++++++++ sdk/helper/pluginutil/multiplexing.proto | 13 ++ sdk/helper/pluginutil/multiplexing_grpc.pb.go | 101 +++++++++ sdk/helper/pluginutil/runner.go | 2 - sdk/logical/identity.pb.go | 2 +- sdk/logical/plugin.pb.go | 2 +- sdk/plugin/pb/backend.pb.go | 2 +- vault/activity/activity_log.pb.go | 2 +- vault/plugin_catalog.go | 14 +- vault/request_forwarding_service.pb.go | 2 +- 20 files changed, 395 insertions(+), 21 deletions(-) create mode 100644 sdk/helper/pluginutil/multiplexing.go create mode 100644 sdk/helper/pluginutil/multiplexing.pb.go create mode 100644 sdk/helper/pluginutil/multiplexing.proto create mode 100644 sdk/helper/pluginutil/multiplexing_grpc.pb.go diff --git a/Makefile b/Makefile index 8abf9634fc9c..173ffb36474e 100644 --- a/Makefile +++ b/Makefile @@ -194,6 +194,7 @@ proto: bootstrap protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/database/dbplugin/*.proto protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/database/dbplugin/v5/proto/*.proto protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/plugin/pb/*.proto + protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/helper/pluginutil/*.proto # No additional sed expressions should be added to this list. Going forward # we should just use the variable names choosen by protobuf. These are left diff --git a/helper/forwarding/types.pb.go b/helper/forwarding/types.pb.go index b7ffa70569e0..3a036f4726aa 100644 --- a/helper/forwarding/types.pb.go +++ b/helper/forwarding/types.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: helper/forwarding/types.proto package forwarding diff --git a/helper/identity/mfa/types.pb.go b/helper/identity/mfa/types.pb.go index 5e5bf2a854e1..5cb27bea548d 100644 --- a/helper/identity/mfa/types.pb.go +++ b/helper/identity/mfa/types.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: helper/identity/mfa/types.proto package mfa diff --git a/helper/identity/types.pb.go b/helper/identity/types.pb.go index c4730c775b89..a392d24bc313 100644 --- a/helper/identity/types.pb.go +++ b/helper/identity/types.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: helper/identity/types.proto package identity diff --git a/helper/storagepacker/types.pb.go b/helper/storagepacker/types.pb.go index 4c5b14edd404..bd7b780cd5a9 100644 --- a/helper/storagepacker/types.pb.go +++ b/helper/storagepacker/types.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: helper/storagepacker/types.proto package storagepacker diff --git a/physical/raft/types.pb.go b/physical/raft/types.pb.go index 98dc72982881..5fca8f6c3e81 100644 --- a/physical/raft/types.pb.go +++ b/physical/raft/types.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: physical/raft/types.proto package raft diff --git a/sdk/database/dbplugin/database.pb.go b/sdk/database/dbplugin/database.pb.go index 4e8b0098ffca..7c9e08a9b03e 100644 --- a/sdk/database/dbplugin/database.pb.go +++ b/sdk/database/dbplugin/database.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: sdk/database/dbplugin/database.proto package dbplugin diff --git a/sdk/database/dbplugin/v5/grpc_database_plugin.go b/sdk/database/dbplugin/v5/grpc_database_plugin.go index 5763e2342171..aa8688cb7c79 100644 --- a/sdk/database/dbplugin/v5/grpc_database_plugin.go +++ b/sdk/database/dbplugin/v5/grpc_database_plugin.go @@ -5,6 +5,7 @@ import ( "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" + "github.com/hashicorp/vault/sdk/helper/pluginutil" "google.golang.org/grpc" ) @@ -44,6 +45,12 @@ func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) err factoryFunc: d.FactoryFunc, instances: make(map[string]Database), } + + // Multiplexing is enabled for this plugin, register the server so we + // can tell the client in Vault. + pluginutil.RegisterPluginMultiplexingServer(s, pluginutil.PluginMultiplexingServerImpl{ + Supported: true, + }) } proto.RegisterDatabaseServer(s, server) diff --git a/sdk/database/dbplugin/v5/proto/database.pb.go b/sdk/database/dbplugin/v5/proto/database.pb.go index 4416e0acc40e..3699a9d662ff 100644 --- a/sdk/database/dbplugin/v5/proto/database.pb.go +++ b/sdk/database/dbplugin/v5/proto/database.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: sdk/database/dbplugin/v5/proto/database.proto package proto diff --git a/sdk/helper/pluginutil/multiplexing.go b/sdk/helper/pluginutil/multiplexing.go new file mode 100644 index 000000000000..a37964b95af3 --- /dev/null +++ b/sdk/helper/pluginutil/multiplexing.go @@ -0,0 +1,41 @@ +package pluginutil + +import ( + context "context" + + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +type PluginMultiplexingServerImpl struct { + UnimplementedPluginMultiplexingServer + + Supported bool +} + +func (pm PluginMultiplexingServerImpl) MultiplexingSupport(ctx context.Context, req *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error) { + return &MultiplexingSupportResponse{ + Supported: pm.Supported, + }, nil +} + +func MultiplexingSupported(ctx context.Context, cc grpc.ClientConnInterface) (bool, error) { + resp, err := NewPluginMultiplexingClient(cc).MultiplexingSupport(ctx, nil) + if err != nil { + + // If the server does not implement the multiplexing server then we can + // assume it is not multiplexed + if status.Code(err) == codes.Unimplemented { + return false, nil + } + + return false, err + } + if resp == nil { + // Somehow got a nil response, assume not multiplexed + return false, nil + } + + return resp.Supported, nil +} diff --git a/sdk/helper/pluginutil/multiplexing.pb.go b/sdk/helper/pluginutil/multiplexing.pb.go new file mode 100644 index 000000000000..d0ff51e57b24 --- /dev/null +++ b/sdk/helper/pluginutil/multiplexing.pb.go @@ -0,0 +1,213 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.27.1 +// protoc v3.19.4 +// source: sdk/helper/pluginutil/multiplexing.proto + +package pluginutil + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type MultiplexingSupportRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *MultiplexingSupportRequest) Reset() { + *x = MultiplexingSupportRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MultiplexingSupportRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MultiplexingSupportRequest) ProtoMessage() {} + +func (x *MultiplexingSupportRequest) ProtoReflect() protoreflect.Message { + mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MultiplexingSupportRequest.ProtoReflect.Descriptor instead. +func (*MultiplexingSupportRequest) Descriptor() ([]byte, []int) { + return file_sdk_helper_pluginutil_multiplexing_proto_rawDescGZIP(), []int{0} +} + +type MultiplexingSupportResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Supported bool `protobuf:"varint,1,opt,name=supported,proto3" json:"supported,omitempty"` +} + +func (x *MultiplexingSupportResponse) Reset() { + *x = MultiplexingSupportResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MultiplexingSupportResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MultiplexingSupportResponse) ProtoMessage() {} + +func (x *MultiplexingSupportResponse) ProtoReflect() protoreflect.Message { + mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MultiplexingSupportResponse.ProtoReflect.Descriptor instead. +func (*MultiplexingSupportResponse) Descriptor() ([]byte, []int) { + return file_sdk_helper_pluginutil_multiplexing_proto_rawDescGZIP(), []int{1} +} + +func (x *MultiplexingSupportResponse) GetSupported() bool { + if x != nil { + return x.Supported + } + return false +} + +var File_sdk_helper_pluginutil_multiplexing_proto protoreflect.FileDescriptor + +var file_sdk_helper_pluginutil_multiplexing_proto_rawDesc = []byte{ + 0x0a, 0x28, 0x73, 0x64, 0x6b, 0x2f, 0x68, 0x65, 0x6c, 0x70, 0x65, 0x72, 0x2f, 0x70, 0x6c, 0x75, + 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2f, 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, + 0x78, 0x69, 0x6e, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x17, 0x70, 0x6c, 0x75, 0x67, + 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2e, 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, + 0x69, 0x6e, 0x67, 0x22, 0x1c, 0x0a, 0x1a, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, + 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x22, 0x3b, 0x0a, 0x1b, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, + 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x65, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x09, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x65, 0x64, 0x32, 0x97, + 0x01, 0x0a, 0x12, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, + 0x65, 0x78, 0x69, 0x6e, 0x67, 0x12, 0x80, 0x01, 0x0a, 0x13, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, + 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x33, 0x2e, + 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2e, 0x6d, 0x75, 0x6c, 0x74, 0x69, + 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, + 0x78, 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x34, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2e, + 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x75, 0x6c, + 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x32, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, + 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, + 0x2f, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x73, 0x64, 0x6b, 0x2f, 0x68, 0x65, 0x6c, 0x70, 0x65, + 0x72, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x62, 0x06, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_sdk_helper_pluginutil_multiplexing_proto_rawDescOnce sync.Once + file_sdk_helper_pluginutil_multiplexing_proto_rawDescData = file_sdk_helper_pluginutil_multiplexing_proto_rawDesc +) + +func file_sdk_helper_pluginutil_multiplexing_proto_rawDescGZIP() []byte { + file_sdk_helper_pluginutil_multiplexing_proto_rawDescOnce.Do(func() { + file_sdk_helper_pluginutil_multiplexing_proto_rawDescData = protoimpl.X.CompressGZIP(file_sdk_helper_pluginutil_multiplexing_proto_rawDescData) + }) + return file_sdk_helper_pluginutil_multiplexing_proto_rawDescData +} + +var file_sdk_helper_pluginutil_multiplexing_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_sdk_helper_pluginutil_multiplexing_proto_goTypes = []interface{}{ + (*MultiplexingSupportRequest)(nil), // 0: pluginutil.multiplexing.MultiplexingSupportRequest + (*MultiplexingSupportResponse)(nil), // 1: pluginutil.multiplexing.MultiplexingSupportResponse +} +var file_sdk_helper_pluginutil_multiplexing_proto_depIdxs = []int32{ + 0, // 0: pluginutil.multiplexing.PluginMultiplexing.MultiplexingSupport:input_type -> pluginutil.multiplexing.MultiplexingSupportRequest + 1, // 1: pluginutil.multiplexing.PluginMultiplexing.MultiplexingSupport:output_type -> pluginutil.multiplexing.MultiplexingSupportResponse + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_sdk_helper_pluginutil_multiplexing_proto_init() } +func file_sdk_helper_pluginutil_multiplexing_proto_init() { + if File_sdk_helper_pluginutil_multiplexing_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MultiplexingSupportRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MultiplexingSupportResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_sdk_helper_pluginutil_multiplexing_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_sdk_helper_pluginutil_multiplexing_proto_goTypes, + DependencyIndexes: file_sdk_helper_pluginutil_multiplexing_proto_depIdxs, + MessageInfos: file_sdk_helper_pluginutil_multiplexing_proto_msgTypes, + }.Build() + File_sdk_helper_pluginutil_multiplexing_proto = out.File + file_sdk_helper_pluginutil_multiplexing_proto_rawDesc = nil + file_sdk_helper_pluginutil_multiplexing_proto_goTypes = nil + file_sdk_helper_pluginutil_multiplexing_proto_depIdxs = nil +} diff --git a/sdk/helper/pluginutil/multiplexing.proto b/sdk/helper/pluginutil/multiplexing.proto new file mode 100644 index 000000000000..aa2438b070ff --- /dev/null +++ b/sdk/helper/pluginutil/multiplexing.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; +package pluginutil.multiplexing; + +option go_package = "github.com/hashicorp/vault/sdk/helper/pluginutil"; + +message MultiplexingSupportRequest {} +message MultiplexingSupportResponse { + bool supported = 1; +} + +service PluginMultiplexing { + rpc MultiplexingSupport(MultiplexingSupportRequest) returns (MultiplexingSupportResponse); +} diff --git a/sdk/helper/pluginutil/multiplexing_grpc.pb.go b/sdk/helper/pluginutil/multiplexing_grpc.pb.go new file mode 100644 index 000000000000..aa8d0e47ba84 --- /dev/null +++ b/sdk/helper/pluginutil/multiplexing_grpc.pb.go @@ -0,0 +1,101 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. + +package pluginutil + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// PluginMultiplexingClient is the client API for PluginMultiplexing service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type PluginMultiplexingClient interface { + MultiplexingSupport(ctx context.Context, in *MultiplexingSupportRequest, opts ...grpc.CallOption) (*MultiplexingSupportResponse, error) +} + +type pluginMultiplexingClient struct { + cc grpc.ClientConnInterface +} + +func NewPluginMultiplexingClient(cc grpc.ClientConnInterface) PluginMultiplexingClient { + return &pluginMultiplexingClient{cc} +} + +func (c *pluginMultiplexingClient) MultiplexingSupport(ctx context.Context, in *MultiplexingSupportRequest, opts ...grpc.CallOption) (*MultiplexingSupportResponse, error) { + out := new(MultiplexingSupportResponse) + err := c.cc.Invoke(ctx, "/pluginutil.multiplexing.PluginMultiplexing/MultiplexingSupport", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// PluginMultiplexingServer is the server API for PluginMultiplexing service. +// All implementations must embed UnimplementedPluginMultiplexingServer +// for forward compatibility +type PluginMultiplexingServer interface { + MultiplexingSupport(context.Context, *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error) + mustEmbedUnimplementedPluginMultiplexingServer() +} + +// UnimplementedPluginMultiplexingServer must be embedded to have forward compatible implementations. +type UnimplementedPluginMultiplexingServer struct { +} + +func (UnimplementedPluginMultiplexingServer) MultiplexingSupport(context.Context, *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method MultiplexingSupport not implemented") +} +func (UnimplementedPluginMultiplexingServer) mustEmbedUnimplementedPluginMultiplexingServer() {} + +// UnsafePluginMultiplexingServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to PluginMultiplexingServer will +// result in compilation errors. +type UnsafePluginMultiplexingServer interface { + mustEmbedUnimplementedPluginMultiplexingServer() +} + +func RegisterPluginMultiplexingServer(s grpc.ServiceRegistrar, srv PluginMultiplexingServer) { + s.RegisterService(&PluginMultiplexing_ServiceDesc, srv) +} + +func _PluginMultiplexing_MultiplexingSupport_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(MultiplexingSupportRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(PluginMultiplexingServer).MultiplexingSupport(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/pluginutil.multiplexing.PluginMultiplexing/MultiplexingSupport", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(PluginMultiplexingServer).MultiplexingSupport(ctx, req.(*MultiplexingSupportRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// PluginMultiplexing_ServiceDesc is the grpc.ServiceDesc for PluginMultiplexing service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var PluginMultiplexing_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "pluginutil.multiplexing.PluginMultiplexing", + HandlerType: (*PluginMultiplexingServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "MultiplexingSupport", + Handler: _PluginMultiplexing_MultiplexingSupport_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "sdk/helper/pluginutil/multiplexing.proto", +} diff --git a/sdk/helper/pluginutil/runner.go b/sdk/helper/pluginutil/runner.go index fdb82b6007f6..48bf133f9927 100644 --- a/sdk/helper/pluginutil/runner.go +++ b/sdk/helper/pluginutil/runner.go @@ -35,8 +35,6 @@ type LookRunnerUtil interface { type PluginClient interface { Conn() grpc.ClientConnInterface - MultiplexingSupport() (bool, error) - plugin.ClientProtocol } diff --git a/sdk/logical/identity.pb.go b/sdk/logical/identity.pb.go index b221ccc3b325..0a68eadf69d3 100644 --- a/sdk/logical/identity.pb.go +++ b/sdk/logical/identity.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: sdk/logical/identity.proto package logical diff --git a/sdk/logical/plugin.pb.go b/sdk/logical/plugin.pb.go index 46de77666df8..b16f0a75af97 100644 --- a/sdk/logical/plugin.pb.go +++ b/sdk/logical/plugin.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: sdk/logical/plugin.proto package logical diff --git a/sdk/plugin/pb/backend.pb.go b/sdk/plugin/pb/backend.pb.go index 342670676c19..dbad4da977ce 100644 --- a/sdk/plugin/pb/backend.pb.go +++ b/sdk/plugin/pb/backend.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: sdk/plugin/pb/backend.proto package pb diff --git a/vault/activity/activity_log.pb.go b/vault/activity/activity_log.pb.go index d59d3d3e17a8..5388f9f78670 100644 --- a/vault/activity/activity_log.pb.go +++ b/vault/activity/activity_log.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: vault/activity/activity_log.proto package activity diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 48bd0b9ac74e..0813687d4b95 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -130,11 +130,11 @@ func (p *pluginClient) Conn() grpc.ClientConnInterface { } // MultiplexingSupport determines if a plugin client supports multiplexing -func (p *pluginClient) MultiplexingSupport() (bool, error) { - if p.client == nil { +func (p *pluginClient) MultiplexingSupport(ctx context.Context) (bool, error) { + if p.clientConn == nil { return false, fmt.Errorf("plugin client is nil") } - return pluginutil.MultiplexingSupport(p.client.NegotiatedVersion()), nil + return pluginutil.MultiplexingSupported(ctx, p.clientConn) } // Close calls the plugin client's cleanupFunc to do any necessary cleanup on @@ -145,7 +145,7 @@ func (p *pluginClient) Close() error { return p.cleanupFunc() } -func (c *PluginCatalog) removePluginClient(name, id string) error { +func (c *PluginCatalog) removePluginClient(ctx context.Context, name, id string) error { var err error extPlugin, ok := c.externalPlugins[name] if !ok { @@ -153,7 +153,7 @@ func (c *PluginCatalog) removePluginClient(name, id string) error { } pluginClient := extPlugin.connections[id] - multiplexingSupport, err := pluginClient.MultiplexingSupport() + multiplexingSupport, err := pluginClient.MultiplexingSupport(ctx) if err != nil { return err } @@ -234,7 +234,7 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi id: id, logger: c.logger.Named(pluginRunner.Name), cleanupFunc: func() error { - return c.removePluginClient(pluginRunner.Name, id) + return c.removePluginClient(ctx, pluginRunner.Name, id) }, } @@ -360,7 +360,7 @@ func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *plug if err == nil { // At this point the pluginRunner does not know if multiplexing is // supported or not. So we need to ask the plugin client itself. - multiplexingSupport, err := v5Client.MultiplexingSupport() + multiplexingSupport, err := pluginutil.MultiplexingSupported(ctx, v5Client.clientConn) if err != nil { return false, err } diff --git a/vault/request_forwarding_service.pb.go b/vault/request_forwarding_service.pb.go index 62962be0c670..d16aa5d07155 100644 --- a/vault/request_forwarding_service.pb.go +++ b/vault/request_forwarding_service.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: vault/request_forwarding_service.proto package vault From c0aaec2afd95a0e56adf596edd5b67a9733021bc Mon Sep 17 00:00:00 2001 From: John-Michael Faircloth Date: Tue, 15 Feb 2022 07:47:51 -0600 Subject: [PATCH 14/22] Apply suggestions from code review Co-authored-by: Brian Kassouf --- vault/plugin_catalog.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 48bd0b9ac74e..f6d0b5a48810 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -169,7 +169,7 @@ func (c *PluginCatalog) removePluginClient(name, id string) error { pluginClient.client.Kill() delete(c.externalPlugins, name) } - return err + return nil } func (c *PluginCatalog) getExternalPlugin(pluginName string) *externalPlugin { @@ -218,7 +218,7 @@ func (c *PluginCatalog) NewPluginClient(ctx context.Context, config pluginutil.P } // newPluginClient returns a client for managing the lifecycle of a plugin -// process +// process. Callers should have the write lock held. func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (*pluginClient, error) { if pluginRunner == nil { return nil, fmt.Errorf("no plugin found") @@ -239,7 +239,7 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi } if !pluginRunner.MultiplexingSupport || len(extPlugin.connections) == 0 { - c.logger.Debug("spawning a new plugin process", "id", id) + c.logger.Debug("spawning a new plugin process", "plugin_name", pluginRunner.Name, "id", id) client, err := pluginRunner.RunConfig(ctx, pluginutil.PluginSets(config.PluginSets), pluginutil.HandshakeConfig(config.HandshakeConfig), From b18bdc6f71f2123dac867d8ab049228a056840a2 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Tue, 15 Feb 2022 09:00:06 -0600 Subject: [PATCH 15/22] add lock to removePluginClient --- vault/plugin_catalog.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index f6d0b5a48810..f165c03f3014 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -146,6 +146,9 @@ func (p *pluginClient) Close() error { } func (c *PluginCatalog) removePluginClient(name, id string) error { + c.lock.Lock() + defer c.lock.Unlock() + var err error extPlugin, ok := c.externalPlugins[name] if !ok { @@ -199,6 +202,7 @@ func (c *PluginCatalog) newExternalPlugin(pluginName string) *externalPlugin { func (c *PluginCatalog) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (*pluginClient, error) { c.lock.Lock() defer c.lock.Unlock() + if config.Name == "" { return nil, fmt.Errorf("no name provided for plugin") } From a46478c96506ac659fbdb2ba6fd7f01bc7e22277 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Tue, 15 Feb 2022 09:25:34 -0600 Subject: [PATCH 16/22] add multiplexingSupport field to externalPlugin struct --- sdk/helper/pluginutil/runner.go | 1 - vault/plugin_catalog.go | 19 +++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/sdk/helper/pluginutil/runner.go b/sdk/helper/pluginutil/runner.go index fdb82b6007f6..c327ef91243f 100644 --- a/sdk/helper/pluginutil/runner.go +++ b/sdk/helper/pluginutil/runner.go @@ -35,7 +35,6 @@ type LookRunnerUtil interface { type PluginClient interface { Conn() grpc.ClientConnInterface - MultiplexingSupport() (bool, error) plugin.ClientProtocol } diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index f165c03f3014..89ff986e89e2 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -77,6 +77,8 @@ type externalPlugin struct { // connections holds client connections by ID connections map[string]*pluginClient + + multiplexingSupport bool } func (c *Core) setupPluginCatalog(ctx context.Context) error { @@ -145,24 +147,21 @@ func (p *pluginClient) Close() error { return p.cleanupFunc() } -func (c *PluginCatalog) removePluginClient(name, id string) error { +// cleanupExternalPlugin will kill plugin processes and perform any necessary cleanup on the +// externalPlugins map for multiplexed and non-multiplexed plugins. +func (c *PluginCatalog) cleanupExternalPlugin(name, id string) error { c.lock.Lock() defer c.lock.Unlock() - var err error extPlugin, ok := c.externalPlugins[name] if !ok { return fmt.Errorf("plugin client not found") } pluginClient := extPlugin.connections[id] - multiplexingSupport, err := pluginClient.MultiplexingSupport() - if err != nil { - return err - } delete(extPlugin.connections, id) - if !multiplexingSupport { + if !extPlugin.multiplexingSupport { pluginClient.client.Kill() if len(extPlugin.connections) == 0 { @@ -172,6 +171,7 @@ func (c *PluginCatalog) removePluginClient(name, id string) error { pluginClient.client.Kill() delete(c.externalPlugins, name) } + return nil } @@ -238,7 +238,7 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi id: id, logger: c.logger.Named(pluginRunner.Name), cleanupFunc: func() error { - return c.removePluginClient(pluginRunner.Name, id) + return c.cleanupExternalPlugin(pluginRunner.Name, id) }, } @@ -292,9 +292,12 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi } else { pc.clientConn = clientConn } + pc.ClientProtocol = rpcClient + extPlugin.connections[id] = pc extPlugin.name = pluginRunner.Name + extPlugin.multiplexingSupport = pluginRunner.MultiplexingSupport return extPlugin.connections[id], nil } From 189382d6746475e01772f9d9faddf235ea330916 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Tue, 15 Feb 2022 11:14:57 -0600 Subject: [PATCH 17/22] do not send nil to grpc MultiplexingSupport --- sdk/helper/pluginutil/multiplexing.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdk/helper/pluginutil/multiplexing.go b/sdk/helper/pluginutil/multiplexing.go index a37964b95af3..5abb99381a12 100644 --- a/sdk/helper/pluginutil/multiplexing.go +++ b/sdk/helper/pluginutil/multiplexing.go @@ -21,7 +21,8 @@ func (pm PluginMultiplexingServerImpl) MultiplexingSupport(ctx context.Context, } func MultiplexingSupported(ctx context.Context, cc grpc.ClientConnInterface) (bool, error) { - resp, err := NewPluginMultiplexingClient(cc).MultiplexingSupport(ctx, nil) + req := new(MultiplexingSupportRequest) + resp, err := NewPluginMultiplexingClient(cc).MultiplexingSupport(ctx, req) if err != nil { // If the server does not implement the multiplexing server then we can From 451d9c712e6dfb2bd04b1f306d687fc67090f3af Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Tue, 15 Feb 2022 12:42:12 -0600 Subject: [PATCH 18/22] check err before logging --- vault/plugin_catalog.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 26079996ee0e..6d91b525da54 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -147,8 +147,9 @@ func (p *pluginClient) Close() error { return p.cleanupFunc() } -// cleanupExternalPlugin will kill plugin processes and perform any necessary cleanup on the -// externalPlugins map for multiplexed and non-multiplexed plugins. +// cleanupExternalPlugin will kill plugin processes and perform any necessary +// cleanup on the externalPlugins map for multiplexed and non-multiplexed +// plugins. func (c *PluginCatalog) cleanupExternalPlugin(name, id string) error { extPlugin, ok := c.externalPlugins[name] if !ok { @@ -371,7 +372,9 @@ func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *plug // Close the client and cleanup the plugin process err = v5Client.Close() - c.logger.Error("error closing plugin client", "error", err) + if err != nil { + c.logger.Error("error closing plugin client", "error", err) + } return multiplexingSupport, nil } @@ -381,7 +384,9 @@ func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *plug if err == nil { // Close the client and cleanup the plugin process err = v4Client.Close() - c.logger.Error("error closing plugin client", "error", err) + if err != nil { + c.logger.Error("error closing plugin client", "error", err) + } return false, nil } From baa01f016e3290c971c0216b17955d5f604a44d2 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Tue, 15 Feb 2022 14:18:28 -0600 Subject: [PATCH 19/22] handle locking scenario for cleanupFunc --- sdk/helper/pluginutil/runner.go | 6 ------ vault/plugin_catalog.go | 4 +++- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/sdk/helper/pluginutil/runner.go b/sdk/helper/pluginutil/runner.go index 48bf133f9927..3e58cebde3aa 100644 --- a/sdk/helper/pluginutil/runner.go +++ b/sdk/helper/pluginutil/runner.go @@ -96,9 +96,3 @@ func CtxCancelIfCanceled(f context.CancelFunc, ctxCanceler context.Context) chan }() return quitCh } - -// MultiplexingSupport returns true if a plugin supports multiplexing. -// Currently this is hardcoded for database plugins. -func MultiplexingSupport(version int) bool { - return version == 6 -} diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 6d91b525da54..063f3adeb2a9 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -236,6 +236,8 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi id: id, logger: c.logger.Named(pluginRunner.Name), cleanupFunc: func() error { + c.lock.Lock() + defer c.lock.Unlock() return c.cleanupExternalPlugin(pluginRunner.Name, id) }, } @@ -371,7 +373,7 @@ func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *plug } // Close the client and cleanup the plugin process - err = v5Client.Close() + err = c.cleanupExternalPlugin(pluginRunner.Name, v5Client.id) if err != nil { c.logger.Error("error closing plugin client", "error", err) } From d4d20c73d84b891572482771a5663954540e11c7 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Tue, 15 Feb 2022 14:43:03 -0600 Subject: [PATCH 20/22] allow ServeConfigMultiplex to dispense v5 plugin --- sdk/database/dbplugin/v5/plugin_server.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sdk/database/dbplugin/v5/plugin_server.go b/sdk/database/dbplugin/v5/plugin_server.go index c1fc495ba757..090894ae5521 100644 --- a/sdk/database/dbplugin/v5/plugin_server.go +++ b/sdk/database/dbplugin/v5/plugin_server.go @@ -50,8 +50,21 @@ func ServeConfigMultiplex(factory Factory) *plugin.ServeConfig { return nil } + db, err := factory() + if err != nil { + fmt.Println(err) + return nil + } + + database := db.(Database) + // pluginSets is the map of plugins we can dispense. pluginSets := map[int]plugin.PluginSet{ + 5: { + "database": &GRPCDatabasePlugin{ + Impl: database, + }, + }, 6: { "database": &GRPCDatabasePlugin{ FactoryFunc: factory, From 3c9cd0d84d5286055dab822030c624dceeddc620 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Wed, 16 Feb 2022 07:51:04 -0600 Subject: [PATCH 21/22] reposition structs, add err check and comments --- sdk/helper/pluginutil/multiplexing.go | 5 ++++ vault/plugin_catalog.go | 38 ++++++++++++--------------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/sdk/helper/pluginutil/multiplexing.go b/sdk/helper/pluginutil/multiplexing.go index 5abb99381a12..cbf50335d0bf 100644 --- a/sdk/helper/pluginutil/multiplexing.go +++ b/sdk/helper/pluginutil/multiplexing.go @@ -2,6 +2,7 @@ package pluginutil import ( context "context" + "fmt" grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" @@ -21,6 +22,10 @@ func (pm PluginMultiplexingServerImpl) MultiplexingSupport(ctx context.Context, } func MultiplexingSupported(ctx context.Context, cc grpc.ClientConnInterface) (bool, error) { + if cc == nil { + return false, fmt.Errorf("client connection is nil") + } + req := new(MultiplexingSupportRequest) resp, err := NewPluginMultiplexingClient(cc).MultiplexingSupport(ctx, req) if err != nil { diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 063f3adeb2a9..208eb6bd9e05 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -53,6 +53,18 @@ type PluginCatalog struct { lock sync.RWMutex } +// externalPlugin holds client connections for multiplexed and +// non-multiplexed plugin processes +type externalPlugin struct { + // name is the plugin name + name string + + // connections holds client connections by ID + connections map[string]*pluginClient + + multiplexingSupport bool +} + // pluginClient represents a connection to a plugin process type pluginClient struct { logger log.Logger @@ -69,18 +81,6 @@ type pluginClient struct { plugin.ClientProtocol } -// externalPlugin holds client connections for multiplexed and -// non-multiplexed plugin processes -type externalPlugin struct { - // name is the plugin name - name string - - // connections holds client connections by ID - connections map[string]*pluginClient - - multiplexingSupport bool -} - func (c *Core) setupPluginCatalog(ctx context.Context) error { c.pluginCatalog = &PluginCatalog{ builtinRegistry: c.builtinRegistry, @@ -131,14 +131,6 @@ func (p *pluginClient) Conn() grpc.ClientConnInterface { return p.clientConn } -// MultiplexingSupport determines if a plugin client supports multiplexing -func (p *pluginClient) MultiplexingSupport(ctx context.Context) (bool, error) { - if p.clientConn == nil { - return false, fmt.Errorf("plugin client is nil") - } - return pluginutil.MultiplexingSupported(ctx, p.clientConn) -} - // Close calls the plugin client's cleanupFunc to do any necessary cleanup on // the plugin client and the PluginCatalog. This implements the // plugin.ClientProtocol interface. @@ -539,7 +531,11 @@ func (c *PluginCatalog) Set(ctx context.Context, name string, pluginType consts. c.lock.Lock() defer c.lock.Unlock() - return c.setInternal(ctx, name, pluginType, false, command, args, env, sha256) + // During plugin registration, we can't know if a plugin is multiplexed or + // not until we run it. So we set it to false here. Once started, we ask + // the plugin if it is multiplexed and set this value accordingly. + multiplexingSupport := false + return c.setInternal(ctx, name, pluginType, multiplexingSupport, command, args, env, sha256) } func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType consts.PluginType, multiplexingSupport bool, command string, args []string, env []string, sha256 []byte) error { From c60d27379fd72105347a1709df6486558fabeeff Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Wed, 16 Feb 2022 14:35:49 -0600 Subject: [PATCH 22/22] add comment on locking for cleanupExternalPlugin --- vault/plugin_catalog.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 208eb6bd9e05..7eaa08690d9c 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -141,7 +141,7 @@ func (p *pluginClient) Close() error { // cleanupExternalPlugin will kill plugin processes and perform any necessary // cleanup on the externalPlugins map for multiplexed and non-multiplexed -// plugins. +// plugins. This should be called with the write lock held. func (c *PluginCatalog) cleanupExternalPlugin(name, id string) error { extPlugin, ok := c.externalPlugins[name] if !ok {