diff --git a/Makefile b/Makefile index 8abf9634fc9c..173ffb36474e 100644 --- a/Makefile +++ b/Makefile @@ -194,6 +194,7 @@ proto: bootstrap protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/database/dbplugin/*.proto protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/database/dbplugin/v5/proto/*.proto protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/plugin/pb/*.proto + protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/helper/pluginutil/*.proto # No additional sed expressions should be added to this list. Going forward # we should just use the variable names choosen by protobuf. These are left diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 8572b0947eea..a88f2b6b068e 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -110,6 +110,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { } type databaseBackend struct { + // connections holds configured database connections by config name connections map[string]*dbPluginInstance logger log.Logger diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index bbe5769ea350..554d83fe3935 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -329,6 +329,8 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { } config.ConnectionDetails = initResp.Config + b.Logger().Debug("created database object", "name", name, "plugin_name", config.PluginName) + b.Lock() defer b.Unlock() @@ -365,6 +367,9 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { "Vault (or the sdk if using a custom plugin) to gain password policy support", config.PluginName)) } + if len(resp.Warnings) == 0 { + return nil, nil + } return resp, nil } } diff --git a/changelog/14033.txt b/changelog/14033.txt new file mode 100644 index 000000000000..9e113a813954 --- /dev/null +++ b/changelog/14033.txt @@ -0,0 +1,3 @@ +```release-note:feature +**Database plugin multiplexing**: manage multiple database connections with a single plugin process +``` diff --git a/helper/builtinplugins/registry.go b/helper/builtinplugins/registry.go index 325658654b50..b0d96530bd11 100644 --- a/helper/builtinplugins/registry.go +++ b/helper/builtinplugins/registry.go @@ -150,8 +150,8 @@ type registry struct { logicalBackends map[string]logical.Factory } -// Get returns the BuiltinFactory func for a particular backend plugin -// from the plugins map. +// Get returns the Factory func for a particular backend plugin from the +// plugins map. func (r *registry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) { switch pluginType { case consts.PluginTypeCredential: diff --git a/helper/forwarding/types.pb.go b/helper/forwarding/types.pb.go index b7ffa70569e0..3a036f4726aa 100644 --- a/helper/forwarding/types.pb.go +++ b/helper/forwarding/types.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: helper/forwarding/types.proto package forwarding diff --git a/helper/identity/mfa/types.pb.go b/helper/identity/mfa/types.pb.go index 5e5bf2a854e1..5cb27bea548d 100644 --- a/helper/identity/mfa/types.pb.go +++ b/helper/identity/mfa/types.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: helper/identity/mfa/types.proto package mfa diff --git a/helper/identity/types.pb.go b/helper/identity/types.pb.go index c4730c775b89..a392d24bc313 100644 --- a/helper/identity/types.pb.go +++ b/helper/identity/types.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: helper/identity/types.proto package identity diff --git a/helper/storagepacker/types.pb.go b/helper/storagepacker/types.pb.go index 4c5b14edd404..bd7b780cd5a9 100644 --- a/helper/storagepacker/types.pb.go +++ b/helper/storagepacker/types.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: helper/storagepacker/types.proto package storagepacker diff --git a/physical/raft/types.pb.go b/physical/raft/types.pb.go index 98dc72982881..5fca8f6c3e81 100644 --- a/physical/raft/types.pb.go +++ b/physical/raft/types.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: physical/raft/types.proto package raft diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index dba0bf74595d..a3826d3a83f0 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -48,7 +48,6 @@ var ( singleQuotedPhrases = regexp.MustCompile(`('.*?')`) ) -// New implements builtinplugins.BuiltinFactory func New() (interface{}, error) { db := new() // Wrap the plugin with middleware to sanitize errors diff --git a/sdk/database/dbplugin/database.pb.go b/sdk/database/dbplugin/database.pb.go index 4e8b0098ffca..7c9e08a9b03e 100644 --- a/sdk/database/dbplugin/database.pb.go +++ b/sdk/database/dbplugin/database.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: sdk/database/dbplugin/database.proto package dbplugin diff --git a/sdk/database/dbplugin/v5/grpc_database_plugin.go b/sdk/database/dbplugin/v5/grpc_database_plugin.go index 96d296ad799c..92ed2dcc66a1 100644 --- a/sdk/database/dbplugin/v5/grpc_database_plugin.go +++ b/sdk/database/dbplugin/v5/grpc_database_plugin.go @@ -5,6 +5,7 @@ import ( "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" + "github.com/hashicorp/vault/sdk/helper/pluginutil" "google.golang.org/grpc" ) @@ -12,14 +13,17 @@ import ( // a plugin and host. If the handshake fails, a user friendly error is shown. // This prevents users from executing bad plugins or executing a plugin // directory. It is a UX feature, not a security feature. -var handshakeConfig = plugin.HandshakeConfig{ - ProtocolVersion: 5, +var HandshakeConfig = plugin.HandshakeConfig{ MagicCookieKey: "VAULT_DATABASE_PLUGIN", MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb", } +// Factory is the factory function to create a dbplugin Database. +type Factory func() (interface{}, error) + type GRPCDatabasePlugin struct { - Impl Database + FactoryFunc Factory + Impl Database // Embeding this will disable the netRPC protocol plugin.NetRPCUnsupportedPlugin @@ -31,7 +35,25 @@ var ( ) func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error { - proto.RegisterDatabaseServer(s, gRPCServer{impl: d.Impl}) + var server gRPCServer + + if d.Impl != nil { + server = gRPCServer{singleImpl: d.Impl} + } else { + // multiplexing is supported + server = gRPCServer{ + factoryFunc: d.FactoryFunc, + instances: make(map[string]Database), + } + + // Multiplexing is enabled for this plugin, register the server so we + // can tell the client in Vault. + pluginutil.RegisterPluginMultiplexingServer(s, pluginutil.PluginMultiplexingServerImpl{ + Supported: true, + }) + } + + proto.RegisterDatabaseServer(s, &server) return nil } diff --git a/sdk/database/dbplugin/v5/grpc_server.go b/sdk/database/dbplugin/v5/grpc_server.go index c75fa8ef0ed0..88c65d4d238c 100644 --- a/sdk/database/dbplugin/v5/grpc_server.go +++ b/sdk/database/dbplugin/v5/grpc_server.go @@ -3,24 +3,113 @@ package dbplugin import ( "context" "fmt" + "sync" "time" "github.com/golang/protobuf/ptypes" "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" + "github.com/hashicorp/vault/sdk/helper/pluginutil" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) -var _ proto.DatabaseServer = gRPCServer{} +var _ proto.DatabaseServer = &gRPCServer{} type gRPCServer struct { proto.UnimplementedDatabaseServer - impl Database + // holds the non-multiplexed Database + // when this is set the plugin does not support multiplexing + singleImpl Database + + // instances holds the multiplexed Databases + instances map[string]Database + factoryFunc func() (interface{}, error) + + sync.RWMutex +} + +func getMultiplexIDFromContext(ctx context.Context) (string, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", fmt.Errorf("missing plugin multiplexing metadata") + } + + multiplexIDs := md[pluginutil.MultiplexingCtxKey] + if len(multiplexIDs) != 1 { + return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs)) + } + + multiplexID := multiplexIDs[0] + if multiplexID == "" { + return "", fmt.Errorf("empty multiplex ID in metadata") + } + + return multiplexID, nil +} + +func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) { + g.Lock() + defer g.Unlock() + + if g.singleImpl != nil { + return g.singleImpl, nil + } + + id, err := getMultiplexIDFromContext(ctx) + if err != nil { + return nil, err + } + + if db, ok := g.instances[id]; ok { + return db, nil + } + + db, err := g.factoryFunc() + if err != nil { + return nil, err + } + + database := db.(Database) + g.instances[id] = database + + return database, nil +} + +// getDatabaseInternal returns the database but does not hold a lock +func (g *gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) { + if g.singleImpl != nil { + return g.singleImpl, nil + } + + id, err := getMultiplexIDFromContext(ctx) + if err != nil { + return nil, err + } + + if db, ok := g.instances[id]; ok { + return db, nil + } + + return nil, fmt.Errorf("no database instance found") +} + +// getDatabase holds a read lock and returns the database +func (g *gRPCServer) getDatabase(ctx context.Context) (Database, error) { + g.RLock() + impl, err := g.getDatabaseInternal(ctx) + g.RUnlock() + return impl, err } // Initialize the database plugin -func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeRequest) (*proto.InitializeResponse, error) { +func (g *gRPCServer) Initialize(ctx context.Context, request *proto.InitializeRequest) (*proto.InitializeResponse, error) { + impl, err := g.getOrCreateDatabase(ctx) + if err != nil { + return nil, err + } + rawConfig := structToMap(request.ConfigData) dbReq := InitializeRequest{ @@ -28,7 +117,7 @@ func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeReq VerifyConnection: request.VerifyConnection, } - dbResp, err := g.impl.Initialize(ctx, dbReq) + dbResp, err := impl.Initialize(ctx, dbReq) if err != nil { return &proto.InitializeResponse{}, status.Errorf(codes.Internal, "failed to initialize: %s", err) } @@ -45,7 +134,7 @@ func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeReq return resp, nil } -func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*proto.NewUserResponse, error) { +func (g *gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*proto.NewUserResponse, error) { if req.GetUsernameConfig() == nil { return &proto.NewUserResponse{}, status.Errorf(codes.InvalidArgument, "missing username config") } @@ -60,6 +149,11 @@ func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*pr expiration = exp } + impl, err := g.getDatabase(ctx) + if err != nil { + return nil, err + } + dbReq := NewUserRequest{ UsernameConfig: UsernameMetadata{ DisplayName: req.GetUsernameConfig().GetDisplayName(), @@ -71,7 +165,7 @@ func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*pr RollbackStatements: getStatementsFromProto(req.GetRollbackStatements()), } - dbResp, err := g.impl.NewUser(ctx, dbReq) + dbResp, err := impl.NewUser(ctx, dbReq) if err != nil { return &proto.NewUserResponse{}, status.Errorf(codes.Internal, "unable to create new user: %s", err) } @@ -82,7 +176,7 @@ func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*pr return resp, nil } -func (g gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest) (*proto.UpdateUserResponse, error) { +func (g *gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest) (*proto.UpdateUserResponse, error) { if req.GetUsername() == "" { return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided") } @@ -92,7 +186,12 @@ func (g gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, err.Error()) } - _, err = g.impl.UpdateUser(ctx, dbReq) + impl, err := g.getDatabase(ctx) + if err != nil { + return nil, err + } + + _, err = impl.UpdateUser(ctx, dbReq) if err != nil { return &proto.UpdateUserResponse{}, status.Errorf(codes.Internal, "unable to update user: %s", err) } @@ -144,7 +243,7 @@ func hasChange(dbReq UpdateUserRequest) bool { return false } -func (g gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest) (*proto.DeleteUserResponse, error) { +func (g *gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest) (*proto.DeleteUserResponse, error) { if req.GetUsername() == "" { return &proto.DeleteUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided") } @@ -153,15 +252,25 @@ func (g gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest Statements: getStatementsFromProto(req.GetStatements()), } - _, err := g.impl.DeleteUser(ctx, dbReq) + impl, err := g.getDatabase(ctx) + if err != nil { + return nil, err + } + + _, err = impl.DeleteUser(ctx, dbReq) if err != nil { return &proto.DeleteUserResponse{}, status.Errorf(codes.Internal, "unable to delete user: %s", err) } return &proto.DeleteUserResponse{}, nil } -func (g gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeResponse, error) { - t, err := g.impl.Type() +func (g *gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeResponse, error) { + impl, err := g.getOrCreateDatabase(ctx) + if err != nil { + return nil, err + } + + t, err := impl.Type() if err != nil { return &proto.TypeResponse{}, status.Errorf(codes.Internal, "unable to retrieve type: %s", err) } @@ -172,11 +281,29 @@ func (g gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeRespon return resp, nil } -func (g gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) { - err := g.impl.Close() +func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) { + g.Lock() + defer g.Unlock() + + impl, err := g.getDatabaseInternal(ctx) + if err != nil { + return nil, err + } + + err = impl.Close() if err != nil { return &proto.Empty{}, status.Errorf(codes.Internal, "unable to close database plugin: %s", err) } + + if g.singleImpl == nil { + // only cleanup instances map when multiplexing is supported + id, err := getMultiplexIDFromContext(ctx) + if err != nil { + return nil, err + } + delete(g.instances, id) + } + return &proto.Empty{}, nil } diff --git a/sdk/database/dbplugin/v5/grpc_server_test.go b/sdk/database/dbplugin/v5/grpc_server_test.go index d3861c25445d..4f45e54bb74f 100644 --- a/sdk/database/dbplugin/v5/grpc_server_test.go +++ b/sdk/database/dbplugin/v5/grpc_server_test.go @@ -13,7 +13,9 @@ import ( "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/timestamp" "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" + "github.com/hashicorp/vault/sdk/helper/pluginutil" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -22,11 +24,12 @@ var invalidExpiration = time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC) func TestGRPCServer_Initialize(t *testing.T) { type testCase struct { - db Database - req *proto.InitializeRequest - expectedResp *proto.InitializeResponse - expectErr bool - expectCode codes.Code + db Database + req *proto.InitializeRequest + expectedResp *proto.InitializeResponse + expectErr bool + expectCode codes.Code + grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer) } tests := map[string]testCase{ @@ -34,10 +37,11 @@ func TestGRPCServer_Initialize(t *testing.T) { db: fakeDatabase{ initErr: errors.New("initialization error"), }, - req: &proto.InitializeRequest{}, - expectedResp: &proto.InitializeResponse{}, - expectErr: true, - expectCode: codes.Internal, + req: &proto.InitializeRequest{}, + expectedResp: &proto.InitializeResponse{}, + expectErr: true, + expectCode: codes.Internal, + grpcSetupFunc: testGrpcServer, }, "newConfig can't marshal to JSON": { db: fakeDatabase{ @@ -47,12 +51,13 @@ func TestGRPCServer_Initialize(t *testing.T) { }, }, }, - req: &proto.InitializeRequest{}, - expectedResp: &proto.InitializeResponse{}, - expectErr: true, - expectCode: codes.Internal, + req: &proto.InitializeRequest{}, + expectedResp: &proto.InitializeResponse{}, + expectErr: true, + expectCode: codes.Internal, + grpcSetupFunc: testGrpcServer, }, - "happy path with config data": { + "happy path with config data for multiplexed plugin": { db: fakeDatabase{ initResp: InitializeResponse{ Config: map[string]interface{}{ @@ -70,21 +75,39 @@ func TestGRPCServer_Initialize(t *testing.T) { "foo": "bar", }), }, - expectErr: false, - expectCode: codes.OK, + expectErr: false, + expectCode: codes.OK, + grpcSetupFunc: testGrpcServer, + }, + "happy path with config data for non-multiplexed plugin": { + db: fakeDatabase{ + initResp: InitializeResponse{ + Config: map[string]interface{}{ + "foo": "bar", + }, + }, + }, + req: &proto.InitializeRequest{ + ConfigData: marshal(t, map[string]interface{}{ + "foo": "bar", + }), + }, + expectedResp: &proto.InitializeResponse{ + ConfigData: marshal(t, map[string]interface{}{ + "foo": "bar", + }), + }, + expectErr: false, + expectCode: codes.OK, + grpcSetupFunc: testGrpcServerSingleImpl, }, } for name, test := range tests { t.Run(name, func(t *testing.T) { - g := gRPCServer{ - impl: test.db, - } + idCtx, g := test.grpcSetupFunc(t, test.db) + resp, err := g.Initialize(idCtx, test.req) - // Context doesn't need to timeout since this is just passed through - ctx := context.Background() - - resp, err := g.Initialize(ctx, test.req) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } @@ -252,14 +275,9 @@ func TestGRPCServer_NewUser(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - g := gRPCServer{ - impl: test.db, - } - - // Context doesn't need to timeout since this is just passed through - ctx := context.Background() + idCtx, g := testGrpcServer(t, test.db) + resp, err := g.NewUser(idCtx, test.req) - resp, err := g.NewUser(ctx, test.req) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } @@ -362,14 +380,9 @@ func TestGRPCServer_UpdateUser(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - g := gRPCServer{ - impl: test.db, - } - - // Context doesn't need to timeout since this is just passed through - ctx := context.Background() + idCtx, g := testGrpcServer(t, test.db) + resp, err := g.UpdateUser(idCtx, test.req) - resp, err := g.UpdateUser(ctx, test.req) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } @@ -430,14 +443,9 @@ func TestGRPCServer_DeleteUser(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - g := gRPCServer{ - impl: test.db, - } - - // Context doesn't need to timeout since this is just passed through - ctx := context.Background() + idCtx, g := testGrpcServer(t, test.db) + resp, err := g.DeleteUser(idCtx, test.req) - resp, err := g.DeleteUser(ctx, test.req) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } @@ -488,14 +496,9 @@ func TestGRPCServer_Type(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - g := gRPCServer{ - impl: test.db, - } + idCtx, g := testGrpcServer(t, test.db) + resp, err := g.Type(idCtx, &proto.Empty{}) - // Context doesn't need to timeout since this is just passed through - ctx := context.Background() - - resp, err := g.Type(ctx, &proto.Empty{}) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } @@ -517,9 +520,11 @@ func TestGRPCServer_Type(t *testing.T) { func TestGRPCServer_Close(t *testing.T) { type testCase struct { - db Database - expectErr bool - expectCode codes.Code + db Database + expectErr bool + expectCode codes.Code + grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer) + assertFunc func(t *testing.T, g gRPCServer) } tests := map[string]testCase{ @@ -527,26 +532,36 @@ func TestGRPCServer_Close(t *testing.T) { db: fakeDatabase{ closeErr: errors.New("close error"), }, - expectErr: true, - expectCode: codes.Internal, + expectErr: true, + expectCode: codes.Internal, + grpcSetupFunc: testGrpcServer, + assertFunc: nil, }, - "happy path": { - db: fakeDatabase{}, - expectErr: false, - expectCode: codes.OK, + "happy path for multiplexed plugin": { + db: fakeDatabase{}, + expectErr: false, + expectCode: codes.OK, + grpcSetupFunc: testGrpcServer, + assertFunc: func(t *testing.T, g gRPCServer) { + if len(g.instances) != 0 { + t.Fatalf("err expected instances map to be empty") + } + }, + }, + "happy path for non-multiplexed plugin": { + db: fakeDatabase{}, + expectErr: false, + expectCode: codes.OK, + grpcSetupFunc: testGrpcServerSingleImpl, + assertFunc: nil, }, } for name, test := range tests { t.Run(name, func(t *testing.T) { - g := gRPCServer{ - impl: test.db, - } - - // Context doesn't need to timeout since this is just passed through - ctx := context.Background() + idCtx, g := test.grpcSetupFunc(t, test.db) + _, err := g.Close(idCtx, &proto.Empty{}) - _, err := g.Close(ctx, &proto.Empty{}) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } @@ -558,10 +573,105 @@ func TestGRPCServer_Close(t *testing.T) { if actualCode != test.expectCode { t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode) } + + if test.assertFunc != nil { + test.assertFunc(t, g) + } }) } } +func TestGetMultiplexIDFromContext(t *testing.T) { + type testCase struct { + ctx context.Context + expectedResp string + expectedErr error + } + + tests := map[string]testCase{ + "missing plugin multiplexing metadata": { + ctx: context.Background(), + expectedResp: "", + expectedErr: fmt.Errorf("missing plugin multiplexing metadata"), + }, + "unexpected number of IDs in metadata": { + ctx: idCtx(t, "12345", "67891"), + expectedResp: "", + expectedErr: fmt.Errorf("unexpected number of IDs in metadata: (2)"), + }, + "empty multiplex ID in metadata": { + ctx: idCtx(t, ""), + expectedResp: "", + expectedErr: fmt.Errorf("empty multiplex ID in metadata"), + }, + "happy path, id is returned from metadata": { + ctx: idCtx(t, "12345"), + expectedResp: "12345", + expectedErr: nil, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + resp, err := getMultiplexIDFromContext(test.ctx) + + if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil { + t.Fatalf("err expected, got nil") + } else if !reflect.DeepEqual(err, test.expectedErr) { + t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr) + } + + if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil { + t.Fatalf("no error expected, got: %s", err) + } + + if !reflect.DeepEqual(resp, test.expectedResp) { + t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp) + } + }) + } +} + +// testGrpcServer is a test helper that returns a context with an ID set in its +// metadata and a gRPCServer instance for a multiplexed plugin +func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) { + t.Helper() + g := gRPCServer{ + factoryFunc: func() (interface{}, error) { + return db, nil + }, + instances: make(map[string]Database), + } + + id := "12345" + idCtx := idCtx(t, id) + g.instances[id] = db + + return idCtx, g +} + +// testGrpcServerSingleImpl is a test helper that returns a context and a +// gRPCServer instance for a non-multiplexed plugin +func testGrpcServerSingleImpl(t *testing.T, db Database) (context.Context, gRPCServer) { + t.Helper() + return context.Background(), gRPCServer{ + singleImpl: db, + } +} + +// idCtx is a test helper that will return a context with the IDs set in its +// metadata +func idCtx(t *testing.T, ids ...string) context.Context { + t.Helper() + // Context doesn't need to timeout since this is just passed through + ctx := context.Background() + md := metadata.MD{} + for _, id := range ids { + md.Append(pluginutil.MultiplexingCtxKey, id) + } + return metadata.NewIncomingContext(ctx, md) +} + func marshal(t *testing.T, m map[string]interface{}) *structpb.Struct { t.Helper() diff --git a/sdk/database/dbplugin/v5/plugin_client.go b/sdk/database/dbplugin/v5/plugin_client.go index d2e096110472..4ef24cb8e242 100644 --- a/sdk/database/dbplugin/v5/plugin_client.go +++ b/sdk/database/dbplugin/v5/plugin_client.go @@ -3,19 +3,14 @@ package dbplugin import ( "context" "errors" - "sync" - log "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" "github.com/hashicorp/vault/sdk/helper/pluginutil" ) -// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's Close -// method to also call Kill() on the plugin.Client. type DatabasePluginClient struct { - client *plugin.Client - sync.Mutex - + client pluginutil.PluginClient Database } @@ -23,42 +18,31 @@ type DatabasePluginClient struct { // and kill the plugin. func (dc *DatabasePluginClient) Close() error { err := dc.Database.Close() - dc.client.Kill() + dc.client.Close() return err } -// NewPluginClient returns a databaseRPCClient with a connection to a running -// plugin. The client is wrapped in a DatabasePluginClient object to ensure the -// plugin is killed on call of Close(). -func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (Database, error) { - // pluginSets is the map of plugins we can dispense. - pluginSets := map[int]plugin.PluginSet{ - 5: { - "database": new(GRPCDatabasePlugin), - }, - } - - client, err := pluginRunner.RunConfig(ctx, - pluginutil.Runner(sys), - pluginutil.PluginSets(pluginSets), - pluginutil.HandshakeConfig(handshakeConfig), - pluginutil.Logger(logger), - pluginutil.MetadataMode(isMetadataMode), - pluginutil.AutoMTLS(true), - ) - if err != nil { - return nil, err - } +// pluginSets is the map of plugins we can dispense. +var PluginSets = map[int]plugin.PluginSet{ + 5: { + "database": &GRPCDatabasePlugin{}, + }, + 6: { + "database": &GRPCDatabasePlugin{}, + }, +} - // Connect via RPC - rpcClient, err := client.Client() +// NewPluginClient returns a databaseRPCClient with a connection to a running +// plugin. +func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, config pluginutil.PluginClientConfig) (Database, error) { + pluginClient, err := sys.NewPluginClient(ctx, config) if err != nil { return nil, err } // Request the plugin - raw, err := rpcClient.Dispense("database") + raw, err := pluginClient.Dispense("database") if err != nil { return nil, err } @@ -66,16 +50,19 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne // We should have a database type now. This feels like a normal interface // implementation but is in fact over an RPC connection. var db Database - switch raw.(type) { + switch c := raw.(type) { case gRPCClient: - db = raw.(gRPCClient) + // This is an abstraction leak from go-plugin but it is necessary in + // order to enable multiplexing on multiplexed plugins + c.client = proto.NewDatabaseClient(pluginClient.Conn()) + + db = c default: return nil, errors.New("unsupported client type") } - // Wrap RPC implementation in DatabasePluginClient return &DatabasePluginClient{ - client: client, + client: pluginClient, Database: db, }, nil } diff --git a/sdk/database/dbplugin/v5/plugin_factory.go b/sdk/database/dbplugin/v5/plugin_factory.go index f203f1ed4cc7..b87dc3a75a60 100644 --- a/sdk/database/dbplugin/v5/plugin_factory.go +++ b/sdk/database/dbplugin/v5/plugin_factory.go @@ -40,8 +40,17 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu transport = "builtin" } else { + config := pluginutil.PluginClientConfig{ + Name: pluginName, + PluginType: consts.PluginTypeDatabase, + PluginSets: PluginSets, + HandshakeConfig: HandshakeConfig, + Logger: namedLogger, + IsMetadataMode: false, + AutoMTLS: true, + } // create a DatabasePluginClient instance - db, err = NewPluginClient(ctx, sys, pluginRunner, namedLogger, false) + db, err = NewPluginClient(ctx, sys, config) if err != nil { return nil, err } @@ -59,6 +68,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu if err != nil { return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err) } + logger.Debug("got database plugin instance", "type", typeStr) // Wrap with metrics middleware db = &databaseMetricsMiddleware{ diff --git a/sdk/database/dbplugin/v5/plugin_server.go b/sdk/database/dbplugin/v5/plugin_server.go index 11d04e6450a6..090894ae5521 100644 --- a/sdk/database/dbplugin/v5/plugin_server.go +++ b/sdk/database/dbplugin/v5/plugin_server.go @@ -31,7 +31,49 @@ func ServeConfig(db Database) *plugin.ServeConfig { } conf := &plugin.ServeConfig{ - HandshakeConfig: handshakeConfig, + HandshakeConfig: HandshakeConfig, + VersionedPlugins: pluginSets, + GRPCServer: plugin.DefaultGRPCServer, + } + + return conf +} + +func ServeMultiplex(factory Factory) { + plugin.Serve(ServeConfigMultiplex(factory)) +} + +func ServeConfigMultiplex(factory Factory) *plugin.ServeConfig { + err := pluginutil.OptionallyEnableMlock() + if err != nil { + fmt.Println(err) + return nil + } + + db, err := factory() + if err != nil { + fmt.Println(err) + return nil + } + + database := db.(Database) + + // pluginSets is the map of plugins we can dispense. + pluginSets := map[int]plugin.PluginSet{ + 5: { + "database": &GRPCDatabasePlugin{ + Impl: database, + }, + }, + 6: { + "database": &GRPCDatabasePlugin{ + FactoryFunc: factory, + }, + }, + } + + conf := &plugin.ServeConfig{ + HandshakeConfig: HandshakeConfig, VersionedPlugins: pluginSets, GRPCServer: plugin.DefaultGRPCServer, } diff --git a/sdk/database/dbplugin/v5/proto/database.pb.go b/sdk/database/dbplugin/v5/proto/database.pb.go index 4416e0acc40e..3699a9d662ff 100644 --- a/sdk/database/dbplugin/v5/proto/database.pb.go +++ b/sdk/database/dbplugin/v5/proto/database.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: sdk/database/dbplugin/v5/proto/database.proto package proto diff --git a/sdk/helper/pluginutil/multiplexing.go b/sdk/helper/pluginutil/multiplexing.go new file mode 100644 index 000000000000..cbf50335d0bf --- /dev/null +++ b/sdk/helper/pluginutil/multiplexing.go @@ -0,0 +1,47 @@ +package pluginutil + +import ( + context "context" + "fmt" + + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +type PluginMultiplexingServerImpl struct { + UnimplementedPluginMultiplexingServer + + Supported bool +} + +func (pm PluginMultiplexingServerImpl) MultiplexingSupport(ctx context.Context, req *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error) { + return &MultiplexingSupportResponse{ + Supported: pm.Supported, + }, nil +} + +func MultiplexingSupported(ctx context.Context, cc grpc.ClientConnInterface) (bool, error) { + if cc == nil { + return false, fmt.Errorf("client connection is nil") + } + + req := new(MultiplexingSupportRequest) + resp, err := NewPluginMultiplexingClient(cc).MultiplexingSupport(ctx, req) + if err != nil { + + // If the server does not implement the multiplexing server then we can + // assume it is not multiplexed + if status.Code(err) == codes.Unimplemented { + return false, nil + } + + return false, err + } + if resp == nil { + // Somehow got a nil response, assume not multiplexed + return false, nil + } + + return resp.Supported, nil +} diff --git a/sdk/helper/pluginutil/multiplexing.pb.go b/sdk/helper/pluginutil/multiplexing.pb.go new file mode 100644 index 000000000000..d0ff51e57b24 --- /dev/null +++ b/sdk/helper/pluginutil/multiplexing.pb.go @@ -0,0 +1,213 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.27.1 +// protoc v3.19.4 +// source: sdk/helper/pluginutil/multiplexing.proto + +package pluginutil + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type MultiplexingSupportRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *MultiplexingSupportRequest) Reset() { + *x = MultiplexingSupportRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MultiplexingSupportRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MultiplexingSupportRequest) ProtoMessage() {} + +func (x *MultiplexingSupportRequest) ProtoReflect() protoreflect.Message { + mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MultiplexingSupportRequest.ProtoReflect.Descriptor instead. +func (*MultiplexingSupportRequest) Descriptor() ([]byte, []int) { + return file_sdk_helper_pluginutil_multiplexing_proto_rawDescGZIP(), []int{0} +} + +type MultiplexingSupportResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Supported bool `protobuf:"varint,1,opt,name=supported,proto3" json:"supported,omitempty"` +} + +func (x *MultiplexingSupportResponse) Reset() { + *x = MultiplexingSupportResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MultiplexingSupportResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MultiplexingSupportResponse) ProtoMessage() {} + +func (x *MultiplexingSupportResponse) ProtoReflect() protoreflect.Message { + mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MultiplexingSupportResponse.ProtoReflect.Descriptor instead. +func (*MultiplexingSupportResponse) Descriptor() ([]byte, []int) { + return file_sdk_helper_pluginutil_multiplexing_proto_rawDescGZIP(), []int{1} +} + +func (x *MultiplexingSupportResponse) GetSupported() bool { + if x != nil { + return x.Supported + } + return false +} + +var File_sdk_helper_pluginutil_multiplexing_proto protoreflect.FileDescriptor + +var file_sdk_helper_pluginutil_multiplexing_proto_rawDesc = []byte{ + 0x0a, 0x28, 0x73, 0x64, 0x6b, 0x2f, 0x68, 0x65, 0x6c, 0x70, 0x65, 0x72, 0x2f, 0x70, 0x6c, 0x75, + 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2f, 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, + 0x78, 0x69, 0x6e, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x17, 0x70, 0x6c, 0x75, 0x67, + 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2e, 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, + 0x69, 0x6e, 0x67, 0x22, 0x1c, 0x0a, 0x1a, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, + 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x22, 0x3b, 0x0a, 0x1b, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, + 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x65, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x09, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x65, 0x64, 0x32, 0x97, + 0x01, 0x0a, 0x12, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, + 0x65, 0x78, 0x69, 0x6e, 0x67, 0x12, 0x80, 0x01, 0x0a, 0x13, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, + 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x33, 0x2e, + 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2e, 0x6d, 0x75, 0x6c, 0x74, 0x69, + 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, + 0x78, 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x34, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2e, + 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x75, 0x6c, + 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x32, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, + 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, + 0x2f, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x73, 0x64, 0x6b, 0x2f, 0x68, 0x65, 0x6c, 0x70, 0x65, + 0x72, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x62, 0x06, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_sdk_helper_pluginutil_multiplexing_proto_rawDescOnce sync.Once + file_sdk_helper_pluginutil_multiplexing_proto_rawDescData = file_sdk_helper_pluginutil_multiplexing_proto_rawDesc +) + +func file_sdk_helper_pluginutil_multiplexing_proto_rawDescGZIP() []byte { + file_sdk_helper_pluginutil_multiplexing_proto_rawDescOnce.Do(func() { + file_sdk_helper_pluginutil_multiplexing_proto_rawDescData = protoimpl.X.CompressGZIP(file_sdk_helper_pluginutil_multiplexing_proto_rawDescData) + }) + return file_sdk_helper_pluginutil_multiplexing_proto_rawDescData +} + +var file_sdk_helper_pluginutil_multiplexing_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_sdk_helper_pluginutil_multiplexing_proto_goTypes = []interface{}{ + (*MultiplexingSupportRequest)(nil), // 0: pluginutil.multiplexing.MultiplexingSupportRequest + (*MultiplexingSupportResponse)(nil), // 1: pluginutil.multiplexing.MultiplexingSupportResponse +} +var file_sdk_helper_pluginutil_multiplexing_proto_depIdxs = []int32{ + 0, // 0: pluginutil.multiplexing.PluginMultiplexing.MultiplexingSupport:input_type -> pluginutil.multiplexing.MultiplexingSupportRequest + 1, // 1: pluginutil.multiplexing.PluginMultiplexing.MultiplexingSupport:output_type -> pluginutil.multiplexing.MultiplexingSupportResponse + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_sdk_helper_pluginutil_multiplexing_proto_init() } +func file_sdk_helper_pluginutil_multiplexing_proto_init() { + if File_sdk_helper_pluginutil_multiplexing_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MultiplexingSupportRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MultiplexingSupportResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_sdk_helper_pluginutil_multiplexing_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_sdk_helper_pluginutil_multiplexing_proto_goTypes, + DependencyIndexes: file_sdk_helper_pluginutil_multiplexing_proto_depIdxs, + MessageInfos: file_sdk_helper_pluginutil_multiplexing_proto_msgTypes, + }.Build() + File_sdk_helper_pluginutil_multiplexing_proto = out.File + file_sdk_helper_pluginutil_multiplexing_proto_rawDesc = nil + file_sdk_helper_pluginutil_multiplexing_proto_goTypes = nil + file_sdk_helper_pluginutil_multiplexing_proto_depIdxs = nil +} diff --git a/sdk/helper/pluginutil/multiplexing.proto b/sdk/helper/pluginutil/multiplexing.proto new file mode 100644 index 000000000000..aa2438b070ff --- /dev/null +++ b/sdk/helper/pluginutil/multiplexing.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; +package pluginutil.multiplexing; + +option go_package = "github.com/hashicorp/vault/sdk/helper/pluginutil"; + +message MultiplexingSupportRequest {} +message MultiplexingSupportResponse { + bool supported = 1; +} + +service PluginMultiplexing { + rpc MultiplexingSupport(MultiplexingSupportRequest) returns (MultiplexingSupportResponse); +} diff --git a/sdk/helper/pluginutil/multiplexing_grpc.pb.go b/sdk/helper/pluginutil/multiplexing_grpc.pb.go new file mode 100644 index 000000000000..aa8d0e47ba84 --- /dev/null +++ b/sdk/helper/pluginutil/multiplexing_grpc.pb.go @@ -0,0 +1,101 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. + +package pluginutil + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// PluginMultiplexingClient is the client API for PluginMultiplexing service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type PluginMultiplexingClient interface { + MultiplexingSupport(ctx context.Context, in *MultiplexingSupportRequest, opts ...grpc.CallOption) (*MultiplexingSupportResponse, error) +} + +type pluginMultiplexingClient struct { + cc grpc.ClientConnInterface +} + +func NewPluginMultiplexingClient(cc grpc.ClientConnInterface) PluginMultiplexingClient { + return &pluginMultiplexingClient{cc} +} + +func (c *pluginMultiplexingClient) MultiplexingSupport(ctx context.Context, in *MultiplexingSupportRequest, opts ...grpc.CallOption) (*MultiplexingSupportResponse, error) { + out := new(MultiplexingSupportResponse) + err := c.cc.Invoke(ctx, "/pluginutil.multiplexing.PluginMultiplexing/MultiplexingSupport", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// PluginMultiplexingServer is the server API for PluginMultiplexing service. +// All implementations must embed UnimplementedPluginMultiplexingServer +// for forward compatibility +type PluginMultiplexingServer interface { + MultiplexingSupport(context.Context, *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error) + mustEmbedUnimplementedPluginMultiplexingServer() +} + +// UnimplementedPluginMultiplexingServer must be embedded to have forward compatible implementations. +type UnimplementedPluginMultiplexingServer struct { +} + +func (UnimplementedPluginMultiplexingServer) MultiplexingSupport(context.Context, *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method MultiplexingSupport not implemented") +} +func (UnimplementedPluginMultiplexingServer) mustEmbedUnimplementedPluginMultiplexingServer() {} + +// UnsafePluginMultiplexingServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to PluginMultiplexingServer will +// result in compilation errors. +type UnsafePluginMultiplexingServer interface { + mustEmbedUnimplementedPluginMultiplexingServer() +} + +func RegisterPluginMultiplexingServer(s grpc.ServiceRegistrar, srv PluginMultiplexingServer) { + s.RegisterService(&PluginMultiplexing_ServiceDesc, srv) +} + +func _PluginMultiplexing_MultiplexingSupport_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(MultiplexingSupportRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(PluginMultiplexingServer).MultiplexingSupport(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/pluginutil.multiplexing.PluginMultiplexing/MultiplexingSupport", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(PluginMultiplexingServer).MultiplexingSupport(ctx, req.(*MultiplexingSupportRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// PluginMultiplexing_ServiceDesc is the grpc.ServiceDesc for PluginMultiplexing service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var PluginMultiplexing_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "pluginutil.multiplexing.PluginMultiplexing", + HandlerType: (*PluginMultiplexingServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "MultiplexingSupport", + Handler: _PluginMultiplexing_MultiplexingSupport_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "sdk/helper/pluginutil/multiplexing.proto", +} diff --git a/sdk/helper/pluginutil/run_config.go b/sdk/helper/pluginutil/run_config.go index f801287d7d4d..cb804f60d873 100644 --- a/sdk/helper/pluginutil/run_config.go +++ b/sdk/helper/pluginutil/run_config.go @@ -9,9 +9,21 @@ import ( log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/version" ) +type PluginClientConfig struct { + Name string + PluginType consts.PluginType + PluginSets map[int]plugin.PluginSet + HandshakeConfig plugin.HandshakeConfig + Logger log.Logger + IsMetadataMode bool + AutoMTLS bool + MLock bool +} + type runConfig struct { // Provided by PluginRunner command string @@ -21,12 +33,9 @@ type runConfig struct { // Initialized with what's in PluginRunner.Env, but can be added to env []string - wrapper RunnerUtil - pluginSets map[int]plugin.PluginSet - hs plugin.HandshakeConfig - logger log.Logger - isMetadataMode bool - autoMTLS bool + wrapper RunnerUtil + + PluginClientConfig } func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error) { @@ -34,19 +43,19 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error cmd.Env = append(cmd.Env, rc.env...) // Add the mlock setting to the ENV of the plugin - if rc.wrapper != nil && rc.wrapper.MlockEnabled() { + if rc.MLock || (rc.wrapper != nil && rc.wrapper.MlockEnabled()) { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true")) } cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version)) - if rc.isMetadataMode { - rc.logger = rc.logger.With("metadata", "true") + if rc.IsMetadataMode { + rc.Logger = rc.Logger.With("metadata", "true") } - metadataEnv := fmt.Sprintf("%s=%t", PluginMetadataModeEnv, rc.isMetadataMode) + metadataEnv := fmt.Sprintf("%s=%t", PluginMetadataModeEnv, rc.IsMetadataMode) cmd.Env = append(cmd.Env, metadataEnv) var clientTLSConfig *tls.Config - if !rc.autoMTLS && !rc.isMetadataMode { + if !rc.AutoMTLS && !rc.IsMetadataMode { // Get a CA TLS Certificate certBytes, key, err := generateCert() if err != nil { @@ -76,17 +85,17 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error } clientConfig := &plugin.ClientConfig{ - HandshakeConfig: rc.hs, - VersionedPlugins: rc.pluginSets, + HandshakeConfig: rc.HandshakeConfig, + VersionedPlugins: rc.PluginSets, Cmd: cmd, SecureConfig: secureConfig, TLSConfig: clientTLSConfig, - Logger: rc.logger, + Logger: rc.Logger, AllowedProtocols: []plugin.Protocol{ plugin.ProtocolNetRPC, plugin.ProtocolGRPC, }, - AutoMTLS: rc.autoMTLS, + AutoMTLS: rc.AutoMTLS, } return clientConfig, nil } @@ -117,31 +126,37 @@ func Runner(wrapper RunnerUtil) RunOpt { func PluginSets(pluginSets map[int]plugin.PluginSet) RunOpt { return func(rc *runConfig) { - rc.pluginSets = pluginSets + rc.PluginSets = pluginSets } } func HandshakeConfig(hs plugin.HandshakeConfig) RunOpt { return func(rc *runConfig) { - rc.hs = hs + rc.HandshakeConfig = hs } } func Logger(logger log.Logger) RunOpt { return func(rc *runConfig) { - rc.logger = logger + rc.Logger = logger } } func MetadataMode(isMetadataMode bool) RunOpt { return func(rc *runConfig) { - rc.isMetadataMode = isMetadataMode + rc.IsMetadataMode = isMetadataMode } } func AutoMTLS(autoMTLS bool) RunOpt { return func(rc *runConfig) { - rc.autoMTLS = autoMTLS + rc.AutoMTLS = autoMTLS + } +} + +func MLock(mlock bool) RunOpt { + return func(rc *runConfig) { + rc.MLock = mlock } } diff --git a/sdk/helper/pluginutil/run_config_test.go b/sdk/helper/pluginutil/run_config_test.go index 239d36852382..f2373fe9b4a5 100644 --- a/sdk/helper/pluginutil/run_config_test.go +++ b/sdk/helper/pluginutil/run_config_test.go @@ -38,19 +38,21 @@ func TestMakeConfig(t *testing.T) { args: []string{"foo", "bar"}, sha256: []byte("some_sha256"), env: []string{"initial=true"}, - pluginSets: map[int]plugin.PluginSet{ - 1: { - "bogus": nil, + PluginClientConfig: PluginClientConfig{ + PluginSets: map[int]plugin.PluginSet{ + 1: { + "bogus": nil, + }, }, + HandshakeConfig: plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "magic_cookie_key", + MagicCookieValue: "magic_cookie_value", + }, + Logger: hclog.NewNullLogger(), + IsMetadataMode: true, + AutoMTLS: false, }, - hs: plugin.HandshakeConfig{ - ProtocolVersion: 1, - MagicCookieKey: "magic_cookie_key", - MagicCookieValue: "magic_cookie_value", - }, - logger: hclog.NewNullLogger(), - isMetadataMode: true, - autoMTLS: false, }, responseWrapInfoTimes: 0, @@ -97,19 +99,21 @@ func TestMakeConfig(t *testing.T) { args: []string{"foo", "bar"}, sha256: []byte("some_sha256"), env: []string{"initial=true"}, - pluginSets: map[int]plugin.PluginSet{ - 1: { - "bogus": nil, + PluginClientConfig: PluginClientConfig{ + PluginSets: map[int]plugin.PluginSet{ + 1: { + "bogus": nil, + }, }, + HandshakeConfig: plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "magic_cookie_key", + MagicCookieValue: "magic_cookie_value", + }, + Logger: hclog.NewNullLogger(), + IsMetadataMode: false, + AutoMTLS: false, }, - hs: plugin.HandshakeConfig{ - ProtocolVersion: 1, - MagicCookieKey: "magic_cookie_key", - MagicCookieValue: "magic_cookie_value", - }, - logger: hclog.NewNullLogger(), - isMetadataMode: false, - autoMTLS: false, }, responseWrapInfo: &wrapping.ResponseWrapInfo{ @@ -161,19 +165,21 @@ func TestMakeConfig(t *testing.T) { args: []string{"foo", "bar"}, sha256: []byte("some_sha256"), env: []string{"initial=true"}, - pluginSets: map[int]plugin.PluginSet{ - 1: { - "bogus": nil, + PluginClientConfig: PluginClientConfig{ + PluginSets: map[int]plugin.PluginSet{ + 1: { + "bogus": nil, + }, }, + HandshakeConfig: plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "magic_cookie_key", + MagicCookieValue: "magic_cookie_value", + }, + Logger: hclog.NewNullLogger(), + IsMetadataMode: true, + AutoMTLS: true, }, - hs: plugin.HandshakeConfig{ - ProtocolVersion: 1, - MagicCookieKey: "magic_cookie_key", - MagicCookieValue: "magic_cookie_value", - }, - logger: hclog.NewNullLogger(), - isMetadataMode: true, - autoMTLS: true, }, responseWrapInfoTimes: 0, @@ -220,19 +226,21 @@ func TestMakeConfig(t *testing.T) { args: []string{"foo", "bar"}, sha256: []byte("some_sha256"), env: []string{"initial=true"}, - pluginSets: map[int]plugin.PluginSet{ - 1: { - "bogus": nil, + PluginClientConfig: PluginClientConfig{ + PluginSets: map[int]plugin.PluginSet{ + 1: { + "bogus": nil, + }, }, + HandshakeConfig: plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "magic_cookie_key", + MagicCookieValue: "magic_cookie_value", + }, + Logger: hclog.NewNullLogger(), + IsMetadataMode: false, + AutoMTLS: true, }, - hs: plugin.HandshakeConfig{ - ProtocolVersion: 1, - MagicCookieKey: "magic_cookie_key", - MagicCookieValue: "magic_cookie_value", - }, - logger: hclog.NewNullLogger(), - isMetadataMode: false, - autoMTLS: true, }, responseWrapInfoTimes: 0, @@ -329,6 +337,11 @@ type mockRunnerUtil struct { mock.Mock } +func (m *mockRunnerUtil) NewPluginClient(ctx context.Context, config PluginClientConfig) (PluginClient, error) { + args := m.Called(ctx, config) + return args.Get(0).(PluginClient), args.Error(1) +} + func (m *mockRunnerUtil) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) { args := m.Called(ctx, data, ttl, jwt) return args.Get(0).(*wrapping.ResponseWrapInfo), args.Error(1) diff --git a/sdk/helper/pluginutil/runner.go b/sdk/helper/pluginutil/runner.go index ecd60eeb3459..3e58cebde3aa 100644 --- a/sdk/helper/pluginutil/runner.go +++ b/sdk/helper/pluginutil/runner.go @@ -8,6 +8,7 @@ import ( plugin "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/wrapping" + "google.golang.org/grpc" ) // Looker defines the plugin Lookup function that looks into the plugin catalog @@ -21,6 +22,7 @@ type Looker interface { // configuration and wrapping data in a response wrapped token. // logical.SystemView implementations satisfy this interface. type RunnerUtil interface { + NewPluginClient(ctx context.Context, config PluginClientConfig) (PluginClient, error) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) MlockEnabled() bool } @@ -31,17 +33,25 @@ type LookRunnerUtil interface { RunnerUtil } +type PluginClient interface { + Conn() grpc.ClientConnInterface + plugin.ClientProtocol +} + +const MultiplexingCtxKey string = "multiplex_id" + // PluginRunner defines the metadata needed to run a plugin securely with // go-plugin. type PluginRunner struct { - Name string `json:"name" structs:"name"` - Type consts.PluginType `json:"type" structs:"type"` - Command string `json:"command" structs:"command"` - Args []string `json:"args" structs:"args"` - Env []string `json:"env" structs:"env"` - Sha256 []byte `json:"sha256" structs:"sha256"` - Builtin bool `json:"builtin" structs:"builtin"` - BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"` + Name string `json:"name" structs:"name"` + Type consts.PluginType `json:"type" structs:"type"` + Command string `json:"command" structs:"command"` + Args []string `json:"args" structs:"args"` + Env []string `json:"env" structs:"env"` + Sha256 []byte `json:"sha256" structs:"sha256"` + Builtin bool `json:"builtin" structs:"builtin"` + BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"` + MultiplexingSupport bool `json:"multiplexing_support" structs:"multiplexing_support"` } // Run takes a wrapper RunnerUtil instance along with the go-plugin parameters and diff --git a/sdk/logical/identity.pb.go b/sdk/logical/identity.pb.go index b221ccc3b325..0a68eadf69d3 100644 --- a/sdk/logical/identity.pb.go +++ b/sdk/logical/identity.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: sdk/logical/identity.proto package logical diff --git a/sdk/logical/plugin.pb.go b/sdk/logical/plugin.pb.go index 46de77666df8..b16f0a75af97 100644 --- a/sdk/logical/plugin.pb.go +++ b/sdk/logical/plugin.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: sdk/logical/plugin.proto package logical diff --git a/sdk/logical/system_view.go b/sdk/logical/system_view.go index 8ea6766b9941..83b4a951e842 100644 --- a/sdk/logical/system_view.go +++ b/sdk/logical/system_view.go @@ -56,6 +56,10 @@ type SystemView interface { // name. Returns a PluginRunner or an error if a plugin can not be found. LookupPlugin(context.Context, string, consts.PluginType) (*pluginutil.PluginRunner, error) + // NewPluginClient returns a client for managing the lifecycle of plugin + // processes + NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) + // MlockEnabled returns the configuration setting for enabling mlock on // plugins. MlockEnabled() bool @@ -152,6 +156,10 @@ func (d StaticSystemView) ReplicationState() consts.ReplicationState { return d.ReplicationStateVal } +func (d StaticSystemView) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) { + return nil, errors.New("NewPluginClient is not implemented in StaticSystemView") +} + func (d StaticSystemView) ResponseWrapData(_ context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) { return nil, errors.New("ResponseWrapData is not implemented in StaticSystemView") } diff --git a/sdk/plugin/grpc_system.go b/sdk/plugin/grpc_system.go index ca7db03176fd..81bc324bf0fa 100644 --- a/sdk/plugin/grpc_system.go +++ b/sdk/plugin/grpc_system.go @@ -99,6 +99,10 @@ func (s *gRPCSystemViewClient) ResponseWrapData(ctx context.Context, data map[st return info, nil } +func (s *gRPCSystemViewClient) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) { + return nil, fmt.Errorf("cannot call NewPluginClient from a plugin backend") +} + func (s *gRPCSystemViewClient) LookupPlugin(_ context.Context, _ string, _ consts.PluginType) (*pluginutil.PluginRunner, error) { return nil, fmt.Errorf("cannot call LookupPlugin from a plugin backend") } diff --git a/sdk/plugin/pb/backend.pb.go b/sdk/plugin/pb/backend.pb.go index 342670676c19..dbad4da977ce 100644 --- a/sdk/plugin/pb/backend.pb.go +++ b/sdk/plugin/pb/backend.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: sdk/plugin/pb/backend.proto package pb diff --git a/vault/activity/activity_log.pb.go b/vault/activity/activity_log.pb.go index d59d3d3e17a8..5388f9f78670 100644 --- a/vault/activity/activity_log.pb.go +++ b/vault/activity/activity_log.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: vault/activity/activity_log.proto package activity diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index ab63ae61bc2e..0c64f09329d1 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -215,6 +215,22 @@ func (d dynamicSystemView) ResponseWrapData(ctx context.Context, data map[string return resp.WrapInfo, nil } +func (d dynamicSystemView) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) { + if d.core == nil { + return nil, fmt.Errorf("system view core is nil") + } + if d.core.pluginCatalog == nil { + return nil, fmt.Errorf("system view core plugin catalog is nil") + } + + c, err := d.core.pluginCatalog.NewPluginClient(ctx, config) + if err != nil { + return nil, err + } + + return c, nil +} + // LookupPlugin looks for a plugin with the given name in the plugin catalog. It // returns a PluginRunner or an error if no plugin was found. func (d dynamicSystemView) LookupPlugin(ctx context.Context, name string, pluginType consts.PluginType) (*pluginutil.PluginRunner, error) { diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 71d4603e6e19..7eaa08690d9c 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -12,6 +12,8 @@ import ( log "github.com/hashicorp/go-hclog" multierror "github.com/hashicorp/go-multierror" + plugin "github.com/hashicorp/go-plugin" + "github.com/hashicorp/go-secure-stdlib/base62" v4 "github.com/hashicorp/vault/sdk/database/dbplugin" v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" "github.com/hashicorp/vault/sdk/helper/consts" @@ -19,6 +21,8 @@ import ( "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" backendplugin "github.com/hashicorp/vault/sdk/plugin" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" ) var ( @@ -35,21 +39,62 @@ type PluginCatalog struct { builtinRegistry BuiltinRegistry catalogView *BarrierView directory string + logger log.Logger + + // externalPlugins holds plugin process connections by plugin name + // + // This allows plugins that suppport multiplexing to use a single grpc + // connection to communicate with multiple "backends". Each backend + // configuration using the same plugin will be routed to the existing + // plugin process. + externalPlugins map[string]*externalPlugin + mlockPlugins bool lock sync.RWMutex } +// externalPlugin holds client connections for multiplexed and +// non-multiplexed plugin processes +type externalPlugin struct { + // name is the plugin name + name string + + // connections holds client connections by ID + connections map[string]*pluginClient + + multiplexingSupport bool +} + +// pluginClient represents a connection to a plugin process +type pluginClient struct { + logger log.Logger + + // id is the connection ID + id string + + // client handles the lifecycle of a plugin process + // multiplexed plugins share the same client + client *plugin.Client + clientConn grpc.ClientConnInterface + cleanupFunc func() error + + plugin.ClientProtocol +} + func (c *Core) setupPluginCatalog(ctx context.Context) error { c.pluginCatalog = &PluginCatalog{ builtinRegistry: c.builtinRegistry, catalogView: NewBarrierView(c.barrier, pluginCatalogPath), directory: c.pluginDirectory, + logger: c.logger, + mlockPlugins: c.enableMlock, } // Run upgrade if untyped plugins exist err := c.pluginCatalog.UpgradePlugins(ctx, c.logger) if err != nil { c.logger.Error("error while upgrading plugin storage", "error", err) + return err } if c.logger.IsInfo() { @@ -59,14 +104,205 @@ func (c *Core) setupPluginCatalog(ctx context.Context) error { return nil } +type pluginClientConn struct { + *grpc.ClientConn + id string +} + +var _ grpc.ClientConnInterface = &pluginClientConn{} + +func (d *pluginClientConn) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error { + // Inject ID to the context + md := metadata.Pairs(pluginutil.MultiplexingCtxKey, d.id) + idCtx := metadata.NewOutgoingContext(ctx, md) + + return d.ClientConn.Invoke(idCtx, method, args, reply, opts...) +} + +func (d *pluginClientConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + // Inject ID to the context + md := metadata.Pairs(pluginutil.MultiplexingCtxKey, d.id) + idCtx := metadata.NewOutgoingContext(ctx, md) + + return d.ClientConn.NewStream(idCtx, desc, method, opts...) +} + +func (p *pluginClient) Conn() grpc.ClientConnInterface { + return p.clientConn +} + +// Close calls the plugin client's cleanupFunc to do any necessary cleanup on +// the plugin client and the PluginCatalog. This implements the +// plugin.ClientProtocol interface. +func (p *pluginClient) Close() error { + p.logger.Debug("cleaning up plugin client connection", "id", p.id) + return p.cleanupFunc() +} + +// cleanupExternalPlugin will kill plugin processes and perform any necessary +// cleanup on the externalPlugins map for multiplexed and non-multiplexed +// plugins. This should be called with the write lock held. +func (c *PluginCatalog) cleanupExternalPlugin(name, id string) error { + extPlugin, ok := c.externalPlugins[name] + if !ok { + return fmt.Errorf("plugin client not found") + } + + pluginClient := extPlugin.connections[id] + + delete(extPlugin.connections, id) + if !extPlugin.multiplexingSupport { + pluginClient.client.Kill() + + if len(extPlugin.connections) == 0 { + delete(c.externalPlugins, name) + } + } else if len(extPlugin.connections) == 0 { + pluginClient.client.Kill() + delete(c.externalPlugins, name) + } + + return nil +} + +func (c *PluginCatalog) getExternalPlugin(pluginName string) *externalPlugin { + if extPlugin, ok := c.externalPlugins[pluginName]; ok { + return extPlugin + } + + return c.newExternalPlugin(pluginName) +} + +func (c *PluginCatalog) newExternalPlugin(pluginName string) *externalPlugin { + if c.externalPlugins == nil { + c.externalPlugins = make(map[string]*externalPlugin) + } + + extPlugin := &externalPlugin{ + connections: make(map[string]*pluginClient), + name: pluginName, + } + + c.externalPlugins[pluginName] = extPlugin + return extPlugin +} + +// NewPluginClient returns a client for managing the lifecycle of a plugin +// process +func (c *PluginCatalog) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (*pluginClient, error) { + c.lock.Lock() + defer c.lock.Unlock() + + if config.Name == "" { + return nil, fmt.Errorf("no name provided for plugin") + } + if config.PluginType == consts.PluginTypeUnknown { + return nil, fmt.Errorf("no plugin type provided") + } + + pluginRunner, err := c.get(ctx, config.Name, config.PluginType) + if err != nil { + return nil, fmt.Errorf("failed to lookup plugin: %w", err) + } + if pluginRunner == nil { + return nil, fmt.Errorf("no plugin found") + } + pc, err := c.newPluginClient(ctx, pluginRunner, config) + return pc, err +} + +// newPluginClient returns a client for managing the lifecycle of a plugin +// process. Callers should have the write lock held. +func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (*pluginClient, error) { + if pluginRunner == nil { + return nil, fmt.Errorf("no plugin found") + } + + extPlugin := c.getExternalPlugin(pluginRunner.Name) + id, err := base62.Random(10) + if err != nil { + return nil, err + } + + pc := &pluginClient{ + id: id, + logger: c.logger.Named(pluginRunner.Name), + cleanupFunc: func() error { + c.lock.Lock() + defer c.lock.Unlock() + return c.cleanupExternalPlugin(pluginRunner.Name, id) + }, + } + + if !pluginRunner.MultiplexingSupport || len(extPlugin.connections) == 0 { + c.logger.Debug("spawning a new plugin process", "plugin_name", pluginRunner.Name, "id", id) + client, err := pluginRunner.RunConfig(ctx, + pluginutil.PluginSets(config.PluginSets), + pluginutil.HandshakeConfig(config.HandshakeConfig), + pluginutil.Logger(config.Logger), + pluginutil.MetadataMode(config.IsMetadataMode), + pluginutil.MLock(c.mlockPlugins), + + // NewPluginClient only supports AutoMTLS today + pluginutil.AutoMTLS(true), + ) + if err != nil { + return nil, err + } + + pc.client = client + } else { + c.logger.Debug("returning existing plugin client for multiplexed plugin", "id", id) + + // get the first client, since they are all the same + for k := range extPlugin.connections { + pc.client = extPlugin.connections[k].client + break + } + + if pc.client == nil { + return nil, fmt.Errorf("plugin client is nil") + } + } + + // Get the protocol client for this connection. + // Subsequent calls to this will return the same client. + rpcClient, err := pc.client.Client() + if err != nil { + return nil, err + } + + clientConn := rpcClient.(*plugin.GRPCClient).Conn + + if pluginRunner.MultiplexingSupport { + // Wrap rpcClient with our implementation so that we can inject the + // ID into the context + pc.clientConn = &pluginClientConn{ + ClientConn: clientConn, + id: id, + } + } else { + pc.clientConn = clientConn + } + + pc.ClientProtocol = rpcClient + + extPlugin.connections[id] = pc + extPlugin.name = pluginRunner.Name + extPlugin.multiplexingSupport = pluginRunner.MultiplexingSupport + + return extPlugin.connections[id], nil +} + // getPluginTypeFromUnknown will attempt to run the plugin to determine the -// type. It will first attempt to run as a database plugin then a backend -// plugin. Both of these will be run in metadata mode. -func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, error) { +// type and if it supports multiplexing. It will first attempt to run as a +// database plugin then a backend plugin. Both of these will be run in metadata +// mode. +func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, bool, error) { merr := &multierror.Error{} - err := isDatabasePlugin(ctx, plugin) + multiplexingSupport, err := c.isDatabasePlugin(ctx, plugin) if err == nil { - return consts.PluginTypeDatabase, nil + return consts.PluginTypeDatabase, multiplexingSupport, nil } merr = multierror.Append(merr, err) @@ -75,7 +311,7 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log if err == nil { err := client.Setup(ctx, &logical.BackendConfig{}) if err != nil { - return consts.PluginTypeUnknown, err + return consts.PluginTypeUnknown, false, err } backendType := client.Type() @@ -83,9 +319,9 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log switch backendType { case logical.TypeCredential: - return consts.PluginTypeCredential, nil + return consts.PluginTypeCredential, false, nil case logical.TypeLogical: - return consts.PluginTypeSecrets, nil + return consts.PluginTypeSecrets, false, nil } } else { merr = multierror.Append(merr, err) @@ -102,29 +338,55 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log "error", merr.Error()) } - return consts.PluginTypeUnknown, nil + return consts.PluginTypeUnknown, false, nil } -func isDatabasePlugin(ctx context.Context, plugin *pluginutil.PluginRunner) error { +// isDatabasePlugin returns true if the plugin supports multiplexing. An error +// is returned if the plugin is not a database plugin. +func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (bool, error) { merr := &multierror.Error{} - // Attempt to run as database V5 plugin - v5Client, err := v5.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true) + config := pluginutil.PluginClientConfig{ + Name: pluginRunner.Name, + PluginSets: v5.PluginSets, + PluginType: consts.PluginTypeDatabase, + HandshakeConfig: v5.HandshakeConfig, + Logger: log.NewNullLogger(), + IsMetadataMode: true, + AutoMTLS: true, + } + // Attempt to run as database V5 or V6 multiplexed plugin + v5Client, err := c.newPluginClient(ctx, pluginRunner, config) if err == nil { + // At this point the pluginRunner does not know if multiplexing is + // supported or not. So we need to ask the plugin client itself. + multiplexingSupport, err := pluginutil.MultiplexingSupported(ctx, v5Client.clientConn) + if err != nil { + return false, err + } + // Close the client and cleanup the plugin process - v5Client.Close() - return nil + err = c.cleanupExternalPlugin(pluginRunner.Name, v5Client.id) + if err != nil { + c.logger.Error("error closing plugin client", "error", err) + } + + return multiplexingSupport, nil } merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v5: %w", err)) - v4Client, err := v4.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true) + v4Client, err := v4.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true) if err == nil { // Close the client and cleanup the plugin process - v4Client.Close() - return nil + err = v4Client.Close() + if err != nil { + c.logger.Error("error closing plugin client", "error", err) + } + + return false, nil } merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v4: %w", err)) - return merr.ErrorOrNil() + return false, merr.ErrorOrNil() } // UpdatePlugins will loop over all the plugins of unknown type and attempt to @@ -170,7 +432,7 @@ func (c *PluginCatalog) UpgradePlugins(ctx context.Context, logger log.Logger) e cmdOld := plugin.Command plugin.Command = filepath.Join(c.directory, plugin.Command) - pluginType, err := c.getPluginTypeFromUnknown(ctx, logger, plugin) + pluginType, multiplexingSupport, err := c.getPluginTypeFromUnknown(ctx, logger, plugin) if err != nil { retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: %s", pluginName, err)) continue @@ -181,7 +443,7 @@ func (c *PluginCatalog) UpgradePlugins(ctx context.Context, logger log.Logger) e } // Upgrade the storage - err = c.setInternal(ctx, pluginName, pluginType, cmdOld, plugin.Args, plugin.Env, plugin.Sha256) + err = c.setInternal(ctx, pluginName, pluginType, multiplexingSupport, cmdOld, plugin.Args, plugin.Env, plugin.Sha256) if err != nil { retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: %s", pluginName, err)) continue @@ -269,10 +531,14 @@ func (c *PluginCatalog) Set(ctx context.Context, name string, pluginType consts. c.lock.Lock() defer c.lock.Unlock() - return c.setInternal(ctx, name, pluginType, command, args, env, sha256) + // During plugin registration, we can't know if a plugin is multiplexed or + // not until we run it. So we set it to false here. Once started, we ask + // the plugin if it is multiplexed and set this value accordingly. + multiplexingSupport := false + return c.setInternal(ctx, name, pluginType, multiplexingSupport, command, args, env, sha256) } -func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType consts.PluginType, command string, args []string, env []string, sha256 []byte) error { +func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType consts.PluginType, multiplexingSupport bool, command string, args []string, env []string, sha256 []byte) error { // Best effort check to make sure the command isn't breaking out of the // configured plugin directory. commandFull := filepath.Join(c.directory, command) @@ -294,15 +560,16 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType // entryTmp should only be used for the below type check, it uses the // full command instead of the relative command. entryTmp := &pluginutil.PluginRunner{ - Name: name, - Command: commandFull, - Args: args, - Env: env, - Sha256: sha256, - Builtin: false, + Name: name, + Command: commandFull, + Args: args, + Env: env, + Sha256: sha256, + Builtin: false, + MultiplexingSupport: multiplexingSupport, } - pluginType, err = c.getPluginTypeFromUnknown(ctx, log.Default(), entryTmp) + pluginType, multiplexingSupport, err = c.getPluginTypeFromUnknown(ctx, log.Default(), entryTmp) if err != nil { return err } @@ -312,13 +579,14 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType } entry := &pluginutil.PluginRunner{ - Name: name, - Type: pluginType, - Command: command, - Args: args, - Env: env, - Sha256: sha256, - Builtin: false, + Name: name, + Type: pluginType, + Command: command, + Args: args, + Env: env, + Sha256: sha256, + Builtin: false, + MultiplexingSupport: multiplexingSupport, } buf, err := json.Marshal(entry) diff --git a/vault/request_forwarding_service.pb.go b/vault/request_forwarding_service.pb.go index 62962be0c670..d16aa5d07155 100644 --- a/vault/request_forwarding_service.pb.go +++ b/vault/request_forwarding_service.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.27.1 -// protoc v3.17.3 +// protoc v3.19.4 // source: vault/request_forwarding_service.proto package vault