Skip to content

Commit

Permalink
bugfix apache#704
Browse files Browse the repository at this point in the history
  • Loading branch information
AsterZephyr committed Dec 2, 2024
1 parent ad092d5 commit abdd938
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 65 deletions.
11 changes: 7 additions & 4 deletions pkg/datasource/sql/types/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package types

import (
"database/sql/driver"
"encoding/base64"
"encoding/json"
"reflect"
Expand Down Expand Up @@ -117,14 +118,16 @@ type RecordImage struct {
// Rows data row
Rows []RowImage `json:"rows"`
// TableMeta table information schema
TableMeta *TableMeta `json:"-"`
TableMeta *TableMeta `json:"-"`
PrimaryKeyMap map[string][]driver.Value `json:"primaryKeyMap,omitempty"`
}

func NewEmptyRecordImage(tableMeta *TableMeta, sqlType SQLType) *RecordImage {
return &RecordImage{
TableName: tableMeta.TableName,
TableMeta: tableMeta,
SQLType: sqlType,
TableName: tableMeta.TableName,
TableMeta: tableMeta,
SQLType: sqlType,
PrimaryKeyMap: make(map[string][]driver.Value),
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,68 +97,136 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildBeforeImageSQL(insertStmt *a
if err := checkDuplicateKeyUpdate(insertStmt, metaData); err != nil {
return "", nil, err
}
var selectArgs []driver.Value

// Reset primary keys map
u.BeforeImageSqlPrimaryKeys = make(map[string]bool)

pkIndexMap := u.getPkIndex(insertStmt, metaData)
var pkIndexArray []int
for _, val := range pkIndexMap {
tmpVal := val
pkIndexArray = append(pkIndexArray, tmpVal)
pkIndexArray = append(pkIndexArray, val)
}
insertRows, err := getInsertRows(insertStmt, pkIndexArray)
if err != nil {
return "", nil, err
}
insertNum := len(insertRows)

paramMap, err := u.buildImageParameters(insertStmt, args, insertRows)
if err != nil {
return "", nil, err
}

sql := strings.Builder{}
sql.WriteString("SELECT * FROM " + metaData.TableName + " ")
// 如果没有参数或没有主键索引,直接返回空
if len(paramMap) == 0 || len(metaData.Indexs) == 0 {
return "", nil, nil
}

// 检查是否有主键
hasPK := false
for _, index := range metaData.Indexs {
if strings.EqualFold("PRIMARY", index.Name) {
hasPK = true
break
}
}
if !hasPK {
return "", nil, nil
}

var sql strings.Builder
sql.WriteString("SELECT * FROM " + metaData.TableName + " ")

var selectArgs []driver.Value
isContainWhere := false
for i := 0; i < insertNum; i++ {
finalI := i
paramAppenderTempList := make([]driver.Value, 0)
hasConditions := false

for i := 0; i < len(insertRows); i++ {
var rowConditions []string
var rowArgs []driver.Value
usedParams := make(map[string]bool)

// First try unique indexes
for _, index := range metaData.Indexs {
//unique index
if index.NonUnique || isIndexValueNotNull(index, paramMap, finalI) == false {
if index.NonUnique || strings.EqualFold("PRIMARY", index.Name) {
continue
}
columnIsNull := true
uniqueList := make([]string, 0)
for _, columnMeta := range index.Columns {
columnName := columnMeta.ColumnName
imageParameters, ok := paramMap[columnName]
if !ok && columnMeta.ColumnDef != nil {
if strings.EqualFold("PRIMARY", index.Name) {
u.BeforeImageSqlPrimaryKeys[columnName] = true
}
uniqueList = append(uniqueList, columnName+" = DEFAULT("+columnName+") ")
columnIsNull = false
continue

if !isIndexValueNotNull(index, paramMap, i) {
continue
}

var indexConditions []string
var indexArgs []driver.Value
allColumnsPresent := true
for _, colMeta := range index.Columns {
columnName := colMeta.ColumnName
if params, ok := paramMap[columnName]; ok && len(params) > i && params[i] != nil {
indexConditions = append(indexConditions, columnName+" = ? ")
indexArgs = append(indexArgs, params[i])
usedParams[columnName] = true
} else if colMeta.ColumnDef != nil {
indexConditions = append(indexConditions, columnName+" = DEFAULT("+columnName+")")
} else {
allColumnsPresent = false
break
}
if strings.EqualFold("PRIMARY", index.Name) {
u.BeforeImageSqlPrimaryKeys[columnName] = true
}

if allColumnsPresent && len(indexConditions) > 0 {
rowConditions = append(rowConditions, "("+strings.Join(indexConditions, " and ")+")")
rowArgs = append(rowArgs, indexArgs...)
hasConditions = true
}
}

// Then try primary key
for _, index := range metaData.Indexs {
if !strings.EqualFold("PRIMARY", index.Name) {
continue
}

var pkConditions []string
var pkArgs []driver.Value
for _, colMeta := range index.Columns {
columnName := colMeta.ColumnName
u.BeforeImageSqlPrimaryKeys[columnName] = true
if params, ok := paramMap[columnName]; ok && len(params) > i && params[i] != nil && !usedParams[columnName] {
pkConditions = append(pkConditions, columnName+" = ? ")
pkArgs = append(pkArgs, params[i])
}
columnIsNull = false
uniqueList = append(uniqueList, columnName+" = ? ")
paramAppenderTempList = append(paramAppenderTempList, imageParameters[finalI])
}

if !columnIsNull {
if isContainWhere {
sql.WriteString(" OR (" + strings.Join(uniqueList, " and ") + ") ")
} else {
sql.WriteString(" WHERE (" + strings.Join(uniqueList, " and ") + ") ")
isContainWhere = true
if len(pkConditions) > 0 {
rowConditions = append(rowConditions, "("+strings.Join(pkConditions, " and ")+")")
rowArgs = append(rowArgs, pkArgs...)
hasConditions = true
}
}

if len(rowConditions) > 0 {
if !isContainWhere {
sql.WriteString("WHERE ")
isContainWhere = true
} else {
sql.WriteString(" OR ")
}
for j, condition := range rowConditions {
if j > 0 {
sql.WriteString(" OR ")
}
sql.WriteString(condition + " ")
}
selectArgs = append(selectArgs, rowArgs...)
}
selectArgs = append(selectArgs, paramAppenderTempList...)
}
log.Infof("build select sql by insert on duplicate sourceQuery, sql {}", sql.String())
return sql.String(), selectArgs, nil

if !hasConditions {
return "", nil, nil
}

sqlStr := sql.String()
log.Infof("build select sql by insert on duplicate sourceQuery, sql: %s", sqlStr)
return sqlStr, selectArgs, nil
}

func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) {
Expand All @@ -168,18 +236,22 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, e
log.Errorf("build prepare stmt: %+v", err)
return nil, err
}
defer stmt.Close()

tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O
metaData := execCtx.MetaDataMap[tableName]

rows, err := stmt.Query(selectArgs)
if err != nil {
log.Errorf("stmt query: %+v", err)
return nil, err
}
tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O
metaData := execCtx.MetaDataMap[tableName]
defer rows.Close()

image, err := u.buildRecordImages(rows, &metaData)
if err != nil {
return nil, err
}

return []*types.RecordImage{image}, nil
}

Expand All @@ -190,6 +262,13 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Co
if len(beforeImages) > 0 {
beforeImage = beforeImages[0]
}

// 如果没有before image,直接返回原始SQL和参数
if beforeImage == nil || len(beforeImage.Rows) == 0 {
return selectSQL, selectArgs
}

// 收集主键值
primaryValueMap := make(map[string][]interface{})
for _, row := range beforeImage.Rows {
for _, col := range row.Columns {
Expand All @@ -200,23 +279,53 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Co
}

var afterImageSql strings.Builder
var primaryValues []driver.Value
afterImageSql.WriteString(selectSQL)
for i := 0; i < len(beforeImage.Rows); i++ {
wherePrimaryList := make([]string, 0)
for name, value := range primaryValueMap {
if !u.BeforeImageSqlPrimaryKeys[name] {
wherePrimaryList = append(wherePrimaryList, name+" = ? ")
primaryValues = append(primaryValues, value[i])

// 如果原始SQL已经包含了所有需要的条件,直接返回
if len(primaryValueMap) == 0 || len(selectArgs) == len(beforeImage.Rows)*len(primaryValueMap) {
return selectSQL, selectArgs
}

// 添加主键条件
var primaryValues []driver.Value
usedPrimaryKeys := make(map[string]bool)

for name := range primaryValueMap {
if !u.BeforeImageSqlPrimaryKeys[name] {
usedPrimaryKeys[name] = true
for i := 0; i < len(beforeImage.Rows); i++ {
if value := primaryValueMap[name][i]; value != nil {
if dv, ok := value.(driver.Value); ok {
primaryValues = append(primaryValues, dv)
} else {
primaryValues = append(primaryValues, value)
}
}
}
}
if len(wherePrimaryList) != 0 {
afterImageSql.WriteString(" OR (" + strings.Join(wherePrimaryList, " and ") + ") ")
}

if len(primaryValues) > 0 {
afterImageSql.WriteString(" OR (" + strings.Join(u.buildPrimaryKeyConditions(primaryValueMap, usedPrimaryKeys), " and ") + ") ")
}

finalArgs := make([]driver.Value, len(selectArgs)+len(primaryValues))
copy(finalArgs, selectArgs)
copy(finalArgs[len(selectArgs):], primaryValues)

sqlStr := afterImageSql.String()
log.Infof("build after select sql by insert on duplicate sourceQuery, sql %s", sqlStr)
return sqlStr, finalArgs
}

func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildPrimaryKeyConditions(primaryValueMap map[string][]interface{}, usedPrimaryKeys map[string]bool) []string {
var conditions []string
for name := range primaryValueMap {
if !usedPrimaryKeys[name] {
conditions = append(conditions, name+" = ? ")
}
}
selectArgs = append(selectArgs, primaryValues...)
log.Infof("build after select sql by insert on duplicate sourceQuery, sql {}", afterImageSql.String())
return afterImageSql.String(), selectArgs
return conditions
}

func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) error {
Expand All @@ -243,11 +352,10 @@ func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) e

// build sql params
func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast.InsertStmt, args []driver.Value, insertRows [][]interface{}) (map[string][]driver.Value, error) {
var (
parameterMap = make(map[string][]driver.Value)
)
parameterMap := make(map[string][]driver.Value)
insertColumns := getInsertColumns(insert)
var placeHolderIndex = 0
placeHolderIndex := 0

for _, row := range insertRows {
if len(row) != len(insertColumns) {
log.Errorf("insert row's column size not equal to insert column size")
Expand All @@ -256,13 +364,14 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast.
for i, col := range insertColumns {
columnName := executor.DelEscape(col, types.DBTypeMySQL)
val := row[i]
rStr, ok := val.(string)
if ok && strings.EqualFold(rStr, SqlPlaceholder) {
objects := args[placeHolderIndex]
parameterMap[columnName] = append(parameterMap[col], objects)
if str, ok := val.(string); ok && strings.EqualFold(str, SqlPlaceholder) {
if placeHolderIndex >= len(args) {
return nil, fmt.Errorf("not enough parameters for placeholders")
}
parameterMap[columnName] = append(parameterMap[columnName], args[placeHolderIndex])
placeHolderIndex++
} else {
parameterMap[columnName] = append(parameterMap[col], val)
parameterMap[columnName] = append(parameterMap[columnName], val)
}
}
}
Expand Down

0 comments on commit abdd938

Please sign in to comment.