diff --git a/cdc/sink/mysql/mysql.go b/cdc/sink/mysql/mysql.go index a5d1f3a4598..c1aab3beeb1 100644 --- a/cdc/sink/mysql/mysql.go +++ b/cdc/sink/mysql/mysql.go @@ -652,8 +652,6 @@ func (s *mysqlSink) execDMLWithMaxRetries(pctx context.Context, dmls *preparedDM return retry.Do(pctx, func() error { writeTimeout, _ := time.ParseDuration(s.params.writeTimeout) writeTimeout += networkDriftDuration - ctx, cancelFunc := context.WithTimeout(pctx, writeTimeout) - defer cancelFunc() failpoint.Inject("MySQLSinkTxnRandomError", func() { failpoint.Return( @@ -665,7 +663,7 @@ func (s *mysqlSink) execDMLWithMaxRetries(pctx context.Context, dmls *preparedDM time.Sleep(time.Hour) }) err := s.statistics.RecordBatchExecution(func() (int, error) { - tx, err := s.db.BeginTx(ctx, nil) + tx, err := s.db.BeginTx(pctx, nil) if err != nil { return 0, logDMLTxnErr( cerror.WrapError(cerror.ErrMySQLTxnError, err), @@ -675,6 +673,7 @@ func (s *mysqlSink) execDMLWithMaxRetries(pctx context.Context, dmls *preparedDM for i, query := range dmls.sqls { args := dmls.values[i] log.Debug("exec row", zap.String("sql", query), zap.Any("args", args)) + ctx, cancelFunc := context.WithTimeout(pctx, writeTimeout) if _, err := tx.ExecContext(ctx, query, args...); err != nil { if rbErr := tx.Rollback(); rbErr != nil { if errors.Cause(rbErr) != context.Canceled { @@ -684,10 +683,12 @@ func (s *mysqlSink) execDMLWithMaxRetries(pctx context.Context, dmls *preparedDM start, s.params.changefeedID, query, dmls.rowCount, dmls.startTs) } } + cancelFunc() return 0, logDMLTxnErr( cerror.WrapError(cerror.ErrMySQLTxnError, err), start, s.params.changefeedID, query, dmls.rowCount, dmls.startTs) } + cancelFunc() } if err = tx.Commit(); err != nil { diff --git a/cdc/sinkv2/eventsink/txn/mysql/mysql.go b/cdc/sinkv2/eventsink/txn/mysql/mysql.go index 4374a181a4d..8403d004b70 100644 --- a/cdc/sinkv2/eventsink/txn/mysql/mysql.go +++ b/cdc/sinkv2/eventsink/txn/mysql/mysql.go @@ -557,8 +557,6 @@ func (s *mysqlBackend) execDMLWithMaxRetries(pctx context.Context, dmls *prepare return retry.Do(pctx, func() error { writeTimeout, _ := time.ParseDuration(s.cfg.WriteTimeout) writeTimeout += networkDriftDuration - ctx, cancelFunc := context.WithTimeout(pctx, writeTimeout) - defer cancelFunc() failpoint.Inject("MySQLSinkTxnRandomError", func() { fmt.Printf("start to random error") @@ -566,19 +564,11 @@ func (s *mysqlBackend) execDMLWithMaxRetries(pctx context.Context, dmls *prepare failpoint.Return(err) }) failpoint.Inject("MySQLSinkHangLongTime", func() { - timer := time.NewTimer(time.Hour) - select { - case <-timer.C: - case <-ctx.Done(): - if !timer.Stop() { - <-timer.C - } - failpoint.Return(context.Canceled) - } + time.Sleep(time.Hour) }) err := s.statistics.RecordBatchExecution(func() (int, error) { - tx, err := s.db.BeginTx(ctx, nil) + tx, err := s.db.BeginTx(pctx, nil) if err != nil { return 0, logDMLTxnErr( cerror.WrapError(cerror.ErrMySQLTxnError, err), @@ -589,6 +579,7 @@ func (s *mysqlBackend) execDMLWithMaxRetries(pctx context.Context, dmls *prepare args := dmls.values[i] log.Debug("exec row", zap.Int("workerID", s.workerID), zap.String("sql", query), zap.Any("args", args)) + ctx, cancelFunc := context.WithTimeout(pctx, writeTimeout) if _, err := tx.ExecContext(ctx, query, args...); err != nil { err := logDMLTxnErr( cerror.WrapError(cerror.ErrMySQLTxnError, err), @@ -598,13 +589,15 @@ func (s *mysqlBackend) execDMLWithMaxRetries(pctx context.Context, dmls *prepare log.Warn("failed to rollback txn", zap.Error(rbErr)) } } + cancelFunc() return 0, err } + cancelFunc() } // we set write source for each txn, // so we can use it to trace the data source - if err = s.setWriteSource(ctx, tx); err != nil { + if err = s.setWriteSource(pctx, tx); err != nil { err := logDMLTxnErr( cerror.WrapError(cerror.ErrMySQLTxnError, err), start, s.changefeed, diff --git a/cdc/sinkv2/eventsink/txn/mysql/mysql_test.go b/cdc/sinkv2/eventsink/txn/mysql/mysql_test.go index 1f3a86065ce..b16ea5b02f3 100644 --- a/cdc/sinkv2/eventsink/txn/mysql/mysql_test.go +++ b/cdc/sinkv2/eventsink/txn/mysql/mysql_test.go @@ -27,7 +27,6 @@ import ( "github.com/DATA-DOG/go-sqlmock" dmysql "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" - "github.com/pingcap/failpoint" "github.com/pingcap/log" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/parser/mysql" @@ -1447,17 +1446,6 @@ func TestPrepareBatchDMLs(t *testing.T) { } } -func TestNetworkPartition(t *testing.T) { - ctx := context.Background() - ms := newMySQLBackendWithoutDB(ctx) - ms.cfg.WriteTimeout = "1s" - _ = failpoint.Enable("github.com/pingcap/tiflow/cdc/sinkv2/eventsink/txn/mysql/MySQLSinkHangLongTime", "return") - defer failpoint.Disable("github.com/pingcap/tiflow/cdc/sinkv2/eventsink/txn/mysql/MySQLSinkHangLongTime") - - err := ms.execDMLWithMaxRetries(ctx, &preparedDMLs{}) - require.Equal(t, context.Canceled, err) -} - func TestGroupRowsByType(t *testing.T) { ctx := context.Background() ms := newMySQLBackendWithoutDB(ctx)