Skip to content

Commit

Permalink
wip redoing interface to support BC
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Hoffman committed Feb 20, 2018
1 parent fd13249 commit b33e114
Show file tree
Hide file tree
Showing 20 changed files with 339 additions and 154 deletions.
18 changes: 9 additions & 9 deletions builtin/logical/database/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,31 +150,31 @@ func (b *databaseBackend) Role(ctx context.Context, s logical.Storage, roleName
case upgradeCh.Statements != nil:
var stmts dbplugin.Statements
if upgradeCh.Statements.CreationStatements != "" {
stmts.CreationStatements = []string{upgradeCh.Statements.CreationStatements}
stmts.Creation = []string{upgradeCh.Statements.CreationStatements}
}
if upgradeCh.Statements.RevocationStatements != "" {
stmts.RevocationStatements = []string{upgradeCh.Statements.RevocationStatements}
stmts.Revocation = []string{upgradeCh.Statements.RevocationStatements}
}
if upgradeCh.Statements.RollbackStatements != "" {
stmts.RollbackStatements = []string{upgradeCh.Statements.RollbackStatements}
stmts.Rollback = []string{upgradeCh.Statements.RollbackStatements}
}
if upgradeCh.Statements.RenewStatements != "" {
stmts.RenewStatements = []string{upgradeCh.Statements.RenewStatements}
stmts.Renewal = []string{upgradeCh.Statements.RenewStatements}
}
result.Statements = stmts
case upgradeCh.OldStatements != nil:
var stmts dbplugin.Statements
if upgradeCh.OldStatements.CreationStatements != "" {
stmts.CreationStatements = []string{upgradeCh.OldStatements.CreationStatements}
stmts.Creation = []string{upgradeCh.OldStatements.CreationStatements}
}
if upgradeCh.OldStatements.RevocationStatements != "" {
stmts.RevocationStatements = []string{upgradeCh.OldStatements.RevocationStatements}
stmts.Revocation = []string{upgradeCh.OldStatements.RevocationStatements}
}
if upgradeCh.OldStatements.RollbackStatements != "" {
stmts.RollbackStatements = []string{upgradeCh.OldStatements.RollbackStatements}
stmts.Rollback = []string{upgradeCh.OldStatements.RollbackStatements}
}
if upgradeCh.OldStatements.RenewStatements != "" {
stmts.RenewStatements = []string{upgradeCh.OldStatements.RenewStatements}
stmts.Renewal = []string{upgradeCh.OldStatements.RenewStatements}
}
result.Statements = stmts
}
Expand Down Expand Up @@ -225,7 +225,7 @@ func (b *databaseBackend) GetConnection(ctx context.Context, s logical.Storage,
return nil, err
}

_, err = dbp.Initialize(ctx, config.ConnectionDetails, true)
_, err = dbp.Init(ctx, config.ConnectionDetails, true)
if err != nil {
dbp.Close()
return nil, err
Expand Down
262 changes: 181 additions & 81 deletions builtin/logical/database/dbplugin/database.pb.go

Large diffs are not rendered by default.

33 changes: 26 additions & 7 deletions builtin/logical/database/dbplugin/database.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ package dbplugin;
import "google/protobuf/timestamp.proto";

message InitializeRequest {
option deprecated = true;
bytes config = 1;
bool verify_connection = 2;
}

message InitRequest {
bytes config = 1;
bool verify_connection = 2;
}
Expand All @@ -30,18 +36,27 @@ message RotateRootCredentialsRequest {
}

message Statements {
repeated string creation_statements = 1;
repeated string revocation_statements = 2;
repeated string rollback_statements = 3;
repeated string renew_statements = 4;
// DEPRECATED, will be removed in 0.12
string creation_statements = 1;
// DEPRECATED, will be removed in 0.12
string revocation_statements = 2;
// DEPRECATED, will be removed in 0.12
string rollback_statements = 3;
// DEPRECATED, will be removed in 0.12
string renew_statements = 4;

repeated string creation = 5;
repeated string revocation = 6;
repeated string rollback = 7;
repeated string renewal = 8;
}

message UsernameConfig {
string DisplayName = 1;
string RoleName = 2;
}

message InitializeResponse {
message InitResponse {
bytes config = 1;
}

Expand All @@ -66,6 +81,10 @@ service Database {
rpc RenewUser(RenewUserRequest) returns (Empty);
rpc RevokeUser(RevokeUserRequest) returns (Empty);
rpc RotateRootCredentials(RotateRootCredentialsRequest) returns (RotateRootCredentialsResponse);
rpc Initialize(InitializeRequest) returns (InitializeResponse);
rpc Close(Empty) returns (Empty);
rpc Init(InitRequest) returns (InitResponse);
rpc Close(Empty) returns (Empty);

rpc Initialize(InitializeRequest) returns (Empty) {
option deprecated = true;
};
}
27 changes: 21 additions & 6 deletions builtin/logical/database/dbplugin/databasemiddleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,18 @@ func (mw *databaseTracingMiddleware) RotateRootCredentials(ctx context.Context,
return mw.next.RotateRootCredentials(ctx, statements)
}

func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := mw.Init(ctx, conf, verifyConnection)
return err
}

func (mw *databaseTracingMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then))
}(time.Now())

mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.Initialize(ctx, conf, verifyConnection)
return mw.next.Init(ctx, conf, verifyConnection)
}

func (mw *databaseTracingMiddleware) Close() (err error) {
Expand Down Expand Up @@ -161,7 +166,12 @@ func (mw *databaseMetricsMiddleware) RotateRootCredentials(ctx context.Context,
return mw.next.RotateRootCredentials(ctx, statements)
}

func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := mw.Init(ctx, conf, verifyConnection)
return err
}

func (mw *databaseMetricsMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "Initialize"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
Expand All @@ -174,7 +184,7 @@ func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[st

metrics.IncrCounter([]string{"database", "Initialize"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
return mw.next.Initialize(ctx, conf, verifyConnection)
return mw.next.Init(ctx, conf, verifyConnection)
}

func (mw *databaseMetricsMiddleware) Close() (err error) {
Expand Down Expand Up @@ -224,8 +234,13 @@ func (mw *databaseErrorSanitizerMiddleware) RotateRootCredentials(ctx context.Co
return conf, mw.sanitize(err)
}

func (mw *databaseErrorSanitizerMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
saveConf, err = mw.next.Initialize(ctx, conf, verifyConnection)
func (mw *databaseErrorSanitizerMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := mw.Init(ctx, conf, verifyConnection)
return err
}

func (mw *databaseErrorSanitizerMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
saveConf, err = mw.next.Init(ctx, conf, verifyConnection)
return saveConf, mw.sanitize(err)
}

Expand Down
23 changes: 18 additions & 5 deletions builtin/logical/database/dbplugin/grpc_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,22 @@ func (s *gRPCServer) RotateRootCredentials(ctx context.Context, req *RotateRootC
}, err
}

func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*InitializeResponse, error) {
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
_, err := s.Init(ctx, &InitRequest{
Config: req.Config,
VerifyConnection: req.VerifyConnection,
})
return &Empty{}, err
}

func (s *gRPCServer) Init(ctx context.Context, req *InitRequest) (*InitResponse, error) {
config := map[string]interface{}{}
err := json.Unmarshal(req.Config, &config)
if err != nil {
return nil, err
}

resp, err := s.impl.Initialize(ctx, config, req.VerifyConnection)
resp, err := s.impl.Init(ctx, config, req.VerifyConnection)
if err != nil {
return nil, err
}
Expand All @@ -95,7 +103,7 @@ func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*I
return nil, err
}

return &InitializeResponse{
return &InitResponse{
Config: respConfig,
}, err
}
Expand Down Expand Up @@ -224,7 +232,12 @@ func (c *gRPCClient) RotateRootCredentials(ctx context.Context, statements []str
return conf, nil
}

func (c *gRPCClient) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
func (c *gRPCClient) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := c.Init(ctx, conf, verifyConnection)
return err
}

func (c *gRPCClient) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
configRaw, err := json.Marshal(conf)
if err != nil {
return nil, err
Expand All @@ -235,7 +248,7 @@ func (c *gRPCClient) Initialize(ctx context.Context, conf map[string]interface{}
defer close(quitCh)
defer cancel()

resp, err := c.client.Initialize(ctx, &InitializeRequest{
resp, err := c.client.Init(ctx, &InitRequest{
Config: configRaw,
VerifyConnection: verifyConnection,
})
Expand Down
27 changes: 22 additions & 5 deletions builtin/logical/database/dbplugin/netrpc_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,15 @@ func (ds *databasePluginRPCServer) RotateRootCredentials(args *RotateRootCredent
return err
}

func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, resp *InitializeResponse) error {
config, err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection)
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error {
return ds.Init(&InitRequestRPC{
Config: args.Config,
VerifyConnection: args.VerifyConnection,
}, &InitResponse{})
}

func (ds *databasePluginRPCServer) Init(args *InitRequestRPC, resp *InitResponse) error {
config, err := ds.impl.Init(context.Background(), args.Config, args.VerifyConnection)
if err != nil {
return err
}
Expand Down Expand Up @@ -119,14 +126,19 @@ func (dr *databasePluginRPCClient) RotateRootCredentials(_ context.Context, stat
return saveConf, err
}

func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := dr.Init(nil, conf, verifyConnection)
return err
}

func (dr *databasePluginRPCClient) Init(_ context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
req := InitializeRequestRPC{
Config: conf,
VerifyConnection: verifyConnection,
}

var resp InitializeResponse
err = dr.client.Call("Plugin.Initialize", req, &resp)
var resp InitResponse
err = dr.client.Call("Plugin.Init", req, &resp)
if err != nil {
return nil, err
}
Expand All @@ -146,6 +158,11 @@ type InitializeRequestRPC struct {
VerifyConnection bool
}

type InitRequestRPC struct {
Config map[string]interface{}
VerifyConnection bool
}

type CreateUserRequestRPC struct {
Statements Statements
UsernameConfig UsernameConfig
Expand Down
5 changes: 4 additions & 1 deletion builtin/logical/database/dbplugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ type Database interface {

RotateRootCredentials(ctx context.Context, statements []string) (config map[string]interface{}, err error)

Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) (saveConfig map[string]interface{}, err error)
Init(ctx context.Context, config map[string]interface{}, verifyConnection bool) (saveConfig map[string]interface{}, err error)
Close() error

// DEPRECATED, will be removed in 0.12
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) (err error)
}

// PluginFactory is used to build plugin database types. It wraps the database
Expand Down
22 changes: 11 additions & 11 deletions builtin/logical/database/dbplugin/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statement
func (m *mockPlugin) RotateRootCredentials(_ context.Context, statements []string) (map[string]interface{}, error) {
return nil, nil
}
func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) (map[string]interface{}, error) {
func (m *mockPlugin) Init(_ context.Context, conf map[string]interface{}, _ bool) (map[string]interface{}, error) {
err := errors.New("err")
if len(conf) != 1 {
return nil, err
Expand Down Expand Up @@ -135,7 +135,7 @@ func TestPlugin_NetRPC_Main(t *testing.T) {
plugin.Serve(serveConf)
}

func TestPlugin_Initialize(t *testing.T) {
func TestPlugin_Init(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()

Expand All @@ -148,7 +148,7 @@ func TestPlugin_Initialize(t *testing.T) {
"test": 1,
}

_, err = dbRaw.Initialize(context.Background(), connectionDetails, true)
_, err = dbRaw.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
Expand All @@ -173,7 +173,7 @@ func TestPlugin_CreateUser(t *testing.T) {
"test": 1,
}

_, err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
Expand Down Expand Up @@ -212,7 +212,7 @@ func TestPlugin_RenewUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
_, err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
Expand Down Expand Up @@ -246,7 +246,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
_, err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
Expand Down Expand Up @@ -275,7 +275,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
}

// Test the code is still compatible with an old netRPC plugin
func TestPlugin_NetRPC_Initialize(t *testing.T) {
func TestPlugin_NetRPC_Init(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()

Expand All @@ -288,7 +288,7 @@ func TestPlugin_NetRPC_Initialize(t *testing.T) {
"test": 1,
}

_, err = dbRaw.Initialize(context.Background(), connectionDetails, true)
_, err = dbRaw.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
Expand All @@ -313,7 +313,7 @@ func TestPlugin_NetRPC_CreateUser(t *testing.T) {
"test": 1,
}

_, err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
Expand Down Expand Up @@ -352,7 +352,7 @@ func TestPlugin_NetRPC_RenewUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
_, err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
Expand Down Expand Up @@ -386,7 +386,7 @@ func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
_, err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/database/path_config_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
}
connDetails, err := db.Initialize(ctx, data.Raw, verifyConnection)
connDetails, err := db.Init(ctx, data.Raw, verifyConnection)
if err != nil {
db.Close()
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
Expand Down
Loading

0 comments on commit b33e114

Please sign in to comment.