diff --git a/dm/syncer/expr_filter_group.go b/dm/syncer/expr_filter_group.go index eea33b47baa..a858b95783c 100644 --- a/dm/syncer/expr_filter_group.go +++ b/dm/syncer/expr_filter_group.go @@ -36,10 +36,9 @@ type ExprFilterGroup struct { updateNewExprs map[string][]expression.Expression // tableName -> expr deleteExprs map[string][]expression.Expression // tableName -> expr - hasInsertFilter map[string]struct{} // set(tableName) - hasUpdateOldFilter map[string]struct{} // set(tableName) - hasUpdateNewFilter map[string]struct{} // set(tableName) - hasDeleteFilter map[string]struct{} // set(tableName) + hasInsertFilter map[string]struct{} // set(tableName) + hasUpdateFilter map[string]struct{} // set(tableName) + hasDeleteFilter map[string]struct{} // set(tableName) tidbCtx sessionctx.Context logCtx *tcontext.Context @@ -48,17 +47,16 @@ type ExprFilterGroup struct { // NewExprFilterGroup creates an ExprFilterGroup. func NewExprFilterGroup(logCtx *tcontext.Context, tidbCtx sessionctx.Context, exprConfig []*config.ExpressionFilter) *ExprFilterGroup { ret := &ExprFilterGroup{ - configs: map[string][]*config.ExpressionFilter{}, - insertExprs: map[string][]expression.Expression{}, - updateOldExprs: map[string][]expression.Expression{}, - updateNewExprs: map[string][]expression.Expression{}, - deleteExprs: map[string][]expression.Expression{}, - hasInsertFilter: map[string]struct{}{}, - hasUpdateOldFilter: map[string]struct{}{}, - hasUpdateNewFilter: map[string]struct{}{}, - hasDeleteFilter: map[string]struct{}{}, - tidbCtx: tidbCtx, - logCtx: logCtx, + configs: map[string][]*config.ExpressionFilter{}, + insertExprs: map[string][]expression.Expression{}, + updateOldExprs: map[string][]expression.Expression{}, + updateNewExprs: map[string][]expression.Expression{}, + deleteExprs: map[string][]expression.Expression{}, + hasInsertFilter: map[string]struct{}{}, + hasUpdateFilter: map[string]struct{}{}, + hasDeleteFilter: map[string]struct{}{}, + tidbCtx: tidbCtx, + logCtx: logCtx, } for _, c := range exprConfig { tableName := dbutil.TableName(c.Schema, c.Table) @@ -67,11 +65,8 @@ func NewExprFilterGroup(logCtx *tcontext.Context, tidbCtx sessionctx.Context, ex if c.InsertValueExpr != "" { ret.hasInsertFilter[tableName] = struct{}{} } - if c.UpdateOldValueExpr != "" { - ret.hasUpdateOldFilter[tableName] = struct{}{} - } - if c.UpdateNewValueExpr != "" { - ret.hasUpdateNewFilter[tableName] = struct{}{} + if c.UpdateOldValueExpr != "" || c.UpdateNewValueExpr != "" { + ret.hasUpdateFilter[tableName] = struct{}{} } if c.DeleteValueExpr != "" { ret.hasDeleteFilter[tableName] = struct{}{} @@ -117,7 +112,7 @@ func (g *ExprFilterGroup) GetUpdateExprs(table *filter.Table, ti *model.TableInf return retOld, retNew, nil } - if _, ok := g.hasUpdateOldFilter[tableID]; ok { + if _, ok := g.hasUpdateFilter[tableID]; ok { for _, c := range g.configs[tableID] { if c.UpdateOldValueExpr != "" { expr, err := getSimpleExprOfTable(g.tidbCtx, c.UpdateOldValueExpr, ti, g.logCtx.L()) @@ -129,11 +124,7 @@ func (g *ExprFilterGroup) GetUpdateExprs(table *filter.Table, ti *model.TableInf } else { g.updateOldExprs[tableID] = append(g.updateOldExprs[tableID], expression.NewOne()) } - } - } - if _, ok := g.hasUpdateNewFilter[tableID]; ok { - for _, c := range g.configs[tableID] { if c.UpdateNewValueExpr != "" { expr, err := getSimpleExprOfTable(g.tidbCtx, c.UpdateNewValueExpr, ti, g.logCtx.L()) if err != nil { diff --git a/dm/syncer/expr_filter_group_test.go b/dm/syncer/expr_filter_group_test.go index 413562a3b28..5501c450b13 100644 --- a/dm/syncer/expr_filter_group_test.go +++ b/dm/syncer/expr_filter_group_test.go @@ -15,18 +15,20 @@ package syncer import ( "context" + "testing" - . "github.com/pingcap/check" + ddl2 "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/util/filter" "github.com/pingcap/tiflow/dm/config" tcontext "github.com/pingcap/tiflow/dm/pkg/context" "github.com/pingcap/tiflow/dm/pkg/log" "github.com/pingcap/tiflow/dm/pkg/schema" "github.com/pingcap/tiflow/dm/pkg/utils" - "github.com/pingcap/tiflow/dm/syncer/dbconn" + "github.com/stretchr/testify/require" ) -func (s *testFilterSuite) TestSkipDMLByExpression(c *C) { +func TestSkipDMLByExpression(t *testing.T) { cases := []struct { exprStr string tableStr string @@ -91,19 +93,18 @@ create table t ( Name: tblName, } ) - c.Assert(log.InitLogger(&log.Config{Level: "debug"}), IsNil) + require.NoError(t, log.InitLogger(&log.Config{Level: "debug"})) - dbConn := dbconn.NewDBConn(&config.SubTaskConfig{}, s.baseConn) for _, ca := range cases { - schemaTracker, err := schema.NewTestTracker(ctx, "unit-test", dbConn, log.L()) - c.Assert(err, IsNil) - c.Assert(schemaTracker.CreateSchemaIfNotExists(dbName), IsNil) + schemaTracker, err := schema.NewTestTracker(ctx, "unit-test", nil, log.L()) + require.NoError(t, err) + require.NoError(t, schemaTracker.CreateSchemaIfNotExists(dbName)) stmt, err := parseSQL(ca.tableStr) - c.Assert(err, IsNil) - c.Assert(schemaTracker.Exec(ctx, dbName, stmt), IsNil) + require.NoError(t, err) + require.NoError(t, schemaTracker.Exec(ctx, dbName, stmt)) ti, err := schemaTracker.GetTableInfo(table) - c.Assert(err, IsNil) + require.NoError(t, err) exprConfig := []*config.ExpressionFilter{ { @@ -115,26 +116,26 @@ create table t ( sessCtx := utils.NewSessionCtx(map[string]string{"time_zone": "UTC"}) g := NewExprFilterGroup(tcontext.Background(), sessCtx, exprConfig) exprs, err := g.GetInsertExprs(table, ti) - c.Assert(err, IsNil) - c.Assert(exprs, HasLen, 1) + require.NoError(t, err) + require.Len(t, exprs, 1) expr := exprs[0] ca.skippedRow = extractValueFromData(ca.skippedRow, ti.Columns, ti) ca.passedRow = extractValueFromData(ca.passedRow, ti.Columns, ti) skip, err := SkipDMLByExpression(sessCtx, ca.skippedRow, expr, ti.Columns) - c.Assert(err, IsNil) - c.Assert(skip, Equals, true) + require.NoError(t, err) + require.True(t, skip) skip, err = SkipDMLByExpression(sessCtx, ca.passedRow, expr, ti.Columns) - c.Assert(err, IsNil) - c.Assert(skip, Equals, false) + require.NoError(t, err) + require.False(t, skip) schemaTracker.Close() } } -func (s *testFilterSuite) TestAllBinaryProtocolTypes(c *C) { +func TestAllBinaryProtocolTypes(t *testing.T) { cases := []struct { exprStr string tableStr string @@ -355,20 +356,19 @@ create table t ( Name: tblName, } ) - c.Assert(log.InitLogger(&log.Config{Level: "debug"}), IsNil) + require.NoError(t, log.InitLogger(&log.Config{Level: "debug"})) - dbConn := dbconn.NewDBConn(&config.SubTaskConfig{}, s.baseConn) for _, ca := range cases { - c.Log(ca.tableStr) - schemaTracker, err := schema.NewTestTracker(ctx, "unit-test", dbConn, log.L()) - c.Assert(err, IsNil) - c.Assert(schemaTracker.CreateSchemaIfNotExists(dbName), IsNil) + t.Log(ca.tableStr) + schemaTracker, err := schema.NewTestTracker(ctx, "unit-test", nil, log.L()) + require.NoError(t, err) + require.NoError(t, schemaTracker.CreateSchemaIfNotExists(dbName)) stmt, err := parseSQL(ca.tableStr) - c.Assert(err, IsNil) - c.Assert(schemaTracker.Exec(ctx, dbName, stmt), IsNil) + require.NoError(t, err) + require.NoError(t, schemaTracker.Exec(ctx, dbName, stmt)) ti, err := schemaTracker.GetTableInfo(table) - c.Assert(err, IsNil) + require.NoError(t, err) exprConfig := []*config.ExpressionFilter{ { @@ -380,26 +380,26 @@ create table t ( sessCtx := utils.NewSessionCtx(map[string]string{"time_zone": "UTC"}) g := NewExprFilterGroup(tcontext.Background(), sessCtx, exprConfig) exprs, err := g.GetInsertExprs(table, ti) - c.Assert(err, IsNil) - c.Assert(exprs, HasLen, 1) + require.NoError(t, err) + require.Len(t, exprs, 1) expr := exprs[0] ca.skippedRow = extractValueFromData(ca.skippedRow, ti.Columns, ti) ca.passedRow = extractValueFromData(ca.passedRow, ti.Columns, ti) skip, err := SkipDMLByExpression(sessCtx, ca.skippedRow, expr, ti.Columns) - c.Assert(err, IsNil) - c.Assert(skip, Equals, true) + require.NoError(t, err) + require.True(t, skip) skip, err = SkipDMLByExpression(sessCtx, ca.passedRow, expr, ti.Columns) - c.Assert(err, IsNil) - c.Assert(skip, Equals, false) + require.NoError(t, err) + require.False(t, skip) schemaTracker.Close() } } -func (s *testFilterSuite) TestExpressionContainsNonExistColumn(c *C) { +func TestExpressionContainsNonExistColumn(t *testing.T) { var ( ctx = context.Background() dbName = "test" @@ -415,16 +415,15 @@ create table t ( exprStr = "d > 1" ) - dbConn := dbconn.NewDBConn(&config.SubTaskConfig{}, s.baseConn) - schemaTracker, err := schema.NewTestTracker(ctx, "unit-test", dbConn, log.L()) - c.Assert(err, IsNil) - c.Assert(schemaTracker.CreateSchemaIfNotExists(dbName), IsNil) + schemaTracker, err := schema.NewTestTracker(ctx, "unit-test", nil, log.L()) + require.NoError(t, err) + require.NoError(t, schemaTracker.CreateSchemaIfNotExists(dbName)) stmt, err := parseSQL(tableStr) - c.Assert(err, IsNil) - c.Assert(schemaTracker.Exec(ctx, dbName, stmt), IsNil) + require.NoError(t, err) + require.NoError(t, schemaTracker.Exec(ctx, dbName, stmt)) ti, err := schemaTracker.GetTableInfo(table) - c.Assert(err, IsNil) + require.NoError(t, err) exprConfig := []*config.ExpressionFilter{ { @@ -436,16 +435,70 @@ create table t ( sessCtx := utils.NewSessionCtx(map[string]string{"time_zone": "UTC"}) g := NewExprFilterGroup(tcontext.Background(), sessCtx, exprConfig) exprs, err := g.GetInsertExprs(table, ti) - c.Assert(err, IsNil) - c.Assert(exprs, HasLen, 1) + require.NoError(t, err) + require.Len(t, exprs, 1) expr := exprs[0] - c.Assert(expr.String(), Equals, "0") + require.Equal(t, "0", expr.String()) // skip nothing skip, err := SkipDMLByExpression(sessCtx, []interface{}{0}, expr, ti.Columns) - c.Assert(err, IsNil) - c.Assert(skip, Equals, false) + require.NoError(t, err) + require.False(t, skip) skip, err = SkipDMLByExpression(sessCtx, []interface{}{2}, expr, ti.Columns) - c.Assert(err, IsNil) - c.Assert(skip, Equals, false) + require.NoError(t, err) + require.False(t, skip) +} + +func TestGetUpdateExprsSameLength(t *testing.T) { + var ( + dbName = "test" + tblName = "t" + table = &filter.Table{ + Schema: dbName, + Name: tblName, + } + tableStr = ` +create table t ( + c varchar(20) +);` + exprStr = "c > 1" + sessCtx = utils.NewSessionCtx(map[string]string{"time_zone": "UTC"}) + ) + + cases := []*config.ExpressionFilter{ + { + Schema: dbName, + Table: tblName, + InsertValueExpr: exprStr, + }, + { + Schema: dbName, + Table: tblName, + UpdateOldValueExpr: exprStr, + }, + { + Schema: dbName, + Table: tblName, + UpdateNewValueExpr: exprStr, + }, + { + Schema: dbName, + Table: tblName, + UpdateOldValueExpr: exprStr, + UpdateNewValueExpr: exprStr, + }, + } + + stmt, err := parseSQL(tableStr) + require.NoError(t, err) + tableInfo, err := ddl2.BuildTableInfoFromAST(stmt.(*ast.CreateTableStmt)) + require.NoError(t, err) + + for i, c := range cases { + t.Logf("case #%d", i) + g := NewExprFilterGroup(tcontext.Background(), sessCtx, []*config.ExpressionFilter{c}) + oldExprs, newExprs, err := g.GetUpdateExprs(table, tableInfo) + require.NoError(t, err) + require.Equal(t, len(oldExprs), len(newExprs)) + } } diff --git a/dm/tests/expression_filter/conf/dm-task2.yaml b/dm/tests/expression_filter/conf/dm-task2.yaml index 817638d5de8..d64c40bb3ee 100644 --- a/dm/tests/expression_filter/conf/dm-task2.yaml +++ b/dm/tests/expression_filter/conf/dm-task2.yaml @@ -19,6 +19,7 @@ mysql-instances: - "update_new_lt_100" - "update_old_and_new" - "only_muller" + - "e02" expression-filter: even_c: @@ -50,6 +51,10 @@ expression-filter: schema: "expr_filter" table: "t6" insert-value-expr: "name != 'Müller'" + e02: + schema: expr_filter + table: t7 + update-new-value-expr: "r = 'a'" black-white-list: # compatible with deprecated config instance: diff --git a/dm/tests/expression_filter/data/db1.increment2.sql b/dm/tests/expression_filter/data/db1.increment2.sql index 70eaceaa5e6..152ed7e6076 100644 --- a/dm/tests/expression_filter/data/db1.increment2.sql +++ b/dm/tests/expression_filter/data/db1.increment2.sql @@ -44,5 +44,8 @@ update t5 set should_skip = 0, c = 3 where c = 1; insert into t6 (id, name, msg) values (1, 'Müller', 'Müller'), (2, 'X Æ A-12', 'X Æ A-12'); alter table t6 add column name2 varchar(20) character set latin1 default 'Müller'; +-- test https://github.com/pingcap/tiflow/issues/7774 +UPDATE t7 SET s = s + 1 WHERE a = 1; + -- trigger a flush alter table t5 add column dummy int; diff --git a/dm/tests/expression_filter/data/db1.prepare2.sql b/dm/tests/expression_filter/data/db1.prepare2.sql index 6cc6832fcd2..be88910616b 100644 --- a/dm/tests/expression_filter/data/db1.prepare2.sql +++ b/dm/tests/expression_filter/data/db1.prepare2.sql @@ -10,3 +10,5 @@ create table t2 (id int primary key, create table t6 (id int, name varchar(20), msg text, primary key(`id`)) character set latin1; insert into t6 (id, name, msg) values (0, 'Müller', 'Müller'); +CREATE TABLE t7 (a BIGINT PRIMARY KEY, r VARCHAR(10), s INT); +INSERT INTO t7 VALUES (1, 'a', 2); diff --git a/dm/tests/expression_filter/run.sh b/dm/tests/expression_filter/run.sh index 732a005bd4a..52136ac5fc5 100755 --- a/dm/tests/expression_filter/run.sh +++ b/dm/tests/expression_filter/run.sh @@ -65,10 +65,15 @@ function complex_behaviour() { run_sql_tidb "select count(10) from expr_filter.t6 where name != 'Müller'" check_contains "count(10): 0" + run_sql_tidb "select count(11) from expr_filter.t7 where r = 'a' and s = 2" + check_contains "count(11): 1" + run_sql_tidb "select count(12) from expr_filter.t7 where r = 'a' and s = 3" + check_contains "count(12): 0" + insert_num=$(grep -o '"number of filtered insert"=[0-9]\+' $WORK_DIR/worker1/log/dm-worker.log | grep -o '[0-9]\+' | awk '{n += $1}; END{print n}') [ $insert_num -eq 6 ] update_num=$(grep -o '"number of filtered update"=[0-9]\+' $WORK_DIR/worker1/log/dm-worker.log | grep -o '[0-9]\+' | awk '{n += $1}; END{print n}') - [ $update_num -eq 3 ] + [ $update_num -eq 4 ] run_dm_ctl $WORK_DIR "127.0.0.1:$MASTER_PORT" \ "stop-task test"