diff --git a/go.mod b/go.mod index b6390c5af..afbff7d1b 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/dubbogo/gost v1.12.6-0.20220824084206-300e27e9e524 github.com/gin-gonic/gin v1.8.0 github.com/go-sql-driver/mysql v1.6.0 + github.com/goccy/go-json v0.9.7 github.com/golang/mock v1.6.0 github.com/google/uuid v1.3.0 github.com/mitchellh/copystructure v1.2.0 @@ -67,7 +68,6 @@ require ( github.com/go-playground/locales v0.14.0 // indirect github.com/go-playground/universal-translator v0.18.0 // indirect github.com/go-resty/resty/v2 v2.7.0 // indirect - github.com/goccy/go-json v0.9.7 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.2 // indirect diff --git a/pkg/datasource/sql/datasource/utils.go b/pkg/datasource/sql/datasource/utils.go new file mode 100644 index 000000000..05890b945 --- /dev/null +++ b/pkg/datasource/sql/datasource/utils.go @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package datasource + +import ( + "database/sql" + "reflect" +) + +type nullTime = sql.NullTime + +var ( + ScanTypeFloat32 = reflect.TypeOf(float32(0)) + ScanTypeFloat64 = reflect.TypeOf(float64(0)) + ScanTypeInt8 = reflect.TypeOf(int8(0)) + ScanTypeInt16 = reflect.TypeOf(int16(0)) + ScanTypeInt32 = reflect.TypeOf(int32(0)) + ScanTypeInt64 = reflect.TypeOf(int64(0)) + ScanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) + ScanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) + ScanTypeNullTime = reflect.TypeOf(nullTime{}) + ScanTypeUint8 = reflect.TypeOf(uint8(0)) + ScanTypeUint16 = reflect.TypeOf(uint16(0)) + ScanTypeUint32 = reflect.TypeOf(uint32(0)) + ScanTypeUint64 = reflect.TypeOf(uint64(0)) + ScanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) + ScanTypeUnknown = reflect.TypeOf(new(interface{})) +) + +func GetScanSlice(types []*sql.ColumnType) []interface{} { + scanSlice := make([]interface{}, 0, len(types)) + for _, tpy := range types { + switch tpy.ScanType() { + case ScanTypeFloat32: + scanVal := float32(0) + scanSlice = append(scanSlice, &scanVal) + case ScanTypeFloat64: + scanVal := float64(0) + scanSlice = append(scanSlice, &scanVal) + case ScanTypeInt8: + scanVal := int8(0) + scanSlice = append(scanSlice, &scanVal) + case ScanTypeInt16: + scanVal := int16(0) + scanSlice = append(scanSlice, &scanVal) + case ScanTypeInt32: + scanVal := int32(0) + scanSlice = append(scanSlice, &scanVal) + case ScanTypeInt64: + scanVal := int64(0) + scanSlice = append(scanSlice, &scanVal) + case ScanTypeNullFloat: + scanVal := sql.NullFloat64{} + scanSlice = append(scanSlice, &scanVal) + case ScanTypeNullInt: + scanVal := sql.NullInt64{} + scanSlice = append(scanSlice, &scanVal) + case ScanTypeNullTime: + scanVal := sql.NullTime{} + scanSlice = append(scanSlice, &scanVal) + case ScanTypeUint8: + scanVal := uint8(0) + scanSlice = append(scanSlice, &scanVal) + case ScanTypeUint16: + scanVal := uint16(0) + scanSlice = append(scanSlice, &scanVal) + case ScanTypeUint32: + scanVal := uint32(0) + scanSlice = append(scanSlice, &scanVal) + case ScanTypeUint64: + scanVal := uint64(0) + scanSlice = append(scanSlice, &scanVal) + case ScanTypeRawBytes: + scanVal := "" + scanSlice = append(scanSlice, &scanVal) + case ScanTypeUnknown: + scanVal := new(interface{}) + scanSlice = append(scanSlice, &scanVal) + } + } + return scanSlice +} + +func DeepEqual(x, y interface{}) bool { + typx := reflect.ValueOf(x) + typy := reflect.ValueOf(y) + + switch typx.Kind() { + case reflect.Ptr: + typx = typx.Elem() + } + + switch typy.Kind() { + case reflect.Ptr: + typy = typy.Elem() + } + + return reflect.DeepEqual(typx.Interface(), typy.Interface()) +} diff --git a/pkg/datasource/sql/tx.go b/pkg/datasource/sql/tx.go index 9901862b3..694932990 100644 --- a/pkg/datasource/sql/tx.go +++ b/pkg/datasource/sql/tx.go @@ -20,9 +20,10 @@ package sql import ( "context" "database/sql/driver" - "errors" "sync" + "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/protocol/branch" "github.com/seata/seata-go/pkg/rm" diff --git a/pkg/datasource/sql/types/image.go b/pkg/datasource/sql/types/image.go index 423d31335..db69a75dd 100644 --- a/pkg/datasource/sql/types/image.go +++ b/pkg/datasource/sql/types/image.go @@ -21,6 +21,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "reflect" "time" ) @@ -93,6 +94,19 @@ func (rs RecordImages) Reserve() { } } +func (rs RecordImages) IsEmptyImage() bool { + if len(rs) == 0 { + return true + } + for _, r := range rs { + if r == nil || len(r.Rows) == 0 { + continue + } + return false + } + return true +} + // RecordImage type RecordImage struct { // index @@ -251,3 +265,16 @@ func (c *ColumnImage) UnmarshalJSON(data []byte) error { func getTypeStr(src interface{}) string { return fmt.Sprintf("%T", src) } + +func (c *ColumnImage) GetActualValue() interface{} { + if c.Value == nil { + return nil + } + value := reflect.ValueOf(c.Value) + kind := reflect.TypeOf(c.Value).Kind() + switch kind { + case reflect.Ptr: + return value.Elem().Interface() + } + return c.Value +} diff --git a/pkg/datasource/sql/undo/base/undo.go b/pkg/datasource/sql/undo/base/undo.go index 595c72553..078137a17 100644 --- a/pkg/datasource/sql/undo/base/undo.go +++ b/pkg/datasource/sql/undo/base/undo.go @@ -162,6 +162,10 @@ func (m *BaseUndoLogManager) FlushUndoLog(tranCtx *types.TransactionContext, con beforeImages := tranCtx.RoundImages.BeofreImages() afterImages := tranCtx.RoundImages.AfterImages() + if beforeImages.IsEmptyImage() && afterImages.IsEmptyImage() { + return nil + } + for i := 0; i < len(beforeImages); i++ { var ( tableName string @@ -474,7 +478,7 @@ func (m *BaseUndoLogManager) DecodeMap(str string) map[string]string { // getRollbackInfo parser rollback info func (m *BaseUndoLogManager) getRollbackInfo(rollbackInfo []byte, undoContext map[string]string) []byte { - // Todo 目前 insert undo log 未实现压缩功能,实现后补齐这块功能 + // Todo use compressor // get compress type /*compressorType, ok := undoContext[constant.CompressorTypeKey] if ok { diff --git a/pkg/datasource/sql/undo/builder/basic_undo_log_builder.go b/pkg/datasource/sql/undo/builder/basic_undo_log_builder.go index bb0c4467c..61d152db2 100644 --- a/pkg/datasource/sql/undo/builder/basic_undo_log_builder.go +++ b/pkg/datasource/sql/undo/builder/basic_undo_log_builder.go @@ -35,6 +35,7 @@ import ( type BasicUndoLogBuilder struct{} // GetScanSlice get the column type for scann +// todo to use ColumnInfo get slice func (*BasicUndoLogBuilder) GetScanSlice(columnNames []string, tableMeta *types.TableMeta) []driver.Value { scanSlice := make([]driver.Value, 0, len(columnNames)) for _, columnNmae := range columnNames { diff --git a/pkg/datasource/sql/undo/executor/executor.go b/pkg/datasource/sql/undo/executor/executor.go index cde5a3a82..b210f0249 100644 --- a/pkg/datasource/sql/undo/executor/executor.go +++ b/pkg/datasource/sql/undo/executor/executor.go @@ -20,7 +20,11 @@ package executor import ( "context" "database/sql" + "fmt" + "strings" + "github.com/goccy/go-json" + "github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/datasource/sql/undo" "github.com/seata/seata-go/pkg/util/log" @@ -29,7 +33,8 @@ import ( var _ undo.UndoExecutor = (*BaseExecutor)(nil) const ( - selectSQL = "SELECT * FROM %s WHERE %s FOR UPDATE" + checkSQLTemplate = "SELECT * FROM %s WHERE %s FOR UPDATE" + maxInSize = 1000 ) type BaseExecutor struct { @@ -48,33 +53,130 @@ func (b *BaseExecutor) UndoPrepare(undoPST *sql.Stmt, undoValues []types.ColumnI } -func (b *BaseExecutor) dataValidationAndGoOn(conn *sql.Conn) (bool, error) { +func (b *BaseExecutor) dataValidationAndGoOn(ctx context.Context, conn *sql.Conn) (bool, error) { beforeImage := b.sqlUndoLog.BeforeImage afterImage := b.sqlUndoLog.AfterImage - equal, err := IsRecordsEquals(beforeImage, afterImage) + equals, err := IsRecordsEquals(beforeImage, afterImage) if err != nil { return false, err } - if equal { + if equals { log.Infof("Stop rollback because there is no data change between the before data snapshot and the after data snapshot.") return false, nil } - // todo compare from current db data to old image data + // Validate if data is dirty. + currentImage, err := b.queryCurrentRecords(ctx, conn) + if err != nil { + return false, err + } + // compare with current data and after image. + equals, err = IsRecordsEquals(afterImage, currentImage) + if err != nil { + return false, err + } + if !equals { + // If current data is not equivalent to the after data, then compare the current data with the before + // data, too. No need continue to undo if current data is equivalent to the before data snapshot + equals, err = IsRecordsEquals(beforeImage, currentImage) + if err != nil { + return false, err + } + if equals { + log.Infof("Stop rollback because there is no data change between the before data snapshot and the current data snapshot.") + // no need continue undo. + return false, nil + } else { + oldRowJson, _ := json.Marshal(afterImage.Rows) + newRowJson, _ := json.Marshal(currentImage.Rows) + log.Infof("check dirty data failed, old and new data are not equal, "+ + "tableName:[%s], oldRows:[%s],newRows:[%s].", afterImage.TableName, oldRowJson, newRowJson) + return false, fmt.Errorf("Has dirty records when undo.") + } + } return true, nil } -// todo -//func (b *BaseExecutor) queryCurrentRecords(conn *sql.Conn) *types.RecordImage { -// tableMeta := b.undoImage.TableMeta -// pkNameList := tableMeta.GetPrimaryKeyOnlyName() -// -// b.undoImage.Rows -// -//} -// -//func (b *BaseExecutor) parsePkValues(rows []types.RowImage, pkNameList []string) { -// -//} +func (b *BaseExecutor) queryCurrentRecords(ctx context.Context, conn *sql.Conn) (*types.RecordImage, error) { + if b.undoImage == nil { + return nil, fmt.Errorf("undo image is nil") + } + tableMeta := b.undoImage.TableMeta + pkNameList := tableMeta.GetPrimaryKeyOnlyName() + pkValues := b.parsePkValues(b.undoImage.Rows, pkNameList) + + if len(pkValues) == 0 { + return nil, nil + } + + var rowSize int + for _, images := range pkValues { + rowSize = len(images) + break + } + + where := buildWhereConditionByPKs(pkNameList, rowSize, maxInSize) + checkSQL := fmt.Sprintf(checkSQLTemplate, b.undoImage.TableName, where) + params := buildPKParams(b.undoImage.Rows, pkNameList) + + rows, err := conn.QueryContext(ctx, checkSQL, params...) + if err != nil { + return nil, err + } + + image := types.RecordImage{ + TableName: b.undoImage.TableName, + TableMeta: tableMeta, + SQLType: types.SQLTypeSelect, + } + rowImages := make([]types.RowImage, 0) + for rows.Next() { + columnTypes, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + slice := datasource.GetScanSlice(columnTypes) + if err = rows.Scan(slice...); err != nil { + return nil, err + } + + colNames, err := rows.Columns() + if err != nil { + return nil, err + } + + columns := make([]types.ColumnImage, 0) + for i, val := range slice { + columns = append(columns, types.ColumnImage{ + ColumnName: colNames[i], + Value: val, + }) + } + rowImages = append(rowImages, types.RowImage{Columns: columns}) + } + + image.Rows = rowImages + return &image, nil +} + +func (b *BaseExecutor) parsePkValues(rows []types.RowImage, pkNameList []string) map[string][]types.ColumnImage { + pkValues := make(map[string][]types.ColumnImage) + // todo optimize 3 fors + for _, row := range rows { + for _, column := range row.Columns { + for _, pk := range pkNameList { + if strings.EqualFold(pk, column.ColumnName) { + values := pkValues[strings.ToUpper(pk)] + if values == nil { + values = make([]types.ColumnImage, 0) + } + values = append(values, column) + pkValues[pk] = values + } + } + } + } + return pkValues +} diff --git a/pkg/datasource/sql/undo/executor/mysql_undo_update_executor.go b/pkg/datasource/sql/undo/executor/mysql_undo_update_executor.go index 86b22c74b..4c0a3e548 100644 --- a/pkg/datasource/sql/undo/executor/mysql_undo_update_executor.go +++ b/pkg/datasource/sql/undo/executor/mysql_undo_update_executor.go @@ -36,12 +36,12 @@ type mySQLUndoUpdateExecutor struct { func newMySQLUndoUpdateExecutor(sqlUndoLog undo.SQLUndoLog) *mySQLUndoUpdateExecutor { return &mySQLUndoUpdateExecutor{ sqlUndoLog: sqlUndoLog, - baseExecutor: &BaseExecutor{sqlUndoLog: sqlUndoLog}, + baseExecutor: &BaseExecutor{sqlUndoLog: sqlUndoLog, undoImage: sqlUndoLog.AfterImage}, } } func (m *mySQLUndoUpdateExecutor) ExecuteOn(ctx context.Context, dbType types.DBType, conn *sql.Conn) error { - ok, err := m.baseExecutor.dataValidationAndGoOn(conn) + ok, err := m.baseExecutor.dataValidationAndGoOn(ctx, conn) if err != nil { return err } diff --git a/pkg/datasource/sql/undo/executor/utils.go b/pkg/datasource/sql/undo/executor/utils.go index ca7d9b626..0a641e490 100644 --- a/pkg/datasource/sql/undo/executor/utils.go +++ b/pkg/datasource/sql/undo/executor/utils.go @@ -19,10 +19,11 @@ package executor import ( "fmt" - "reflect" "strings" + "github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/util/log" ) // IsRecordsEquals check before record and after record if equal @@ -51,14 +52,16 @@ func compareRows(tableMeta types.TableMeta, oldRows []types.RowImage, newRows [] for key, oldRow := range oldRowMap { newRow := newRowMap[key] if newRow == nil { - return false, fmt.Errorf("compare row failed, rowKey %s, reason [newField is null]", key) + log.Errorf("compare row failed, rowKey %s, reason new field is null", key) + return false, fmt.Errorf("compare image failed for new row is null") } for fieldName, oldValue := range oldRow { newValue := newRow[fieldName] if newValue == nil { - return false, fmt.Errorf("compare row failed, rowKey %s, fieldName %s, reason [newField is null]", key, fieldName) + log.Errorf("compare row failed, rowKey %s, fieldName %s, reason new value is null", key, fieldName) + return false, fmt.Errorf("compare image failed for new value is null") } - if !reflect.DeepEqual(newValue, oldValue) { + if !datasource.DeepEqual(newValue, oldValue) { return false, nil } } @@ -80,7 +83,7 @@ func rowListToMap(rows []types.RowImage, primaryKeyList []string) map[string]map rowKey += "_##$$_" } // todo make value more accurate - rowKey = fmt.Sprintf("%v%v", rowKey, column.Value) + rowKey = fmt.Sprintf("%v%v", rowKey, column.GetActualValue()) firstUnderline = true } } @@ -90,3 +93,74 @@ func rowListToMap(rows []types.RowImage, primaryKeyList []string) map[string]map } return rowMap } + +// buildWhereConditionByPKs build where condition by primary keys +// each pk is a condition.the result will like :" (id,userCode) in ((?,?),(?,?)) or (id,userCode) in ((?,?),(?,?) ) or (id,userCode) in ((?,?))" +func buildWhereConditionByPKs(pkNameList []string, rowSize int, maxInSize int) string { + var ( + whereStr = &strings.Builder{} + batchSize = rowSize/maxInSize + 1 + ) + + if rowSize%maxInSize == 0 { + batchSize = rowSize / maxInSize + } + + for batch := 0; batch < batchSize; batch++ { + if batch > 0 { + whereStr.WriteString(" OR ") + } + whereStr.WriteString("(") + + for i := 0; i < len(pkNameList); i++ { + if i > 0 { + whereStr.WriteString(",") + } + // todo add escape + whereStr.WriteString(fmt.Sprintf("`%s`", pkNameList[i])) + } + whereStr.WriteString(") IN (") + + var eachSize int + + if batch == batchSize-1 { + if rowSize%maxInSize == 0 { + eachSize = maxInSize + } else { + eachSize = rowSize % maxInSize + } + } else { + eachSize = maxInSize + } + + for i := 0; i < eachSize; i++ { + if i > 0 { + whereStr.WriteString(",") + } + whereStr.WriteString("(") + for j := 0; j < len(pkNameList); j++ { + if j > 0 { + whereStr.WriteString(",") + } + whereStr.WriteString("?") + } + whereStr.WriteString(")") + } + whereStr.WriteString(")") + } + return whereStr.String() +} + +func buildPKParams(rows []types.RowImage, pkNameList []string) []interface{} { + params := make([]interface{}, 0) + for _, row := range rows { + coumnMap := row.GetColumnMap() + for _, pk := range pkNameList { + col := coumnMap[pk] + if col != nil { + params = append(params, col.Value) + } + } + } + return params +}