Skip to content

Commit

Permalink
add data check before rollbeck (apache#366)
Browse files Browse the repository at this point in the history
* add data check before rollbeck
  • Loading branch information
luky116 authored Nov 23, 2022
1 parent 1dc20df commit 6880911
Show file tree
Hide file tree
Showing 9 changed files with 350 additions and 27 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
github.com/dubbogo/gost v1.12.6-0.20220824084206-300e27e9e524
github.com/gin-gonic/gin v1.8.0
github.com/go-sql-driver/mysql v1.6.0
github.com/goccy/go-json v0.9.7
github.com/golang/mock v1.6.0
github.com/google/uuid v1.3.0
github.com/mitchellh/copystructure v1.2.0
Expand Down Expand Up @@ -67,7 +68,6 @@ require (
github.com/go-playground/locales v0.14.0 // indirect
github.com/go-playground/universal-translator v0.18.0 // indirect
github.com/go-resty/resty/v2 v2.7.0 // indirect
github.com/goccy/go-json v0.9.7 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.2 // indirect
Expand Down
114 changes: 114 additions & 0 deletions pkg/datasource/sql/datasource/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package datasource

import (
"database/sql"
"reflect"
)

type nullTime = sql.NullTime

var (
ScanTypeFloat32 = reflect.TypeOf(float32(0))
ScanTypeFloat64 = reflect.TypeOf(float64(0))
ScanTypeInt8 = reflect.TypeOf(int8(0))
ScanTypeInt16 = reflect.TypeOf(int16(0))
ScanTypeInt32 = reflect.TypeOf(int32(0))
ScanTypeInt64 = reflect.TypeOf(int64(0))
ScanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
ScanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
ScanTypeNullTime = reflect.TypeOf(nullTime{})
ScanTypeUint8 = reflect.TypeOf(uint8(0))
ScanTypeUint16 = reflect.TypeOf(uint16(0))
ScanTypeUint32 = reflect.TypeOf(uint32(0))
ScanTypeUint64 = reflect.TypeOf(uint64(0))
ScanTypeRawBytes = reflect.TypeOf(sql.RawBytes{})
ScanTypeUnknown = reflect.TypeOf(new(interface{}))
)

func GetScanSlice(types []*sql.ColumnType) []interface{} {
scanSlice := make([]interface{}, 0, len(types))
for _, tpy := range types {
switch tpy.ScanType() {
case ScanTypeFloat32:
scanVal := float32(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeFloat64:
scanVal := float64(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeInt8:
scanVal := int8(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeInt16:
scanVal := int16(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeInt32:
scanVal := int32(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeInt64:
scanVal := int64(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeNullFloat:
scanVal := sql.NullFloat64{}
scanSlice = append(scanSlice, &scanVal)
case ScanTypeNullInt:
scanVal := sql.NullInt64{}
scanSlice = append(scanSlice, &scanVal)
case ScanTypeNullTime:
scanVal := sql.NullTime{}
scanSlice = append(scanSlice, &scanVal)
case ScanTypeUint8:
scanVal := uint8(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeUint16:
scanVal := uint16(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeUint32:
scanVal := uint32(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeUint64:
scanVal := uint64(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeRawBytes:
scanVal := ""
scanSlice = append(scanSlice, &scanVal)
case ScanTypeUnknown:
scanVal := new(interface{})
scanSlice = append(scanSlice, &scanVal)
}
}
return scanSlice
}

func DeepEqual(x, y interface{}) bool {
typx := reflect.ValueOf(x)
typy := reflect.ValueOf(y)

switch typx.Kind() {
case reflect.Ptr:
typx = typx.Elem()
}

switch typy.Kind() {
case reflect.Ptr:
typy = typy.Elem()
}

return reflect.DeepEqual(typx.Interface(), typy.Interface())
}
3 changes: 2 additions & 1 deletion pkg/datasource/sql/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ package sql
import (
"context"
"database/sql/driver"
"errors"
"sync"

"github.com/pkg/errors"

"github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/protocol/branch"
"github.com/seata/seata-go/pkg/rm"
Expand Down
27 changes: 27 additions & 0 deletions pkg/datasource/sql/types/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"reflect"
"time"
)

Expand Down Expand Up @@ -93,6 +94,19 @@ func (rs RecordImages) Reserve() {
}
}

func (rs RecordImages) IsEmptyImage() bool {
if len(rs) == 0 {
return true
}
for _, r := range rs {
if r == nil || len(r.Rows) == 0 {
continue
}
return false
}
return true
}

// RecordImage
type RecordImage struct {
// index
Expand Down Expand Up @@ -251,3 +265,16 @@ func (c *ColumnImage) UnmarshalJSON(data []byte) error {
func getTypeStr(src interface{}) string {
return fmt.Sprintf("%T", src)
}

func (c *ColumnImage) GetActualValue() interface{} {
if c.Value == nil {
return nil
}
value := reflect.ValueOf(c.Value)
kind := reflect.TypeOf(c.Value).Kind()
switch kind {
case reflect.Ptr:
return value.Elem().Interface()
}
return c.Value
}
6 changes: 5 additions & 1 deletion pkg/datasource/sql/undo/base/undo.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ func (m *BaseUndoLogManager) FlushUndoLog(tranCtx *types.TransactionContext, con
beforeImages := tranCtx.RoundImages.BeofreImages()
afterImages := tranCtx.RoundImages.AfterImages()

if beforeImages.IsEmptyImage() && afterImages.IsEmptyImage() {
return nil
}

for i := 0; i < len(beforeImages); i++ {
var (
tableName string
Expand Down Expand Up @@ -474,7 +478,7 @@ func (m *BaseUndoLogManager) DecodeMap(str string) map[string]string {

// getRollbackInfo parser rollback info
func (m *BaseUndoLogManager) getRollbackInfo(rollbackInfo []byte, undoContext map[string]string) []byte {
// Todo 目前 insert undo log 未实现压缩功能,实现后补齐这块功能
// Todo use compressor
// get compress type
/*compressorType, ok := undoContext[constant.CompressorTypeKey]
if ok {
Expand Down
1 change: 1 addition & 0 deletions pkg/datasource/sql/undo/builder/basic_undo_log_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
type BasicUndoLogBuilder struct{}

// GetScanSlice get the column type for scann
// todo to use ColumnInfo get slice
func (*BasicUndoLogBuilder) GetScanSlice(columnNames []string, tableMeta *types.TableMeta) []driver.Value {
scanSlice := make([]driver.Value, 0, len(columnNames))
for _, columnNmae := range columnNames {
Expand Down
136 changes: 119 additions & 17 deletions pkg/datasource/sql/undo/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ package executor
import (
"context"
"database/sql"
"fmt"
"strings"

"github.com/goccy/go-json"
"github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/datasource/sql/undo"
"github.com/seata/seata-go/pkg/util/log"
Expand All @@ -29,7 +33,8 @@ import (
var _ undo.UndoExecutor = (*BaseExecutor)(nil)

const (
selectSQL = "SELECT * FROM %s WHERE %s FOR UPDATE"
checkSQLTemplate = "SELECT * FROM %s WHERE %s FOR UPDATE"
maxInSize = 1000
)

type BaseExecutor struct {
Expand All @@ -48,33 +53,130 @@ func (b *BaseExecutor) UndoPrepare(undoPST *sql.Stmt, undoValues []types.ColumnI

}

func (b *BaseExecutor) dataValidationAndGoOn(conn *sql.Conn) (bool, error) {
func (b *BaseExecutor) dataValidationAndGoOn(ctx context.Context, conn *sql.Conn) (bool, error) {
beforeImage := b.sqlUndoLog.BeforeImage
afterImage := b.sqlUndoLog.AfterImage

equal, err := IsRecordsEquals(beforeImage, afterImage)
equals, err := IsRecordsEquals(beforeImage, afterImage)
if err != nil {
return false, err
}
if equal {
if equals {
log.Infof("Stop rollback because there is no data change between the before data snapshot and the after data snapshot.")
return false, nil
}

// todo compare from current db data to old image data
// Validate if data is dirty.
currentImage, err := b.queryCurrentRecords(ctx, conn)
if err != nil {
return false, err
}
// compare with current data and after image.
equals, err = IsRecordsEquals(afterImage, currentImage)
if err != nil {
return false, err
}
if !equals {
// If current data is not equivalent to the after data, then compare the current data with the before
// data, too. No need continue to undo if current data is equivalent to the before data snapshot
equals, err = IsRecordsEquals(beforeImage, currentImage)
if err != nil {
return false, err
}

if equals {
log.Infof("Stop rollback because there is no data change between the before data snapshot and the current data snapshot.")
// no need continue undo.
return false, nil
} else {
oldRowJson, _ := json.Marshal(afterImage.Rows)
newRowJson, _ := json.Marshal(currentImage.Rows)
log.Infof("check dirty data failed, old and new data are not equal, "+
"tableName:[%s], oldRows:[%s],newRows:[%s].", afterImage.TableName, oldRowJson, newRowJson)
return false, fmt.Errorf("Has dirty records when undo.")
}
}
return true, nil
}

// todo
//func (b *BaseExecutor) queryCurrentRecords(conn *sql.Conn) *types.RecordImage {
// tableMeta := b.undoImage.TableMeta
// pkNameList := tableMeta.GetPrimaryKeyOnlyName()
//
// b.undoImage.Rows
//
//}
//
//func (b *BaseExecutor) parsePkValues(rows []types.RowImage, pkNameList []string) {
//
//}
func (b *BaseExecutor) queryCurrentRecords(ctx context.Context, conn *sql.Conn) (*types.RecordImage, error) {
if b.undoImage == nil {
return nil, fmt.Errorf("undo image is nil")
}
tableMeta := b.undoImage.TableMeta
pkNameList := tableMeta.GetPrimaryKeyOnlyName()
pkValues := b.parsePkValues(b.undoImage.Rows, pkNameList)

if len(pkValues) == 0 {
return nil, nil
}

var rowSize int
for _, images := range pkValues {
rowSize = len(images)
break
}

where := buildWhereConditionByPKs(pkNameList, rowSize, maxInSize)
checkSQL := fmt.Sprintf(checkSQLTemplate, b.undoImage.TableName, where)
params := buildPKParams(b.undoImage.Rows, pkNameList)

rows, err := conn.QueryContext(ctx, checkSQL, params...)
if err != nil {
return nil, err
}

image := types.RecordImage{
TableName: b.undoImage.TableName,
TableMeta: tableMeta,
SQLType: types.SQLTypeSelect,
}
rowImages := make([]types.RowImage, 0)
for rows.Next() {
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
slice := datasource.GetScanSlice(columnTypes)
if err = rows.Scan(slice...); err != nil {
return nil, err
}

colNames, err := rows.Columns()
if err != nil {
return nil, err
}

columns := make([]types.ColumnImage, 0)
for i, val := range slice {
columns = append(columns, types.ColumnImage{
ColumnName: colNames[i],
Value: val,
})
}
rowImages = append(rowImages, types.RowImage{Columns: columns})
}

image.Rows = rowImages
return &image, nil
}

func (b *BaseExecutor) parsePkValues(rows []types.RowImage, pkNameList []string) map[string][]types.ColumnImage {
pkValues := make(map[string][]types.ColumnImage)
// todo optimize 3 fors
for _, row := range rows {
for _, column := range row.Columns {
for _, pk := range pkNameList {
if strings.EqualFold(pk, column.ColumnName) {
values := pkValues[strings.ToUpper(pk)]
if values == nil {
values = make([]types.ColumnImage, 0)
}
values = append(values, column)
pkValues[pk] = values
}
}
}
}
return pkValues
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ type mySQLUndoUpdateExecutor struct {
func newMySQLUndoUpdateExecutor(sqlUndoLog undo.SQLUndoLog) *mySQLUndoUpdateExecutor {
return &mySQLUndoUpdateExecutor{
sqlUndoLog: sqlUndoLog,
baseExecutor: &BaseExecutor{sqlUndoLog: sqlUndoLog},
baseExecutor: &BaseExecutor{sqlUndoLog: sqlUndoLog, undoImage: sqlUndoLog.AfterImage},
}
}

func (m *mySQLUndoUpdateExecutor) ExecuteOn(ctx context.Context, dbType types.DBType, conn *sql.Conn) error {
ok, err := m.baseExecutor.dataValidationAndGoOn(conn)
ok, err := m.baseExecutor.dataValidationAndGoOn(ctx, conn)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 6880911

Please sign in to comment.