Skip to content

Commit

Permalink
add pitr checking (#20838)
Browse files Browse the repository at this point in the history
add pitr checking when creating cdc task

Approved by: @daviszhen, @sukki37
  • Loading branch information
ck89119 authored Dec 19, 2024
1 parent 9f2b7bc commit 6114b9d
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pkg/cdc/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ func AesCFBDecode(ctx context.Context, data string) (string, error) {
return AesCFBDecodeWithKey(ctx, data, []byte(AesKey))
}

func AesCFBDecodeWithKey(ctx context.Context, data string, aesKey []byte) (string, error) {
var AesCFBDecodeWithKey = func(ctx context.Context, data string, aesKey []byte) (string, error) {
if len(aesKey) == 0 {
return "", moerr.NewInternalErrorNoCtx("AesKey is not initialized")
}
Expand Down
77 changes: 68 additions & 9 deletions pkg/frontend/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,13 @@ func doCreateCdc(ctx context.Context, ses *Session, create *tree.CreateCDC) (err
if err != nil {
return err
}

bh := ses.GetBackgroundExec(ctx)
defer bh.Close()
if err = checkPitr(ctx, bh, ses.GetTenantName(), tablePts); err != nil {
return
}

jsonTables, err := cdc2.JsonEncode(tablePts)
if err != nil {
return
Expand Down Expand Up @@ -398,7 +405,7 @@ func doCreateCdc(ctx context.Context, ses *Session, create *tree.CreateCDC) (err
var encodedSinkPwd string
if !useConsole {
// TODO replace with creatorAccountId
if err = initAesKeyWrapper(ctx, tx, catalog.System_Account, service); err != nil {
if err = initAesKeyBySqlExecutor(ctx, tx, catalog.System_Account, service); err != nil {
return
}

Expand Down Expand Up @@ -470,7 +477,7 @@ func cdcTaskMetadata(cdcId string) task.TaskMetadata {
}
}

func queryTable(
var queryTable = func(
ctx context.Context,
tx taskservice.SqlExecutor,
query string,
Expand Down Expand Up @@ -501,6 +508,60 @@ func queryTable(
return false, nil
}

var checkPitr = func(ctx context.Context, bh BackgroundExec, accName string, pts *cdc2.PatternTuples) error {
// TODO min length
minPitrLen := int64(2)
checkPitrByLevel := func(level, dbName, tblName string) (bool, error) {
length, unit, ok, err := getPitrLengthAndUnit(ctx, bh, level, accName, dbName, tblName)
if err != nil {
return false, err
}
if !ok {
return false, nil
}
return !(unit == "h" && length < minPitrLen), nil
}

for _, pt := range pts.Pts {
dbName := pt.Source.Database
tblName := pt.Source.Table
level := cdc2.TableLevel
if dbName == cdc2.MatchAll && tblName == cdc2.MatchAll { // account level
level = cdc2.AccountLevel
} else if tblName == cdc2.MatchAll { // db level
level = cdc2.DbLevel
}

if ok, err := checkPitrByLevel(cdc2.AccountLevel, dbName, tblName); err != nil {
return err
} else if ok {
// covered by account level pitr
continue
}

if level == cdc2.DbLevel || level == cdc2.TableLevel {
if ok, err := checkPitrByLevel(cdc2.DbLevel, dbName, tblName); err != nil {
return err
} else if ok {
// covered by db level pitr
continue
}
}

if level == cdc2.TableLevel {
if ok, err := checkPitrByLevel(cdc2.TableLevel, dbName, tblName); err != nil {
return err
} else if ok {
// covered by table level pitr
continue
}
}

return moerr.NewInternalErrorf(ctx, "no account/db/table level pitr with enough length found for pattern: %s, min pitr length: %d h", pt.OriginString, minPitrLen)
}
return nil
}

// getPatternTuple pattern example:
//
// db1
Expand Down Expand Up @@ -890,7 +951,7 @@ func (cdc *CdcTask) initAesKeyByInternalExecutor(ctx context.Context, accountId
return err
}

cdc2.AesKey, err = decrypt(ctx, encryptedKey, []byte(getGlobalPuWrapper(cdc.cnUUID).SV.KeyEncryptionKey))
cdc2.AesKey, err = cdc2.AesCFBDecodeWithKey(ctx, encryptedKey, []byte(getGlobalPuWrapper(cdc.cnUUID).SV.KeyEncryptionKey))
return
}

Expand Down Expand Up @@ -1140,6 +1201,7 @@ func handleDropCdc(ses *Session, execCtx *ExecCtx, st *tree.DropCDC) error {
}

func handlePauseCdc(ses *Session, execCtx *ExecCtx, st *tree.PauseCDC) error {
ses.GetResponser()
return updateCdc(execCtx.reqCtx, ses, st)
}

Expand Down Expand Up @@ -1554,13 +1616,10 @@ func getTaskCkp(ctx context.Context, bh BackgroundExec, accountId uint32, taskId
}

var (
queryTableWrapper = queryTable
decrypt = cdc2.AesCFBDecodeWithKey
getGlobalPuWrapper = getPu
initAesKeyWrapper = initAesKeyBySqlExecutor
)

func initAesKeyBySqlExecutor(ctx context.Context, executor taskservice.SqlExecutor, accountId uint32, service string) (err error) {
var initAesKeyBySqlExecutor = func(ctx context.Context, executor taskservice.SqlExecutor, accountId uint32, service string) (err error) {
if len(cdc2.AesKey) > 0 {
return nil
}
Expand All @@ -1569,7 +1628,7 @@ func initAesKeyBySqlExecutor(ctx context.Context, executor taskservice.SqlExecut
var ret bool
querySql := fmt.Sprintf(getDataKeyFormat, accountId, cdc2.InitKeyId)

ret, err = queryTableWrapper(ctx, executor, querySql, func(ctx context.Context, rows *sql.Rows) (bool, error) {
ret, err = queryTable(ctx, executor, querySql, func(ctx context.Context, rows *sql.Rows) (bool, error) {
if err = rows.Scan(&encryptedKey); err != nil {
return false, err
}
Expand All @@ -1581,6 +1640,6 @@ func initAesKeyBySqlExecutor(ctx context.Context, executor taskservice.SqlExecut
return moerr.NewInternalError(ctx, "no data key")
}

cdc2.AesKey, err = decrypt(ctx, encryptedKey, []byte(getGlobalPuWrapper(service).SV.KeyEncryptionKey))
cdc2.AesKey, err = cdc2.AesCFBDecodeWithKey(ctx, encryptedKey, []byte(getGlobalPuWrapper(service).SV.KeyEncryptionKey))
return
}
52 changes: 46 additions & 6 deletions pkg/frontend/cdc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ func Test_handleCreateCdc(t *testing.T) {

cdc2.AesKey = "test-aes-key-not-use-it-in-cloud"
defer func() { cdc2.AesKey = "" }()
stub := gostub.Stub(&initAesKeyWrapper, func(context.Context, taskservice.SqlExecutor, uint32, string) (err error) {
stub := gostub.Stub(&initAesKeyBySqlExecutor, func(context.Context, taskservice.SqlExecutor, uint32, string) (err error) {
return nil
})
defer stub.Reset()
Expand All @@ -421,6 +421,11 @@ func Test_handleCreateCdc(t *testing.T) {
})
defer stubOpenDbConn.Reset()

stubCheckPitr := gostub.Stub(&checkPitr, func(ctx context.Context, bh BackgroundExec, accName string, pts *cdc2.PatternTuples) error {
return nil
})
defer stubCheckPitr.Reset()

tests := []struct {
name string
args args
Expand Down Expand Up @@ -2472,7 +2477,7 @@ func Test_initAesKey(t *testing.T) {

{
e := moerr.NewInternalErrorNoCtx("error")
queryTableStub := gostub.Stub(&queryTableWrapper, func(context.Context, taskservice.SqlExecutor, string, func(ctx context.Context, rows *sql.Rows) (bool, error)) (bool, error) {
queryTableStub := gostub.Stub(&queryTable, func(context.Context, taskservice.SqlExecutor, string, func(ctx context.Context, rows *sql.Rows) (bool, error)) (bool, error) {
return true, e
})
defer queryTableStub.Reset()
Expand All @@ -2482,7 +2487,7 @@ func Test_initAesKey(t *testing.T) {
}

{
queryTableStub := gostub.Stub(&queryTableWrapper, func(context.Context, taskservice.SqlExecutor, string, func(ctx context.Context, rows *sql.Rows) (bool, error)) (bool, error) {
queryTableStub := gostub.Stub(&queryTable, func(context.Context, taskservice.SqlExecutor, string, func(ctx context.Context, rows *sql.Rows) (bool, error)) (bool, error) {
return false, nil
})
defer queryTableStub.Reset()
Expand All @@ -2492,12 +2497,12 @@ func Test_initAesKey(t *testing.T) {
}

{
queryTableStub := gostub.Stub(&queryTableWrapper, func(context.Context, taskservice.SqlExecutor, string, func(ctx context.Context, rows *sql.Rows) (bool, error)) (bool, error) {
queryTableStub := gostub.Stub(&queryTable, func(context.Context, taskservice.SqlExecutor, string, func(ctx context.Context, rows *sql.Rows) (bool, error)) (bool, error) {
return true, nil
})
defer queryTableStub.Reset()

decryptStub := gostub.Stub(&decrypt, func(context.Context, string, []byte) (string, error) {
decryptStub := gostub.Stub(&cdc2.AesCFBDecodeWithKey, func(context.Context, string, []byte) (string, error) {
return "aesKey", nil
})
defer decryptStub.Reset()
Expand Down Expand Up @@ -2618,7 +2623,7 @@ func TestCdcTask_initAesKeyByInternalExecutor(t *testing.T) {
ie: mie,
}

decryptStub := gostub.Stub(&decrypt, func(context.Context, string, []byte) (string, error) {
decryptStub := gostub.Stub(&cdc2.AesCFBDecodeWithKey, func(context.Context, string, []byte) (string, error) {
return "aesKey", nil
})
defer decryptStub.Reset()
Expand Down Expand Up @@ -2823,3 +2828,38 @@ func TestCdcTask_addExecPipelineForTable(t *testing.T) {

assert.NoError(t, cdc.addExecPipelineForTable(context.Background(), info, txnOperator))
}

func TestCdcTask_checkPitr(t *testing.T) {
stubGetPitrLength := gostub.Stub(&getPitrLengthAndUnit,
func(_ context.Context, _ BackgroundExec, level, _, _, _ string) (int64, string, bool, error) {
return 0, "", level == "table", nil
},
)
defer stubGetPitrLength.Reset()

pts := &cdc2.PatternTuples{
Pts: []*cdc2.PatternTuple{
{
Source: cdc2.PatternTable{
Database: "db1",
Table: "tb1",
},
},
{
Source: cdc2.PatternTable{
Database: "db2",
Table: cdc2.MatchAll,
},
},
{
Source: cdc2.PatternTable{
Database: cdc2.MatchAll,
Table: cdc2.MatchAll,
},
},
},
}

err := checkPitr(context.Background(), nil, "acc1", pts)
assert.Error(t, err)
}
53 changes: 53 additions & 0 deletions pkg/frontend/pitr.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ var (

// update mo_pitr object id
updateMoPitrAccountObjectIdFmt = `update mo_catalog.mo_pitr set obj_id = %d, modified_time = '%s' where account_name = '%s';`

getLengthAndUnitFmt = `select pitr_length, pitr_unit from mo_catalog.mo_pitr where account_id = %d and level = '%s'`
)

type pitrRecord struct {
Expand Down Expand Up @@ -160,6 +162,18 @@ func getSqlForUpdateMoPitrAccountObjectId(accountName string, objId uint64, modi
return fmt.Sprintf(updateMoPitrAccountObjectIdFmt, objId, modifiedTime, accountName)
}

func getSqlForGetLengthAndUnitFmt(accountId uint32, level, accName, dbName, tblName string) string {
sql := fmt.Sprintf(getLengthAndUnitFmt, accountId, level)
if level == "account" {
sql += fmt.Sprintf(" and account_name = '%s'", accName)
} else if level == "database" {
sql += fmt.Sprintf(" and database_name = '%s'", dbName)
} else if level == "table" {
sql += fmt.Sprintf(" and table_name = '%s'", tblName)
}
return sql
}

func checkPitrDup(ctx context.Context, bh BackgroundExec, createAccount string, createAccountId uint64, stmt *tree.CreatePitr) (bool, error) {
sql := getSqlForCheckPitrDup(createAccount, createAccountId, stmt)

Expand Down Expand Up @@ -2378,3 +2392,42 @@ func updatePitrObjectId(ctx context.Context,
}
return
}

var getPitrLengthAndUnit = func(
ctx context.Context,
bh BackgroundExec,
level string,
accName, dbName, tblName string,
) (length int64, unit string, ok bool, err error) {
accountId, err := defines.GetAccountId(ctx)
if err != nil {
return
}

sql := getSqlForGetLengthAndUnitFmt(accountId, level, accName, dbName, tblName)
ctx = defines.AttachAccountId(ctx, sysAccountID)
bh.ClearExecResultSet()
if err = bh.Exec(ctx, sql); err != nil {
return
}

erArray, err := getResultSet(ctx, bh)
if err != nil {
return
}

if !execResultArrayHasData(erArray) {
return
}

if length, err = erArray[0].GetInt64(ctx, 0, 0); err != nil {
return
}

if unit, err = erArray[0].GetString(ctx, 0, 1); err != nil {
return
}

ok = true
return
}
29 changes: 29 additions & 0 deletions pkg/frontend/pitr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3037,3 +3037,32 @@ func Test_restoreViewsWithPitr(t *testing.T) {
assert.NoError(t, err)
})
}

func Test_getPitrLengthAndUnit(t *testing.T) {
ctx := defines.AttachAccountId(context.Background(), sysAccountID)

bh := &backgroundExecTest{}
bh.init()

bhStub := gostub.StubFunc(&NewBackgroundExec, bh)
defer bhStub.Reset()

sql := getSqlForGetLengthAndUnitFmt(0, "account", "acc1", "", "")
bh.sql2result[sql] = newMrsForPitrRecord([][]interface{}{
{1, "h"},
})
length, unit, ok, err := getPitrLengthAndUnit(ctx, bh, "account", "acc1", "", "")
assert.NoError(t, err)
assert.Equal(t, int64(1), length)
assert.Equal(t, "h", unit)
assert.True(t, ok)

sql = getSqlForGetLengthAndUnitFmt(0, "database", "", "db", "")
bh.sql2result[sql] = newMrsForPitrRecord([][]interface{}{})
_, _, ok, err = getPitrLengthAndUnit(ctx, bh, "database", "", "db", "")
assert.NoError(t, err)
assert.False(t, ok)

_, _, _, err = getPitrLengthAndUnit(ctx, bh, "table", "", "", "tbl")
assert.Error(t, err)
}

0 comments on commit 6114b9d

Please sign in to comment.