Skip to content
This repository has been archived by the owner on Nov 24, 2023. It is now read-only.

syncer: fix generated column in where condition #60

Merged
merged 9 commits into from
Mar 2, 2019
150 changes: 91 additions & 59 deletions syncer/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,22 @@ type GenColCache struct {
// `schema`.`table` -> column list
columns map[string][]*column

// `schema`.`table` -> tableIndex information
indexes map[string]map[string][]*column

// `schema`.`table` -> a bool slice representing whether it is generated for each column
isGenColumn map[string][]bool
}

// genDMLParam stores pruned columns, data as well as the original columns, data, index
type genDMLParam struct {
schema string
table string
safeMode bool // only used in update
data [][]interface{} // pruned data
originalData [][]interface{} // all data
columns []*column // pruned columns
originalColumns []*column // all columns
originalIndexColumns map[string][]*column // all index information
}

// NewGenColCache creates a GenColCache.
func NewGenColCache() *GenColCache {
c := &GenColCache{}
Expand All @@ -74,35 +83,54 @@ func (c *GenColCache) clearTable(schema, table string) {
key := dbutil.TableName(schema, table)
delete(c.hasGenColumn, key)
delete(c.columns, key)
delete(c.indexes, key)
delete(c.isGenColumn, key)
}

func (c *GenColCache) reset() {
c.hasGenColumn = make(map[string]bool)
c.columns = make(map[string][]*column)
c.indexes = make(map[string]map[string][]*column)
c.isGenColumn = make(map[string][]bool)
}

func genInsertSQLs(schema string, table string, dataSeq [][]interface{}, columns []*column, indexColumns map[string][]*column) ([]string, [][]string, [][]interface{}, error) {
func extractValueFromData(data []interface{}, columns []*column) []interface{} {
value := make([]interface{}, 0, len(data))
for i := range data {
value = append(value, castUnsigned(data[i], columns[i].unsigned, columns[i].tp))
}
return value
}

func genInsertSQLs(param *genDMLParam) ([]string, [][]string, [][]interface{}, error) {
var (
schema = param.schema
table = param.table
dataSeq = param.data
originalDataSeq = param.originalData
columns = param.columns
originalColumns = param.originalColumns
originalIndexColumns = param.originalIndexColumns
)
sqls := make([]string, 0, len(dataSeq))
keys := make([][]string, 0, len(dataSeq))
values := make([][]interface{}, 0, len(dataSeq))
columnList := genColumnList(columns)
columnPlaceholders := genColumnPlaceholders(len(columns))
for _, data := range dataSeq {
for dataIdx, data := range dataSeq {
if len(data) != len(columns) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also check originalColumns and originalDataSeq ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed, data/columns and origianlData/originalColumns have the same length distance ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add column and data length check in generated column prune

return nil, nil, nil, errors.Errorf("insert columns and data mismatch in length: %d (columns) vs %d (data)", len(columns), len(data))
}

value := make([]interface{}, 0, len(data))
for i := range data {
value = append(value, castUnsigned(data[i], columns[i].unsigned, columns[i].tp))
value := extractValueFromData(data, columns)
originalData := originalDataSeq[dataIdx]
var originalValue []interface{}
if len(columns) == len(originalColumns) {
originalValue = value
} else {
originalValue = extractValueFromData(originalData, originalColumns)
}

sql := fmt.Sprintf("REPLACE INTO `%s`.`%s` (%s) VALUES (%s);", schema, table, columnList, columnPlaceholders)
ks := genMultipleKeys(columns, value, indexColumns)
ks := genMultipleKeys(originalColumns, originalValue, originalIndexColumns)
sqls = append(sqls, sql)
values = append(values, value)
keys = append(keys, ks)
Expand All @@ -111,17 +139,29 @@ func genInsertSQLs(schema string, table string, dataSeq [][]interface{}, columns
return sqls, keys, values, nil
}

func genUpdateSQLs(schema string, table string, data [][]interface{}, columns []*column, indexColumns map[string][]*column, safeMode bool) ([]string, [][]string, [][]interface{}, error) {
func genUpdateSQLs(param *genDMLParam) ([]string, [][]string, [][]interface{}, error) {
var (
schema = param.schema
table = param.table
safeMode = param.safeMode
data = param.data
originalData = param.originalData
columns = param.columns
originalColumns = param.originalColumns
originalIndexColumns = param.originalIndexColumns
)
sqls := make([]string, 0, len(data)/2)
keys := make([][]string, 0, len(data)/2)
values := make([][]interface{}, 0, len(data)/2)
columnList := genColumnList(columns)
columnPlaceholders := genColumnPlaceholders(len(columns))
defaultIndexColumns := findFitIndex(indexColumns)
defaultIndexColumns := findFitIndex(originalIndexColumns)

for i := 0; i < len(data); i += 2 {
oldData := data[i]
changedData := data[i+1]
oriOldData := originalData[i]
oriChangedData := originalData[i+1]

if len(oldData) != len(changedData) {
return nil, nil, nil, errors.Errorf("update data mismatch in length: %d (columns) vs %d (data)", len(oldData), len(changedData))
Expand All @@ -131,25 +171,28 @@ func genUpdateSQLs(schema string, table string, data [][]interface{}, columns []
return nil, nil, nil, errors.Errorf("update columns and data mismatch in length: %d (columns) vs %d (data)", len(columns), len(oldData))
}

oldValues := make([]interface{}, 0, len(oldData))
for i := range oldData {
oldValues = append(oldValues, castUnsigned(oldData[i], columns[i].unsigned, columns[i].tp))
}
changedValues := make([]interface{}, 0, len(changedData))
for i := range changedData {
changedValues = append(changedValues, castUnsigned(changedData[i], columns[i].unsigned, columns[i].tp))
oldValues := extractValueFromData(oldData, columns)
changedValues := extractValueFromData(changedData, columns)

var oriOldValues, oriChangedValues []interface{}
if len(columns) == len(originalColumns) {
oriOldValues = oldValues
oriChangedValues = changedValues
} else {
oriOldValues = extractValueFromData(oriOldData, originalColumns)
oriChangedValues = extractValueFromData(oriChangedData, originalColumns)
}

if len(defaultIndexColumns) == 0 {
defaultIndexColumns = getAvailableIndexColumn(indexColumns, oldValues)
defaultIndexColumns = getAvailableIndexColumn(originalIndexColumns, oriOldValues)
}

ks := genMultipleKeys(columns, oldValues, indexColumns)
ks = append(ks, genMultipleKeys(columns, changedValues, indexColumns)...)
ks := genMultipleKeys(originalColumns, oriOldValues, originalIndexColumns)
ks = append(ks, genMultipleKeys(originalColumns, oriChangedValues, originalIndexColumns)...)

if safeMode {
// generate delete sql from old data
sql, value := genDeleteSQL(schema, table, oldValues, columns, defaultIndexColumns)
sql, value := genDeleteSQL(schema, table, oriOldValues, originalColumns, defaultIndexColumns)
sqls = append(sqls, sql)
values = append(values, value)
keys = append(keys, ks)
Expand Down Expand Up @@ -177,9 +220,9 @@ func genUpdateSQLs(schema string, table string, data [][]interface{}, columns []
kvs := genKVs(updateColumns)
value = append(value, updateValues...)

whereColumns, whereValues := columns, oldValues
whereColumns, whereValues := originalColumns, oriOldValues
if len(defaultIndexColumns) > 0 {
whereColumns, whereValues = getColumnData(columns, defaultIndexColumns, oldValues)
whereColumns, whereValues = getColumnData(originalColumns, defaultIndexColumns, oriOldValues)
}

where := genWhere(whereColumns, whereValues)
Expand All @@ -194,7 +237,14 @@ func genUpdateSQLs(schema string, table string, data [][]interface{}, columns []
return sqls, keys, values, nil
}

func genDeleteSQLs(schema string, table string, dataSeq [][]interface{}, columns []*column, indexColumns map[string][]*column) ([]string, [][]string, [][]interface{}, error) {
func genDeleteSQLs(param *genDMLParam) ([]string, [][]string, [][]interface{}, error) {
var (
schema = param.schema
table = param.table
dataSeq = param.originalData
columns = param.originalColumns
indexColumns = param.originalIndexColumns
)
sqls := make([]string, 0, len(dataSeq))
keys := make([][]string, 0, len(dataSeq))
values := make([][]interface{}, 0, len(dataSeq))
Expand All @@ -205,10 +255,7 @@ func genDeleteSQLs(schema string, table string, dataSeq [][]interface{}, columns
return nil, nil, nil, errors.Errorf("delete columns and data mismatch in length: %d (columns) vs %d (data)", len(columns), len(data))
}

value := make([]interface{}, 0, len(data))
for i := range data {
value = append(value, castUnsigned(data[i], columns[i].unsigned, columns[i].tp))
}
value := extractValueFromData(data, columns)

if len(defaultIndexColumns) == 0 {
defaultIndexColumns = getAvailableIndexColumn(indexColumns, value)
Expand Down Expand Up @@ -488,29 +535,25 @@ func (s *Syncer) mappingDML(schema, table string, columns []string, data [][]int
// pruneGeneratedColumnDML filters columns list, data and index removing all
// generated column. because generated column is not support setting value
// directly in DML, we must remove generated column from DML, including column
// list, data list and all indexes including generated columns.
func pruneGeneratedColumnDML(columns []*column, data [][]interface{}, index map[string][]*column, schema, table string, cache *GenColCache) ([]*column, [][]interface{}, map[string][]*column, error) {
// list and data list including generated columns.
func pruneGeneratedColumnDML(columns []*column, data [][]interface{}, schema, table string, cache *GenColCache) ([]*column, [][]interface{}, error) {
var (
cacheKey = dbutil.TableName(schema, table)
cacheStatus = cache.status(cacheKey)
)

if cacheStatus == noGenColumn {
return columns, data, index, nil
return columns, data, nil
}
if cacheStatus == hasGenColumn {
rows := make([][]interface{}, 0, len(data))
filters, ok1 := cache.isGenColumn[cacheKey]
if !ok1 {
return nil, nil, nil, errors.NotFoundf("cache key %s in isGenColumn", cacheKey)
return nil, nil, errors.NotFoundf("cache key %s in isGenColumn", cacheKey)
}
cols, ok2 := cache.columns[cacheKey]
if !ok2 {
return nil, nil, nil, errors.NotFoundf("cache key %s in columns", cacheKey)
}
idxes, ok3 := cache.indexes[cacheKey]
if !ok3 {
return nil, nil, nil, errors.NotFoundf("cache key %s in indexes", cacheKey)
return nil, nil, errors.NotFoundf("cache key %s in columns", cacheKey)
}
for _, row := range data {
value := make([]interface{}, 0, len(row))
Expand All @@ -521,7 +564,7 @@ func pruneGeneratedColumnDML(columns []*column, data [][]interface{}, index map[
}
rows = append(rows, value)
}
return cols, rows, idxes, nil
return cols, rows, nil
}

var (
Expand All @@ -541,13 +584,12 @@ func pruneGeneratedColumnDML(columns []*column, data [][]interface{}, index map[

if !needPrune {
cache.hasGenColumn[cacheKey] = false
return columns, data, index, nil
return columns, data, nil
}

var (
cols = make([]*column, 0, len(columns))
rows = make([][]interface{}, 0, len(data))
idxes = make(map[string][]*column)
cols = make([]*column, 0, len(columns))
rows = make([][]interface{}, 0, len(data))
)

for i := range columns {
Expand All @@ -556,6 +598,9 @@ func pruneGeneratedColumnDML(columns []*column, data [][]interface{}, index map[
}
}
for _, row := range data {
if len(row) != len(columns) {
return nil, nil, errors.Errorf("prune DML columns and data mismatch in length: %d (columns) %d (data)", len(columns), len(data))
}
value := make([]interface{}, 0, len(row))
for i := range row {
if !colIndexfilters[i] {
Expand All @@ -564,22 +609,9 @@ func pruneGeneratedColumnDML(columns []*column, data [][]interface{}, index map[
}
rows = append(rows, value)
}
for key, keyCols := range index {
hasGenColumn := false
for _, col := range keyCols {
if _, ok := genColumnNames[col.name]; ok {
hasGenColumn = true
break
}
}
if !hasGenColumn {
idxes[key] = keyCols
}
}
cache.hasGenColumn[cacheKey] = true
cache.columns[cacheKey] = cols
cache.indexes[cacheKey] = idxes
cache.isGenColumn[cacheKey] = colIndexfilters

return cols, rows, idxes, nil
return cols, rows, nil
}
19 changes: 14 additions & 5 deletions syncer/syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ func (s *Syncer) Run(ctx context.Context) (err error) {
if err != nil {
return errors.Trace(err)
}
tblColumns, rowData, tblIndexColumns, err := pruneGeneratedColumnDML(table.columns, rows, table.indexColumns, schemaName, tableName, s.genColsCache)
prunedColumns, prunedRows, err := pruneGeneratedColumnDML(table.columns, rows, schemaName, tableName, s.genColsCache)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -1168,11 +1168,19 @@ func (s *Syncer) Run(ctx context.Context) (err error) {
if err != nil {
return errors.Trace(err)
}

param := &genDMLParam{
schema: table.schema,
table: table.name,
data: prunedRows,
originalData: rows,
columns: prunedColumns,
originalColumns: table.columns,
originalIndexColumns: table.indexColumns,
}
switch e.Header.EventType {
case replication.WRITE_ROWS_EVENTv0, replication.WRITE_ROWS_EVENTv1, replication.WRITE_ROWS_EVENTv2:
if !applied {
sqls, keys, args, err = genInsertSQLs(table.schema, table.name, rowData, tblColumns, tblIndexColumns)
sqls, keys, args, err = genInsertSQLs(param)
if err != nil {
return errors.Errorf("gen insert sqls failed: %v, schema: %s, table: %s", errors.Trace(err), table.schema, table.name)
}
Expand All @@ -1195,7 +1203,8 @@ func (s *Syncer) Run(ctx context.Context) (err error) {
}
case replication.UPDATE_ROWS_EVENTv0, replication.UPDATE_ROWS_EVENTv1, replication.UPDATE_ROWS_EVENTv2:
if !applied {
sqls, keys, args, err = genUpdateSQLs(table.schema, table.name, rowData, tblColumns, tblIndexColumns, safeMode.Enable())
param.safeMode = safeMode.Enable()
sqls, keys, args, err = genUpdateSQLs(param)
if err != nil {
return errors.Errorf("gen update sqls failed: %v, schema: %s, table: %s", err, table.schema, table.name)
}
Expand All @@ -1219,7 +1228,7 @@ func (s *Syncer) Run(ctx context.Context) (err error) {
}
case replication.DELETE_ROWS_EVENTv0, replication.DELETE_ROWS_EVENTv1, replication.DELETE_ROWS_EVENTv2:
if !applied {
sqls, keys, args, err = genDeleteSQLs(table.schema, table.name, rowData, tblColumns, tblIndexColumns)
sqls, keys, args, err = genDeleteSQLs(param)
if err != nil {
return errors.Errorf("gen delete sqls failed: %v, schema: %s, table: %s", err, table.schema, table.name)
}
Expand Down
Loading