From 7ee2d1140f377f31f25a9a441e3964998eddbd5c Mon Sep 17 00:00:00 2001 From: Saurav Malani Date: Thu, 25 Aug 2022 15:34:38 +0530 Subject: [PATCH] feat: perform ADD_DS operation in a single transaction (#2324) --- jobsdb/jobsdb.go | 119 +++++++++++++++++++++++------------------ jobsdb/jobsdb_test.go | 49 +++++++++++++++++ jobsdb/jobsdb_utils.go | 22 +------- 3 files changed, 117 insertions(+), 73 deletions(-) diff --git a/jobsdb/jobsdb.go b/jobsdb/jobsdb.go index 063e24db26..4e0a1dcb2b 100644 --- a/jobsdb/jobsdb.go +++ b/jobsdb/jobsdb.go @@ -1042,28 +1042,6 @@ func (jd *HandleT) Close() { jd.dbHandle.Close() } -// removeExtraKey : removes extra key present in map1 and not in map2 -// Assumption is keys in map1 and map2 are same, except that map1 has one key more than map2 -func removeExtraKey(map1, map2 map[string]string) string { - var deleteKey, key string - for key = range map1 { - if _, ok := map2[key]; !ok { - deleteKey = key - break - } - } - - if deleteKey != "" { - delete(map1, deleteKey) - } - - return deleteKey -} - -func remove(slice []string, idx int) []string { - return append(slice[:idx], slice[idx+1:]...) -} - /* Function to return an ordered list of datasets and datasetRanges Most callers use the in-memory list of dataset and datasetRanges @@ -1107,6 +1085,7 @@ func (jd *HandleT) refreshDSRangeList(l lock.DSListLockToken) []dataSetRangeT { // At this point we must have write-locked dsListLock dsList := jd.refreshDSList(l) + jd.datasetRangeList = nil for idx, ds := range dsList { @@ -1142,7 +1121,8 @@ func (jd *HandleT) refreshDSRangeList(l lock.DSListLockToken) []dataSetRangeT { jd.datasetRangeList = append(jd.datasetRangeList, dataSetRangeT{ minJobID: minID.Int64, - maxJobID: maxID.Int64, ds: ds, + maxJobID: maxID.Int64, + ds: ds, }) prevMax = maxID.Int64 } @@ -1533,11 +1513,31 @@ type transactionHandler interface { } func (jd *HandleT) createDS(newDS dataSetT, l lock.DSListLockToken) { + err := jd.WithTx(func(tx *sql.Tx) error { + return jd.createDSInTx(tx, newDS, l) + }) + jd.assertError(err) + + // In case of a migration, we don't yet update the in-memory list till we finish the migration + if l != nil { + // to get the updated DS list in the cache after createDS transaction has been committed. + _ = jd.refreshDSList(l) + _ = jd.refreshDSRangeList(l) + } +} + +func (jd *HandleT) createDSInTx(tx *sql.Tx, newDS dataSetT, l lock.DSListLockToken) error { // Mark the start of operation. If we crash somewhere here, we delete the // DS being added opPayload, err := json.Marshal(&journalOpPayloadT{To: newDS}) - jd.assertError(err) - opID := jd.JournalMarkStart(addDSOperation, opPayload) + if err != nil { + return err + } + + opID, err := jd.JournalMarkStartInTx(tx, addDSOperation, opPayload) + if err != nil { + return err + } // Create the jobs and job_status tables sqlStatement := fmt.Sprintf(`CREATE TABLE %q ( @@ -1552,14 +1552,18 @@ func (jd *HandleT) createDS(newDS dataSetT, l lock.DSListLockToken) { created_at TIMESTAMP NOT NULL DEFAULT NOW(), expire_at TIMESTAMP NOT NULL DEFAULT NOW());`, newDS.JobTable) - _, err = jd.dbHandle.Exec(sqlStatement) - jd.assertError(err) + _, err = tx.ExecContext(context.TODO(), sqlStatement) + if err != nil { + return err + } // TODO : Evaluate a way to handle indexes only for particular tables if jd.tablePrefix == "rt" { sqlStatement = fmt.Sprintf(`CREATE INDEX IF NOT EXISTS "customval_workspace_%s" ON %q (custom_val,workspace_id)`, newDS.Index, newDS.JobTable) - _, err = jd.dbHandle.Exec(sqlStatement) - jd.assertError(err) + _, err = tx.ExecContext(context.TODO(), sqlStatement) + if err != nil { + return err + } } sqlStatement = fmt.Sprintf(`CREATE TABLE %q ( @@ -1573,38 +1577,48 @@ func (jd *HandleT) createDS(newDS dataSetT, l lock.DSListLockToken) { error_response JSONB DEFAULT '{}'::JSONB, parameters JSONB DEFAULT '{}'::JSONB, PRIMARY KEY (job_id, job_state, id));`, newDS.JobStatusTable, newDS.JobTable) - _, err = jd.dbHandle.Exec(sqlStatement) - jd.assertError(err) - // In case of a migration, we don't yet update the in-memory list till - // we finish the migration + _, err = tx.ExecContext(context.TODO(), sqlStatement) + if err != nil { + return err + } + if l != nil { - jd.setSequenceNumber(l, newDS.Index) + err = jd.setSequenceNumberInTx(tx, l, newDS.Index) + if err != nil { + return err + } } - jd.JournalMarkDone(opID) -} -func (jd *HandleT) setSequenceNumber(l lock.DSListLockToken, newDSIdx string) dataSetT { - // Refresh the in-memory list. We only need to refresh the - // last DS, not the entire but we do it anyway. - // For the range list, we use the cached data. Internally - // it queries the new dataset which was added. - dList := jd.refreshDSList(l) - dRangeList := jd.refreshDSRangeList(l) + err = jd.journalMarkDoneInTx(tx, opID) + if err != nil { + return err + } - // We should not have range values for the last element (the new DS) and migrationTargetDS (if found) - jd.assert(len(dList) == len(dRangeList)+1 || len(dList) == len(dRangeList)+2, fmt.Sprintf("len(dList):%d != len(dRangeList):%d (+1 || +2)", len(dList), len(dRangeList))) + return nil +} + +func (jd *HandleT) setSequenceNumberInTx(tx *sql.Tx, l lock.DSListLockToken, newDSIdx string) error { + dList := jd.getDSList() + var maxID sql.NullInt64 // Now set the min JobID for the new DS just added to be 1 more than previous max - if len(dRangeList) > 0 { - newDSMin := dRangeList[len(dRangeList)-1].maxJobID + 1 - // jd.assert(newDSMin > 0, fmt.Sprintf("newDSMin:%d <= 0", newDSMin)) - sqlStatement := fmt.Sprintf(`ALTER SEQUENCE "%[1]s_jobs_%[2]s_job_id_seq" MINVALUE %[3]d START %[3]d RESTART %[3]d`, + if len(dList) > 0 { + sqlStatement := fmt.Sprintf(`SELECT MAX(job_id) FROM %q`, dList[len(dList)-1].JobTable) + err := tx.QueryRowContext(context.TODO(), sqlStatement).Scan(&maxID) + if err != nil { + return err + } + + newDSMin := maxID.Int64 + 1 + sqlStatement = fmt.Sprintf(`ALTER SEQUENCE "%[1]s_jobs_%[2]s_job_id_seq" MINVALUE %[3]d START %[3]d RESTART %[3]d`, jd.tablePrefix, newDSIdx, newDSMin) - _, err := jd.dbHandle.Exec(sqlStatement) - jd.assertError(err) + _, err = tx.ExecContext(context.TODO(), sqlStatement) + if err != nil { + return err + } } - return dList[len(dList)-1] + return nil } /* @@ -3039,7 +3053,6 @@ func (jd *HandleT) migrateDSLoop(ctx context.Context) { var migrateDSProbeCount int // we don't want `maxDSSize` value to change, during dsList loop maxDSSize := *jd.MaxDSSize - for idx, ds := range dsList { var idxCheck bool diff --git a/jobsdb/jobsdb_test.go b/jobsdb/jobsdb_test.go index 97bc820255..7062edf730 100644 --- a/jobsdb/jobsdb_test.go +++ b/jobsdb/jobsdb_test.go @@ -10,16 +10,21 @@ import ( "strings" "sync" "testing" + "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/ory/dockertest/v3" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" "github.com/rudderlabs/rudder-server/admin" "github.com/rudderlabs/rudder-server/config" + "github.com/rudderlabs/rudder-server/jobsdb/internal/lock" "github.com/rudderlabs/rudder-server/jobsdb/prebackup" "github.com/rudderlabs/rudder-server/services/stats" + "github.com/rudderlabs/rudder-server/testhelper" + "github.com/rudderlabs/rudder-server/testhelper/destination" rsRand "github.com/rudderlabs/rudder-server/testhelper/rand" "github.com/rudderlabs/rudder-server/utils/logger" ) @@ -542,3 +547,47 @@ func sanitizedJsonUsingRegexp(input json.RawMessage) json.RawMessage { func setSkipZeroAssertionForMultitenant(b bool) { skipZeroAssertionForMultitenant = b } + +func TestRefreshDSList(t *testing.T) { + pool, err := dockertest.NewPool("") + require.NoError(t, err, "Failed to create docker pool") + cleanup := &testhelper.Cleanup{} + defer cleanup.Run() + + postgresResource, err := destination.SetupPostgres(pool, cleanup) + require.NoError(t, err) + + { + t.Setenv("JOBS_DB_DB_NAME", postgresResource.Database) + t.Setenv("JOBS_DB_NAME", postgresResource.Database) + t.Setenv("JOBS_DB_HOST", postgresResource.Host) + t.Setenv("JOBS_DB_PORT", postgresResource.Port) + t.Setenv("JOBS_DB_USER", postgresResource.User) + t.Setenv("JOBS_DB_PASSWORD", postgresResource.Password) + initJobsDB() + stats.Setup() + } + + migrationMode := "" + + triggerAddNewDS := make(chan time.Time) + jobsDB := &HandleT{ + TriggerAddNewDS: func() <-chan time.Time { + return triggerAddNewDS + }, + } + queryFilters := QueryFiltersT{ + CustomVal: true, + } + + err = jobsDB.Setup(ReadWrite, false, "batch_rt", migrationMode, true, queryFilters, []prebackup.Handler{}) + require.NoError(t, err) + + require.Equal(t, 1, len(jobsDB.getDSList()), "jobsDB should start with a ds list size of 1") + // this will throw error if refreshDSList is called without lock + jobsDB.addDS(newDataSet("batch_rt", "2")) + require.Equal(t, 1, len(jobsDB.getDSList()), "addDS should not refresh the ds list") + jobsDB.dsListLock.WithLock(func(l lock.DSListLockToken) { + require.Equal(t, 2, len(jobsDB.refreshDSList(l)), "after refreshing the ds list jobsDB should have a ds list size of 2") + }) +} diff --git a/jobsdb/jobsdb_utils.go b/jobsdb/jobsdb_utils.go index f70e149d65..0ec8d2ce6c 100644 --- a/jobsdb/jobsdb_utils.go +++ b/jobsdb/jobsdb_utils.go @@ -45,25 +45,6 @@ func getDSList(jd assertInterface, dbHandle *sql.DB, tablePrefix string) []dataS sortDnumList(jd, dnumList) - // If any service has crashed while creating DS, this may happen. Handling such case gracefully. - if len(jobNameMap) != len(jobStatusNameMap) { - jd.assert(len(jobNameMap) == len(jobStatusNameMap)+1, fmt.Sprintf("Length of jobNameMap(%d) - length of jobStatusNameMap(%d) is more than 1", len(jobNameMap), len(jobStatusNameMap))) - deletedDNum := removeExtraKey(jobNameMap, jobStatusNameMap) - // remove deletedDNum from dnumList - var idx int - var dnum string - var foundDeletedDNum bool - for idx, dnum = range dnumList { - if dnum == deletedDNum { - foundDeletedDNum = true - break - } - } - if foundDeletedDNum { - dnumList = remove(dnumList, idx) - } - } - // Create the structure for _, dnum := range dnumList { jobName, ok := jobNameMap[dnum] @@ -73,7 +54,8 @@ func getDSList(jd assertInterface, dbHandle *sql.DB, tablePrefix string) []dataS datasetList = append(datasetList, dataSetT{ JobTable: jobName, - JobStatusTable: jobStatusName, Index: dnum, + JobStatusTable: jobStatusName, + Index: dnum, }) }