diff --git a/pkg/datasource/sql/types/image.go b/pkg/datasource/sql/types/image.go index f755d9c91..77108fae2 100644 --- a/pkg/datasource/sql/types/image.go +++ b/pkg/datasource/sql/types/image.go @@ -18,6 +18,7 @@ package types import ( + "database/sql/driver" "encoding/base64" "encoding/json" "reflect" @@ -117,14 +118,16 @@ type RecordImage struct { // Rows data row Rows []RowImage `json:"rows"` // TableMeta table information schema - TableMeta *TableMeta `json:"-"` + TableMeta *TableMeta `json:"-"` + PrimaryKeyMap map[string][]driver.Value `json:"primaryKeyMap,omitempty"` } func NewEmptyRecordImage(tableMeta *TableMeta, sqlType SQLType) *RecordImage { return &RecordImage{ - TableName: tableMeta.TableName, - TableMeta: tableMeta, - SQLType: sqlType, + TableName: tableMeta.TableName, + TableMeta: tableMeta, + SQLType: sqlType, + PrimaryKeyMap: make(map[string][]driver.Value), } } diff --git a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go index 4a23dedda..5da3d35e7 100644 --- a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go +++ b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go @@ -97,68 +97,136 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildBeforeImageSQL(insertStmt *a if err := checkDuplicateKeyUpdate(insertStmt, metaData); err != nil { return "", nil, err } - var selectArgs []driver.Value + + // Reset primary keys map + u.BeforeImageSqlPrimaryKeys = make(map[string]bool) + pkIndexMap := u.getPkIndex(insertStmt, metaData) var pkIndexArray []int for _, val := range pkIndexMap { - tmpVal := val - pkIndexArray = append(pkIndexArray, tmpVal) + pkIndexArray = append(pkIndexArray, val) } insertRows, err := getInsertRows(insertStmt, pkIndexArray) if err != nil { return "", nil, err } - insertNum := len(insertRows) + paramMap, err := u.buildImageParameters(insertStmt, args, insertRows) if err != nil { return "", nil, err } - sql := strings.Builder{} - sql.WriteString("SELECT * FROM " + metaData.TableName + " ") + // 如果没有参数或没有主键索引,直接返回空 + if len(paramMap) == 0 || len(metaData.Indexs) == 0 { + return "", nil, nil + } + + // 检查是否有主键 + hasPK := false + for _, index := range metaData.Indexs { + if strings.EqualFold("PRIMARY", index.Name) { + hasPK = true + break + } + } + if !hasPK { + return "", nil, nil + } + + var sql strings.Builder + sql.WriteString("SELECT * FROM " + metaData.TableName + " ") + + var selectArgs []driver.Value isContainWhere := false - for i := 0; i < insertNum; i++ { - finalI := i - paramAppenderTempList := make([]driver.Value, 0) + hasConditions := false + + for i := 0; i < len(insertRows); i++ { + var rowConditions []string + var rowArgs []driver.Value + usedParams := make(map[string]bool) + + // First try unique indexes for _, index := range metaData.Indexs { - //unique index - if index.NonUnique || isIndexValueNotNull(index, paramMap, finalI) == false { + if index.NonUnique || strings.EqualFold("PRIMARY", index.Name) { continue } - columnIsNull := true - uniqueList := make([]string, 0) - for _, columnMeta := range index.Columns { - columnName := columnMeta.ColumnName - imageParameters, ok := paramMap[columnName] - if !ok && columnMeta.ColumnDef != nil { - if strings.EqualFold("PRIMARY", index.Name) { - u.BeforeImageSqlPrimaryKeys[columnName] = true - } - uniqueList = append(uniqueList, columnName+" = DEFAULT("+columnName+") ") - columnIsNull = false - continue + + if !isIndexValueNotNull(index, paramMap, i) { + continue + } + + var indexConditions []string + var indexArgs []driver.Value + allColumnsPresent := true + for _, colMeta := range index.Columns { + columnName := colMeta.ColumnName + if params, ok := paramMap[columnName]; ok && len(params) > i && params[i] != nil { + indexConditions = append(indexConditions, columnName+" = ? ") + indexArgs = append(indexArgs, params[i]) + usedParams[columnName] = true + } else if colMeta.ColumnDef != nil { + indexConditions = append(indexConditions, columnName+" = DEFAULT("+columnName+")") + } else { + allColumnsPresent = false + break } - if strings.EqualFold("PRIMARY", index.Name) { - u.BeforeImageSqlPrimaryKeys[columnName] = true + } + + if allColumnsPresent && len(indexConditions) > 0 { + rowConditions = append(rowConditions, "("+strings.Join(indexConditions, " and ")+")") + rowArgs = append(rowArgs, indexArgs...) + hasConditions = true + } + } + + // Then try primary key + for _, index := range metaData.Indexs { + if !strings.EqualFold("PRIMARY", index.Name) { + continue + } + + var pkConditions []string + var pkArgs []driver.Value + for _, colMeta := range index.Columns { + columnName := colMeta.ColumnName + u.BeforeImageSqlPrimaryKeys[columnName] = true + if params, ok := paramMap[columnName]; ok && len(params) > i && params[i] != nil && !usedParams[columnName] { + pkConditions = append(pkConditions, columnName+" = ? ") + pkArgs = append(pkArgs, params[i]) } - columnIsNull = false - uniqueList = append(uniqueList, columnName+" = ? ") - paramAppenderTempList = append(paramAppenderTempList, imageParameters[finalI]) } - if !columnIsNull { - if isContainWhere { - sql.WriteString(" OR (" + strings.Join(uniqueList, " and ") + ") ") - } else { - sql.WriteString(" WHERE (" + strings.Join(uniqueList, " and ") + ") ") - isContainWhere = true + if len(pkConditions) > 0 { + rowConditions = append(rowConditions, "("+strings.Join(pkConditions, " and ")+")") + rowArgs = append(rowArgs, pkArgs...) + hasConditions = true + } + } + + if len(rowConditions) > 0 { + if !isContainWhere { + sql.WriteString("WHERE ") + isContainWhere = true + } else { + sql.WriteString(" OR ") + } + for j, condition := range rowConditions { + if j > 0 { + sql.WriteString(" OR ") } + sql.WriteString(condition + " ") } + selectArgs = append(selectArgs, rowArgs...) } - selectArgs = append(selectArgs, paramAppenderTempList...) } - log.Infof("build select sql by insert on duplicate sourceQuery, sql {}", sql.String()) - return sql.String(), selectArgs, nil + + if !hasConditions { + return "", nil, nil + } + + sqlStr := sql.String() + log.Infof("build select sql by insert on duplicate sourceQuery, sql: %s", sqlStr) + return sqlStr, selectArgs, nil } func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { @@ -168,18 +236,22 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, e log.Errorf("build prepare stmt: %+v", err) return nil, err } + defer stmt.Close() + + tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O + metaData := execCtx.MetaDataMap[tableName] rows, err := stmt.Query(selectArgs) if err != nil { - log.Errorf("stmt query: %+v", err) return nil, err } - tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O - metaData := execCtx.MetaDataMap[tableName] + defer rows.Close() + image, err := u.buildRecordImages(rows, &metaData) if err != nil { return nil, err } + return []*types.RecordImage{image}, nil } @@ -190,6 +262,13 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Co if len(beforeImages) > 0 { beforeImage = beforeImages[0] } + + // 如果没有before image,直接返回原始SQL和参数 + if beforeImage == nil || len(beforeImage.Rows) == 0 { + return selectSQL, selectArgs + } + + // 收集主键值 primaryValueMap := make(map[string][]interface{}) for _, row := range beforeImage.Rows { for _, col := range row.Columns { @@ -200,23 +279,53 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Co } var afterImageSql strings.Builder - var primaryValues []driver.Value afterImageSql.WriteString(selectSQL) - for i := 0; i < len(beforeImage.Rows); i++ { - wherePrimaryList := make([]string, 0) - for name, value := range primaryValueMap { - if !u.BeforeImageSqlPrimaryKeys[name] { - wherePrimaryList = append(wherePrimaryList, name+" = ? ") - primaryValues = append(primaryValues, value[i]) + + // 如果原始SQL已经包含了所有需要的条件,直接返回 + if len(primaryValueMap) == 0 || len(selectArgs) == len(beforeImage.Rows)*len(primaryValueMap) { + return selectSQL, selectArgs + } + + // 添加主键条件 + var primaryValues []driver.Value + usedPrimaryKeys := make(map[string]bool) + + for name := range primaryValueMap { + if !u.BeforeImageSqlPrimaryKeys[name] { + usedPrimaryKeys[name] = true + for i := 0; i < len(beforeImage.Rows); i++ { + if value := primaryValueMap[name][i]; value != nil { + if dv, ok := value.(driver.Value); ok { + primaryValues = append(primaryValues, dv) + } else { + primaryValues = append(primaryValues, value) + } + } } } - if len(wherePrimaryList) != 0 { - afterImageSql.WriteString(" OR (" + strings.Join(wherePrimaryList, " and ") + ") ") + } + + if len(primaryValues) > 0 { + afterImageSql.WriteString(" OR (" + strings.Join(u.buildPrimaryKeyConditions(primaryValueMap, usedPrimaryKeys), " and ") + ") ") + } + + finalArgs := make([]driver.Value, len(selectArgs)+len(primaryValues)) + copy(finalArgs, selectArgs) + copy(finalArgs[len(selectArgs):], primaryValues) + + sqlStr := afterImageSql.String() + log.Infof("build after select sql by insert on duplicate sourceQuery, sql %s", sqlStr) + return sqlStr, finalArgs +} + +func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildPrimaryKeyConditions(primaryValueMap map[string][]interface{}, usedPrimaryKeys map[string]bool) []string { + var conditions []string + for name := range primaryValueMap { + if !usedPrimaryKeys[name] { + conditions = append(conditions, name+" = ? ") } } - selectArgs = append(selectArgs, primaryValues...) - log.Infof("build after select sql by insert on duplicate sourceQuery, sql {}", afterImageSql.String()) - return afterImageSql.String(), selectArgs + return conditions } func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) error { @@ -243,11 +352,10 @@ func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) e // build sql params func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast.InsertStmt, args []driver.Value, insertRows [][]interface{}) (map[string][]driver.Value, error) { - var ( - parameterMap = make(map[string][]driver.Value) - ) + parameterMap := make(map[string][]driver.Value) insertColumns := getInsertColumns(insert) - var placeHolderIndex = 0 + placeHolderIndex := 0 + for _, row := range insertRows { if len(row) != len(insertColumns) { log.Errorf("insert row's column size not equal to insert column size") @@ -256,13 +364,14 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast. for i, col := range insertColumns { columnName := executor.DelEscape(col, types.DBTypeMySQL) val := row[i] - rStr, ok := val.(string) - if ok && strings.EqualFold(rStr, SqlPlaceholder) { - objects := args[placeHolderIndex] - parameterMap[columnName] = append(parameterMap[col], objects) + if str, ok := val.(string); ok && strings.EqualFold(str, SqlPlaceholder) { + if placeHolderIndex >= len(args) { + return nil, fmt.Errorf("not enough parameters for placeholders") + } + parameterMap[columnName] = append(parameterMap[columnName], args[placeHolderIndex]) placeHolderIndex++ } else { - parameterMap[columnName] = append(parameterMap[col], val) + parameterMap[columnName] = append(parameterMap[columnName], val) } } }