Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROX-13666: Split dbclient.EnsureDBProvisioned() into two functions #742

Merged
merged 2 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@
"filename": "fleetshard/pkg/central/cloudprovider/dbclient_moq.go",
"hashed_secret": "80519927d0f3ce1efe933f46ca9e05e68e491adc",
"is_verified": false,
"line_number": 106
"line_number": 118
}
],
"internal/dinosaur/pkg/api/public/api/openapi.yaml": [
Expand Down Expand Up @@ -546,5 +546,5 @@
}
]
},
"generated_at": "2023-01-13T14:02:09Z"
"generated_at": "2023-01-18T17:58:26Z"
}
56 changes: 31 additions & 25 deletions fleetshard/pkg/central/cloudprovider/awsclient/rds.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,66 +55,82 @@ type RDS struct {
}

// EnsureDBProvisioned is a blocking function that makes sure that an RDS database was provisioned for a Central
func (r *RDS) EnsureDBProvisioned(ctx context.Context, databaseID, masterPassword string) (*postgres.DBConnection, error) {
func (r *RDS) EnsureDBProvisioned(ctx context.Context, databaseID, masterPassword string) error {
clusterID := getClusterID(databaseID)
instanceID := getInstanceID(databaseID)

if err := r.ensureDBClusterCreated(clusterID, masterPassword); err != nil {
return nil, fmt.Errorf("ensuring DB cluster %s exists: %w", clusterID, err)
return fmt.Errorf("ensuring DB cluster %s exists: %w", clusterID, err)
}

if err := r.ensureDBInstanceCreated(instanceID, clusterID); err != nil {
return nil, fmt.Errorf("ensuring DB instance %s exists in cluster %s: %w", instanceID, clusterID, err)
return fmt.Errorf("ensuring DB instance %s exists in cluster %s: %w", instanceID, clusterID, err)
}

return r.waitForInstanceToBeAvailable(ctx, instanceID, clusterID)
}

// EnsureDBDeprovisioned is a function that initiates the deprovisioning of the RDS database of a Central
// Unlike EnsureDBProvisioned, this function does not block until the DB is deprovisioned
func (r *RDS) EnsureDBDeprovisioned(databaseID string) (bool, error) {
func (r *RDS) EnsureDBDeprovisioned(databaseID string) error {
clusterID := getClusterID(databaseID)
instanceID := getInstanceID(databaseID)

instanceExists, err := r.instanceExists(instanceID)
if err != nil {
return false, fmt.Errorf("checking if DB instance exists: %w", err)
return fmt.Errorf("checking if DB instance exists: %w", err)
}
if instanceExists {
status, err := r.instanceStatus(instanceID)
if err != nil {
return false, fmt.Errorf("getting DB instance status: %w", err)
return fmt.Errorf("getting DB instance status: %w", err)
}
if status != dbDeletingStatus {
glog.Infof("Initiating deprovisioning of RDS database instance %s.", instanceID)
// TODO(ROX-13692): do not skip taking a final DB snapshot
_, err := r.rdsClient.DeleteDBInstance(newDeleteCentralDBInstanceInput(instanceID, true))
if err != nil {
return false, fmt.Errorf("deleting DB instance: %w", err)
return fmt.Errorf("deleting DB instance: %w", err)
}
}
}

clusterExists, err := r.clusterExists(clusterID)
if err != nil {
return false, fmt.Errorf("checking if DB cluster exists: %w", err)
return fmt.Errorf("checking if DB cluster exists: %w", err)
}
if clusterExists {
status, err := r.clusterStatus(clusterID)
if err != nil {
return false, fmt.Errorf("getting DB cluster status: %w", err)
return fmt.Errorf("getting DB cluster status: %w", err)
}
if status != dbDeletingStatus {
glog.Infof("Initiating deprovisioning of RDS database cluster %s.", clusterID)
// TODO(ROX-13692): do not skip taking a final DB snapshot
_, err := r.rdsClient.DeleteDBCluster(newDeleteCentralDBClusterInput(clusterID, true))
if err != nil {
return false, fmt.Errorf("deleting DB cluster: %w", err)
return fmt.Errorf("deleting DB cluster: %w", err)
}
}
}

return true, nil
return nil
}

// GetDBConnection returns a postgres.DBConnection struct, which contains the data necessary
// to construct a PostgreSQL connection string. It expects that the database was already provisioned.
func (r *RDS) GetDBConnection(databaseID string) (postgres.DBConnection, error) {
dbCluster, err := r.describeDBCluster(getClusterID(databaseID))
if err != nil {
return postgres.DBConnection{}, err
}

connection, err := postgres.NewDBConnection(*dbCluster.Endpoint, dbPostgresPort, dbUser, dbName)
if err != nil {
return postgres.DBConnection{}, fmt.Errorf("incorrect DB connection parameters: %w", err)
}

return connection, nil
}

func (r *RDS) ensureDBClusterCreated(clusterID, masterPassword string) error {
Expand Down Expand Up @@ -234,25 +250,15 @@ func (r *RDS) describeDBCluster(clusterID string) (*rds.DBCluster, error) {
return result.DBClusters[0], nil
}

func (r *RDS) waitForInstanceToBeAvailable(ctx context.Context, instanceID string, clusterID string) (*postgres.DBConnection, error) {
func (r *RDS) waitForInstanceToBeAvailable(ctx context.Context, instanceID string, clusterID string) error {
for {
dbInstanceStatus, err := r.instanceStatus(instanceID)
if err != nil {
return nil, err
return err
}

if dbInstanceStatus == dbAvailableStatus {
dbCluster, err := r.describeDBCluster(clusterID)
if err != nil {
return nil, err
}

connection, err := postgres.NewDBConnection(*dbCluster.Endpoint, dbPostgresPort, dbUser, dbName)
if err != nil {
return nil, fmt.Errorf("incorrect DB connection parameters: %w", err)
}

return &connection, nil
return nil
}

glog.Infof("RDS instance status: %s (instance ID: %s)", dbInstanceStatus, instanceID)
Expand All @@ -261,7 +267,7 @@ func (r *RDS) waitForInstanceToBeAvailable(ctx context.Context, instanceID strin
case <-ticker.C:
continue
case <-ctx.Done():
return nil, fmt.Errorf("waiting for RDS instance to be available: %w", ctx.Err())
return fmt.Errorf("waiting for RDS instance to be available: %w", ctx.Err())
}
}
}
Expand Down
23 changes: 20 additions & 3 deletions fleetshard/pkg/central/cloudprovider/awsclient/rds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/google/uuid"
Expand Down Expand Up @@ -89,7 +90,10 @@ func TestRDSProvisioning(t *testing.T) {
require.NoError(t, err)
require.False(t, instanceExists)

_, err = rdsClient.EnsureDBProvisioned(ctx, dbID, dbMasterPassword)
err = rdsClient.EnsureDBProvisioned(ctx, dbID, dbMasterPassword)
assert.NoError(t, err)

_, err = rdsClient.GetDBConnection(dbID)
assert.NoError(t, err)

clusterExists, err = rdsClient.clusterExists(clusterID)
Expand All @@ -108,9 +112,8 @@ func TestRDSProvisioning(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, instanceStatus, dbAvailableStatus)

deletionStarted, err := rdsClient.EnsureDBDeprovisioned(dbID)
err = rdsClient.EnsureDBDeprovisioned(dbID)
assert.NoError(t, err)
assert.True(t, deletionStarted)

deleteCtx, deleteCancel := context.WithTimeout(context.TODO(), 10*time.Minute)
defer deleteCancel()
Expand All @@ -119,3 +122,17 @@ func TestRDSProvisioning(t *testing.T) {
require.NoError(t, err)
assert.True(t, clusterDeleted)
}

func TestGetDBConnection(t *testing.T) {
if os.Getenv("RUN_RDS_TESTS") != "true" {
t.Skip("Skip RDS tests. Set RUN_RDS_TESTS=true env variable to enable RDS tests.")
}

rdsClient, err := newTestRDS()
require.NoError(t, err)

_, err = rdsClient.GetDBConnection("test-" + uuid.New().String())
var awsErr awserr.Error
require.ErrorAs(t, err, &awsErr)
assert.Equal(t, awsErr.Code(), rds.ErrCodeDBClusterNotFoundFault)
}
7 changes: 5 additions & 2 deletions fleetshard/pkg/central/cloudprovider/dbclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ import (
type DBClient interface {
// EnsureDBProvisioned is a blocking function that makes sure that a database with the given databaseID was provisioned,
// using the master password given as parameter
EnsureDBProvisioned(ctx context.Context, databaseID, passwordSecretName string) (*postgres.DBConnection, error)
EnsureDBProvisioned(ctx context.Context, databaseID, passwordSecretName string) error
// EnsureDBDeprovisioned is a non-blocking function that makes sure that a managed DB is deprovisioned (more
// specifically, that its deletion was initiated)
EnsureDBDeprovisioned(databaseID string) (bool, error)
EnsureDBDeprovisioned(databaseID string) error
// GetDBConnection returns a postgres.DBConnection struct, which contains the data necessary
// to construct a PostgreSQL connection string. It expects that the database was already provisioned.
GetDBConnection(databaseID string) (postgres.DBConnection, error)
}
56 changes: 50 additions & 6 deletions fleetshard/pkg/central/cloudprovider/dbclient_moq.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 7 additions & 3 deletions fleetshard/pkg/central/reconciler/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,16 @@ func (r *CentralReconciler) Reconcile(ctx context.Context, remoteCentral private
return nil, fmt.Errorf("getting DB password from secret: %w", err)
}

dbConnection, err := r.managedDBProvisioningClient.EnsureDBProvisioned(ctx, remoteCentral.Id, dbMasterPassword)
err = r.managedDBProvisioningClient.EnsureDBProvisioned(ctx, remoteCentral.Id, dbMasterPassword)
if err != nil {
return nil, fmt.Errorf("provisioning RDS DB: %w", err)
}

dbConnection, err := r.managedDBProvisioningClient.GetDBConnection(remoteCentral.Id)
if err != nil {
return nil, fmt.Errorf("getting RDS DB connection data: %w", err)
}

central.Spec.Central.DB = &v1alpha1.CentralDBSpec{
IsEnabled: v1alpha1.CentralDBEnabledPtr(v1alpha1.CentralDBEnabledTrue),
ConnectionStringOverride: pointer.String(dbConnection.AsConnectionString()),
Expand Down Expand Up @@ -376,11 +381,10 @@ func (r *CentralReconciler) ensureCentralDeleted(ctx context.Context, remoteCent
globalDeleted = globalDeleted && centralDeleted

if r.managedDBEnabled {
dbDeleted, err := r.managedDBProvisioningClient.EnsureDBDeprovisioned(remoteCentral.Id)
err = r.managedDBProvisioningClient.EnsureDBDeprovisioned(remoteCentral.Id)
if err != nil {
return false, fmt.Errorf("deprovisioning DB: %v", err)
}
globalDeleted = globalDeleted && dbDeleted

secretDeleted, err := r.ensureCentralDBSecretDeleted(ctx, central.GetNamespace())
if err != nil {
Expand Down
Loading