Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When database plugin Clean is called, close connections in goroutines #15923

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 78 additions & 4 deletions builtin/logical/database/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"net/rpc"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/armon/go-metrics"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/go-uuid"
Expand All @@ -18,6 +20,7 @@ import (
"github.com/hashicorp/vault/sdk/helper/locksutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/queue"
"golang.org/x/sync/semaphore"
)

const (
Expand All @@ -27,6 +30,32 @@ const (
minRootCredRollbackAge = 1 * time.Minute
)

// mutable for tests only
var cleanupMaxWaitTime = 500 * time.Millisecond

// metrics collection
var (
gaugeSync = sync.Mutex{}
gauges = map[string]*int32{}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm always discomforted when I see non-readonly globals, can we pull this into databaseBackend instead? I'm also not sure we need a generic gauge manager, why not just have an atomic numConnections field in databaseBackend? That way we won't have to worry about the lock to protect the map and any possible unpleasantness it might cause us.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A single numConnections wouldn't allow us to see the different types of databases though.

I wanted to avoid putting these into databaseBackend because we might have multiple databaseBackends running, and we probably would like to aggregate these gauges across all of them. With gauges (instead of counters), I don't think that armon/go-metrics or statsd will automatically aggregate them for us (whereas with counters, I believe statsd at least would).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A single numConnections wouldn't allow us to see the different types of databases though.

Good point.

I wanted to avoid putting these into databaseBackend because we might have multiple databaseBackends running, and we probably would like to aggregate these gauges across all of them

Maybe? I'm not sure. I don't have a lot of experience with the database backend. Is it common to run more than one instead of just using multiple configs? I guess on the one hand as a user I'd rather not have the values pre-aggregated, since my metrics store can sum the individual series together but it can't break up a value pre-aggregated by vault. On other hand, if people might have 100s or 1000s of db mounts, I don't want us to break telemetry by creating too many series.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, I would like to add a tag with the instance of the database backend -- maybe the mount path? But I don't see any way to get that. We could tag by the BackendUUID though.

gaugeKey = []string{"secrets", "database", "backend", "connections", "count"}
)

func createGauge(name string) {
gaugeSync.Lock()
defer gaugeSync.Unlock()
gauges[name] = new(int32)
}

func addToGauge(dbType, version string, amount int32) {
labels := []metrics.Label{{"type", dbType}, {"version", version}}
gaugeName := fmt.Sprintf("%s-%s", dbType, version)
if _, ok := gauges[gaugeName]; !ok {
createGauge(gaugeName)
}
val := atomic.AddInt32(gauges[gaugeName], amount)
metrics.SetGaugeWithLabels(gaugeKey, float32(val), labels)
}

type dbPluginInstance struct {
sync.RWMutex
database databaseVersionWrapper
Expand Down Expand Up @@ -280,10 +309,26 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
id: id,
name: name,
}
b.addConnectionsCounter(dbw, 1)
b.connections[name] = dbi
return dbi, nil
}

// addConnectionsCounter is used to keep metrics on how many and what types of databases we have open
func (b *databaseBackend) addConnectionsCounter(dbw databaseVersionWrapper, amount int) {
// keep track of what databases we open
dbType, err := dbw.Type()
if err != nil {
b.Logger().Debug("Error getting database type", "err", err)
dbType = "unknown"
}
version := "5"
if dbw.isV4() {
version = "4"
}
addToGauge(dbType, version, int32(amount))
}

// invalidateQueue cancels any background queue loading and destroys the queue.
func (b *databaseBackend) invalidateQueue() {
// cancel context before grabbing lock to start closing any open connections
Expand Down Expand Up @@ -311,6 +356,7 @@ func (b *databaseBackend) clearConnection(name string) error {
if ok {
// Ignore error here since the database client is always killed
db.Close()
b.addConnectionsCounter(db.database, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically counters should only ever increase, I think you want a gauge instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I thought about using a Gauge, but then we'll have to do some synchronization and use atomics. I was hoping to keep it simple.

But you're right. I realize when I've done this in the past, I've used two counters -- one for increments and one for decrements, and had Datadog do the diff, or just used a gauge.

I'll switch to gauges.

delete(b.connections, name)
}
return nil
Expand All @@ -331,6 +377,7 @@ func (b *databaseBackend) CloseIfShutdown(db *dbPluginInstance, err error) {
// Ensure we are deleting the correct connection
mapDB, ok := b.connections[db.name]
if ok && db.id == mapDB.id {
b.addConnectionsCounter(db.database, -1)
delete(b.connections, db.name)
}
}()
Expand All @@ -339,18 +386,45 @@ func (b *databaseBackend) CloseIfShutdown(db *dbPluginInstance, err error) {

// clean closes all connections from all database types
// and cancels any rotation queue loading operation.
// It spawns goroutines to close the databases since these
// are not guaranteed to finish quickly.
func (b *databaseBackend) clean(ctx context.Context) {
// invalidateQueue acquires it's own lock on the backend, removes queue, and
cleanupCtx, cancel := context.WithDeadline(ctx, time.Now().Add(cleanupMaxWaitTime))
defer cancel()

// invalidateQueue acquires its own lock on the backend, removes queue, and
// terminates the background ticker
b.invalidateQueue()

b.Lock()
defer b.Unlock()

for _, db := range b.connections {
db.Close()
}
// copy all connections so we can close asynchronously
connectionsCopy := b.connections
b.connections = make(map[string]*dbPluginInstance)

// we will try to wait for all the connections to close
// ... but not too long
sem := semaphore.NewWeighted(int64(len(connectionsCopy)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a useful pattern that I have seen messed up a lot, and with a lot of bad answers for on the internet; I'm tempted to move this out to a library since I don't see one already made (what I want is something like a "TimedWaitGroup" or "WaitGroupWithTimeout".

err := sem.Acquire(cleanupCtx, int64(len(connectionsCopy)))
if err != nil {
b.Logger().Debug("Error acquiring semaphore; ignoring", "error", err)
}

for _, db := range connectionsCopy {
go func(db *dbPluginInstance) {
defer sem.Release(1)
b.addConnectionsCounter(db.database, -1)
err := db.Close()
if err != nil {
b.Logger().Debug("Error closing database while cleaning up plugin; ignoring", "error", err)
}
}(db)
}
err = sem.Acquire(cleanupCtx, int64(len(connectionsCopy)))
if err != nil {
b.Logger().Debug("Error in cleanup semaphore; ignoring", "error", err)
}
}

const backendHelp = `
Expand Down
83 changes: 83 additions & 0 deletions builtin/logical/database/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/hashicorp/vault/plugins/database/mongodb"
"github.com/hashicorp/vault/plugins/database/postgresql"
v4 "github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/framework"
Expand Down Expand Up @@ -1461,6 +1462,88 @@ func TestBackend_ConnectionURL_redacted(t *testing.T) {
}
}

type hangingPlugin struct{}

func (h hangingPlugin) Initialize(ctx context.Context, req v5.InitializeRequest) (v5.InitializeResponse, error) {
return v5.InitializeResponse{
Config: req.Config,
}, nil
}

func (h hangingPlugin) NewUser(ctx context.Context, req v5.NewUserRequest) (v5.NewUserResponse, error) {
return v5.NewUserResponse{}, nil
}

func (h hangingPlugin) UpdateUser(ctx context.Context, req v5.UpdateUserRequest) (v5.UpdateUserResponse, error) {
return v5.UpdateUserResponse{}, nil
}

func (h hangingPlugin) DeleteUser(ctx context.Context, req v5.DeleteUserRequest) (v5.DeleteUserResponse, error) {
return v5.DeleteUserResponse{}, nil
}

func (h hangingPlugin) Type() (string, error) {
return "hanging", nil
}

func (h hangingPlugin) Close() error {
time.Sleep(1000 * time.Second)
return nil
}

var _ dbplugin.Database = (*hangingPlugin)(nil)

func TestBackend_PluginMain_Hanging(t *testing.T) {
if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" {
return
}
v5.Serve(&hangingPlugin{})
}

func TestBackend_Closes_Cleanly_Even_If_Plugin_Hangs(t *testing.T) {
cleanupMaxWaitTime = 100 * time.Millisecond
cluster, sys := getCluster(t)
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "hanging-plugin", consts.PluginTypeDatabase, "TestBackend_PluginMain_Hanging", []string{}, "")
t.Cleanup(cluster.Cleanup)

config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = sys

b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}

// Configure a connection
data := map[string]interface{}{
"connection_url": "doesn't matter",
"plugin_name": "hanging-plugin",
"allowed_roles": []string{"plugin-role-test"},
}
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/hang",
Storage: config.StorageView,
Data: data,
}
_, err = b.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
timeout := time.NewTimer(750 * time.Millisecond)
done := make(chan bool)
go func() {
b.Cleanup(context.Background())
done <- true
}()
select {
case <-timeout.C:
t.Error("Hanging plugin caused Close() to time out")
case <-done:
}
}

func testCredsExist(t *testing.T, resp *logical.Response, connURL string) bool {
t.Helper()
var d struct {
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/database/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (m *mockNewDatabase) DeleteUser(ctx context.Context, req v5.DeleteUserReque

func (m *mockNewDatabase) Type() (string, error) {
args := m.Called()
return args.String(0), args.Error(1)
return args.Get(0).(string), args.Error(1)
}

func (m *mockNewDatabase) Close() error {
Expand Down
1 change: 1 addition & 0 deletions builtin/logical/database/rotation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,7 @@ func setupMockDB(b *databaseBackend) *mockNewDatabase {
mockDB := &mockNewDatabase{}
mockDB.On("Initialize", mock.Anything, mock.Anything).Return(v5.InitializeResponse{}, nil)
mockDB.On("Close").Return(nil)
mockDB.On("Type").Return("mock", nil)
dbw := databaseVersionWrapper{
v5: mockDB,
}
Expand Down