From 84fa999186d37873b468660187ce4407894831c2 Mon Sep 17 00:00:00 2001 From: ekexium Date: Thu, 24 Nov 2022 14:33:58 +0800 Subject: [PATCH] txn: support multi-table join in nt-dml (#39139) ref pingcap/tidb#33485 --- metrics/telemetry.go | 3 + parser/ast/dml.go | 25 ++- session/nontransactional.go | 263 +++++++++++++++++++-------- session/nontransactional_test.go | 76 +++++++- telemetry/data_feature_usage_test.go | 6 + 5 files changed, 285 insertions(+), 88 deletions(-) diff --git a/metrics/telemetry.go b/metrics/telemetry.go index e7629bcd76f6a..486a2d43a0c59 100644 --- a/metrics/telemetry.go +++ b/metrics/telemetry.go @@ -332,6 +332,7 @@ func GetTablePartitionCounter() TablePartitionUsageCounter { // NonTransactionalStmtCounter records the usages of non-transactional statements. type NonTransactionalStmtCounter struct { DeleteCount int64 `json:"delete"` + UpdateCount int64 `json:"update"` InsertCount int64 `json:"insert"` } @@ -339,6 +340,7 @@ type NonTransactionalStmtCounter struct { func (n NonTransactionalStmtCounter) Sub(rhs NonTransactionalStmtCounter) NonTransactionalStmtCounter { return NonTransactionalStmtCounter{ DeleteCount: n.DeleteCount - rhs.DeleteCount, + UpdateCount: n.UpdateCount - rhs.UpdateCount, InsertCount: n.InsertCount - rhs.InsertCount, } } @@ -347,6 +349,7 @@ func (n NonTransactionalStmtCounter) Sub(rhs NonTransactionalStmtCounter) NonTra func GetNonTransactionalStmtCounter() NonTransactionalStmtCounter { return NonTransactionalStmtCounter{ DeleteCount: readCounter(NonTransactionalDMLCount.With(prometheus.Labels{LblType: "delete"})), + UpdateCount: readCounter(NonTransactionalDMLCount.With(prometheus.Labels{LblType: "update"})), InsertCount: readCounter(NonTransactionalDMLCount.With(prometheus.Labels{LblType: "insert"})), } } diff --git a/parser/ast/dml.go b/parser/ast/dml.go index 2712a8f7eba51..c711da90d123f 100644 --- a/parser/ast/dml.go +++ b/parser/ast/dml.go @@ -2221,8 +2221,8 @@ func (n *InsertStmt) SetWhereExpr(e ExprNode) { s.Where = e } -// TableSource implements ShardableDMLStmt interface. -func (n *InsertStmt) TableSource() (*TableSource, bool) { +// TableRefsJoin implements ShardableDMLStmt interface. +func (n *InsertStmt) TableRefsJoin() (*Join, bool) { if n.Select == nil { return nil, false } @@ -2230,8 +2230,7 @@ func (n *InsertStmt) TableSource() (*TableSource, bool) { if !ok { return nil, false } - table, ok := s.From.TableRefs.Left.(*TableSource) - return table, ok + return s.From.TableRefs, true } // DeleteStmt is a statement to delete rows from table. @@ -2410,10 +2409,9 @@ func (n *DeleteStmt) SetWhereExpr(e ExprNode) { n.Where = e } -// TableSource implements ShardableDMLStmt interface. -func (n *DeleteStmt) TableSource() (*TableSource, bool) { - table, ok := n.TableRefs.TableRefs.Left.(*TableSource) - return table, ok +// TableRefsJoin implements ShardableDMLStmt interface. +func (n *DeleteStmt) TableRefsJoin() (*Join, bool) { + return n.TableRefs.TableRefs, true } const ( @@ -2426,8 +2424,8 @@ type ShardableDMLStmt = interface { StmtNode WhereExpr() ExprNode SetWhereExpr(ExprNode) - // TableSource returns the *only* target table source in the statement. - TableSource() (table *TableSource, ok bool) + // TableRefsJoin returns the table refs in the statement. + TableRefsJoin() (refs *Join, ok bool) } var _ ShardableDMLStmt = &DeleteStmt{} @@ -2649,10 +2647,9 @@ func (n *UpdateStmt) SetWhereExpr(e ExprNode) { n.Where = e } -// TableSource implements ShardableDMLStmt interface. -func (n *UpdateStmt) TableSource() (*TableSource, bool) { - table, ok := n.TableRefs.TableRefs.Left.(*TableSource) - return table, ok +// TableRefsJoin implements ShardableDMLStmt interface. +func (n *UpdateStmt) TableRefsJoin() (*Join, bool) { + return n.TableRefs.TableRefs, true } // Limit is the limit clause. diff --git a/session/nontransactional.go b/session/nontransactional.go index 779b2b7042094..d6fbd8e8bd9fa 100644 --- a/session/nontransactional.go +++ b/session/nontransactional.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" driver "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" @@ -52,6 +53,7 @@ var ErrNonTransactionalJobFailure = dbterror.ClassSession.NewStd(errno.ErrNonTra var ( nonTransactionalDeleteCount = metrics.NonTransactionalDMLCount.With(prometheus.Labels{metrics.LblType: "delete"}) nonTransactionalInsertCount = metrics.NonTransactionalDMLCount.With(prometheus.Labels{metrics.LblType: "insert"}) + nonTransactionalUpdateCount = metrics.NonTransactionalDMLCount.With(prometheus.Labels{metrics.LblType: "update"}) ) // job: handle keys in [start, end] @@ -95,10 +97,16 @@ func HandleNonTransactionalDML(ctx context.Context, stmt *ast.NonTransactionalDM if err := checkConstraint(stmt, se); err != nil { return nil, err } + tableName, selectSQL, shardColumnInfo, err := buildSelectSQL(stmt, se) if err != nil { return nil, err } + + if err := checkConstraintWithShardColumn(stmt, tableName, shardColumnInfo); err != nil { + return nil, err + } + if stmt.DryRun == ast.DryRunQuery { return buildDryRunResults(stmt.DryRun, []string{selectSQL}, se.GetSessionVars().BatchSize.MaxChunkSize) } @@ -124,6 +132,23 @@ func HandleNonTransactionalDML(ctx context.Context, stmt *ast.NonTransactionalDM return buildExecuteResults(ctx, jobs, se.GetSessionVars().BatchSize.MaxChunkSize, se.GetSessionVars().EnableRedactLog) } +func checkConstraintWithShardColumn(stmt *ast.NonTransactionalDMLStmt, tableName *ast.TableName, shardColumnInfo *model.ColumnInfo) error { + switch s := stmt.DMLStmt.(type) { + case *ast.UpdateStmt: + // FIXME: this check is not enough. the table name and schema name of the assignment can be null. But we cannot + // simply rely on the column name to judge it. + for _, assignment := range s.List { + if shardColumnInfo != nil && assignment.Column.Name.L == shardColumnInfo.Name.L && + assignment.Column.Table.L == tableName.Name.L && + assignment.Column.Schema.L == tableName.Schema.L { + return errors.New("Non-transactional DML, shard columns cannot be updated") + } + } + default: + } + return nil +} + func checkConstraint(stmt *ast.NonTransactionalDMLStmt, se Session) error { sessVars := se.GetSessionVars() if !(sessVars.IsAutocommit() && !sessVars.InTxn()) { @@ -143,7 +168,7 @@ func checkConstraint(stmt *ast.NonTransactionalDMLStmt, se Session) error { switch s := stmt.DMLStmt.(type) { case *ast.DeleteStmt: - if err := checkTableRef(s.TableRefs); err != nil { + if err := checkTableRef(s.TableRefs, true); err != nil { return err } if err := checkReadClauses(s.Limit, s.Order); err != nil { @@ -151,11 +176,13 @@ func checkConstraint(stmt *ast.NonTransactionalDMLStmt, se Session) error { } nonTransactionalDeleteCount.Inc() case *ast.UpdateStmt: - // TODO: check: (1) single target table (2) more... - if s.Limit != nil { - return errors.New("Non-transactional update doesn't support limit") + if err := checkTableRef(s.TableRefs, true); err != nil { + return err } - // TODO: metrics + if err := checkReadClauses(s.Limit, s.Order); err != nil { + return err + } + nonTransactionalUpdateCount.Inc() case *ast.InsertStmt: if s.Select == nil { return errors.New("Non-transactional insert supports insert select stmt only") @@ -164,31 +191,12 @@ func checkConstraint(stmt *ast.NonTransactionalDMLStmt, se Session) error { if !ok { return errors.New("Non-transactional insert doesn't support non-select source") } - if err := checkTableRef(selectStmt.From); err != nil { + if err := checkTableRef(selectStmt.From, true); err != nil { return err } if err := checkReadClauses(selectStmt.Limit, selectStmt.OrderBy); err != nil { return err } - sourceTable, ok := selectStmt.From.TableRefs.Left.(*ast.TableSource) - if !ok { - return errors.New("Non-transactional insert must have a source table") - } - sourceName, ok := sourceTable.Source.(*ast.TableName) - if !ok { - return errors.New("Non-transaction insert must have s source table") - } - targetTable, ok := s.Table.TableRefs.Left.(*ast.TableSource) - if !ok { - return errors.New("Non-transactional insert must have a target table") - } - targetName, ok := targetTable.Source.(*ast.TableName) - if !ok { - return errors.New("Non-transactional insert must have a target table") - } - if sourceName.Name.L == targetName.Name.L { - return errors.New("Non-transactional insert doesn't support self-insert") - } nonTransactionalInsertCount.Inc() default: return errors.New("Unsupported DML type for non-transactional DML") @@ -197,11 +205,11 @@ func checkConstraint(stmt *ast.NonTransactionalDMLStmt, se Session) error { return nil } -func checkTableRef(t *ast.TableRefsClause) error { +func checkTableRef(t *ast.TableRefsClause, allowMultipleTables bool) error { if t == nil || t.TableRefs == nil || t.TableRefs.Left == nil { return errors.New("table reference is nil") } - if t.TableRefs.Right != nil { + if !allowMultipleTables && t.TableRefs.Right != nil { return errors.New("Non-transactional statements don't support multiple tables") } return nil @@ -503,23 +511,28 @@ func appendNewJob(jobs []job, id int, start types.Datum, end types.Datum, size i func buildSelectSQL(stmt *ast.NonTransactionalDMLStmt, se Session) (*ast.TableName, string, *model.ColumnInfo, error) { // only use the first table - tableSource, ok := stmt.DMLStmt.TableSource() + join, ok := stmt.DMLStmt.TableRefsJoin() if !ok { return nil, "", nil, errors.New("Non-transactional DML, table source not found") } - tableName, ok := tableSource.Source.(*ast.TableName) + tableSources := make([]*ast.TableSource, 0) + tableSources, err := collectTableSourcesInJoin(join, tableSources) + if err != nil { + return nil, "", nil, err + } + if len(tableSources) == 0 { + return nil, "", nil, errors.New("Non-transactional DML, no tables found in table refs") + } + leftMostTableSource := tableSources[0] + leftMostTableName, ok := leftMostTableSource.Source.(*ast.TableName) if !ok { return nil, "", nil, errors.New("Non-transactional DML, table name not found") } - // the shard column must be indexed - indexed, shardColumnInfo, err := selectShardColumn(stmt, se, tableName, tableSource.AsName) + shardColumnInfo, tableName, err := selectShardColumn(stmt, se, tableSources, leftMostTableName, leftMostTableSource) if err != nil { return nil, "", nil, err } - if !indexed { - return nil, "", nil, errors.Errorf("Non-transactional DML, shard column %s is not indexed", stmt.ShardColumn.Name.L) - } var sb strings.Builder if stmt.DMLStmt.WhereExpr() != nil { @@ -541,54 +554,122 @@ func buildSelectSQL(stmt *ast.NonTransactionalDMLStmt, se Session) (*ast.TableNa return tableName, selectSQL, shardColumnInfo, nil } -// it attempts to auto-select a shard column from handle if not specified, and fills back the corresponding info in the stmt, -// making it transparent to following steps -func selectShardColumn(stmt *ast.NonTransactionalDMLStmt, se Session, tableName *ast.TableName, tableAsName model.CIStr) (indexed bool, shardColumnInfo *model.ColumnInfo, err error) { - tbl, err := domain.GetDomain(se).InfoSchema().TableByName(tableName.Schema, tableName.Name) - if err != nil { - return false, nil, err - } - tableInfo := tbl.Meta() +func selectShardColumn(stmt *ast.NonTransactionalDMLStmt, se Session, tableSources []*ast.TableSource, + leftMostTableName *ast.TableName, leftMostTableSource *ast.TableSource) ( + *model.ColumnInfo, *ast.TableName, error) { + var indexed bool + var shardColumnInfo *model.ColumnInfo + var selectedTableName *ast.TableName - var shardColumnName string - if stmt.ShardColumn == nil { - // auto-detect shard column - if tbl.Meta().PKIsHandle { - shardColumnInfo = tableInfo.GetPkColInfo() - } else if tableInfo.IsCommonHandle { - for _, index := range tableInfo.Indices { - if index.Primary { - if len(index.Columns) == 1 { - shardColumnInfo = tableInfo.Columns[index.Columns[0].Offset] - break - } - // if the clustered index contains multiple columns, we cannot automatically choose a column as the shard column - return false, nil, errors.New("Non-transactional DML, the clustered index contains multiple columns. Please specify a shard column") + if len(tableSources) == 1 { + // single table + leftMostTable, err := domain.GetDomain(se).InfoSchema().TableByName(leftMostTableName.Schema, leftMostTableName.Name) + if err != nil { + return nil, nil, err + } + selectedTableName = leftMostTableName + indexed, shardColumnInfo, err = selectShardColumnFromTheOnlyTable( + stmt, leftMostTableName, leftMostTableSource.AsName, leftMostTable) + if err != nil { + return nil, nil, err + } + } else { + // multi table join + if stmt.ShardColumn == nil { + leftMostTable, err := domain.GetDomain(se).InfoSchema().TableByName(leftMostTableName.Schema, leftMostTableName.Name) + if err != nil { + return nil, nil, err + } + selectedTableName = leftMostTableName + indexed, shardColumnInfo, err = selectShardColumnAutomatically(stmt, leftMostTable, leftMostTableName, leftMostTableSource.AsName) + if err != nil { + return nil, nil, err + } + } else if stmt.ShardColumn.Schema.L != "" && stmt.ShardColumn.Table.L != "" && stmt.ShardColumn.Name.L != "" { + dbName := stmt.ShardColumn.Schema + tableName := stmt.ShardColumn.Table + colName := stmt.ShardColumn.Name + + // the specified table must be in the join + tableInJoin := false + for _, tableSource := range tableSources { + tableSourceName := tableSource.Source.(*ast.TableName) + if tableSourceName.Schema.L == dbName.L && tableSourceName.Name.L == tableName.L { + tableInJoin = true + selectedTableName = tableSourceName + break } } - if shardColumnInfo == nil { - return false, nil, errors.New("Non-transactional DML, the clustered index is not found") + if !tableInJoin { + return nil, nil, + errors.Errorf( + "Non-transactional DML, shard column %s.%s.%s is not in the tables involved in the join", + dbName.L, tableName.L, colName.L, + ) } - } - shardColumnName := model.ExtraHandleName.L - if shardColumnInfo != nil { - shardColumnName = shardColumnInfo.Name.L + tbl, err := domain.GetDomain(se).InfoSchema().TableByName(dbName, tableName) + if err != nil { + return nil, nil, err + } + indexed, shardColumnInfo, err = selectShardColumnByGivenName(colName.L, tbl) + if err != nil { + return nil, nil, err + } + } else { + return nil, nil, errors.New( + "Non-transactional DML, shard column must be fully specified (dbname.tablename.colname) when multiple tables are involved", + ) } + } + if !indexed { + return nil, nil, errors.Errorf("Non-transactional DML, shard column %s is not indexed", stmt.ShardColumn.Name.L) + } + return shardColumnInfo, selectedTableName, nil +} - outputTableName := tableName.Name - if tableAsName.L != "" { - outputTableName = tableAsName +func collectTableSourcesInJoin(node ast.ResultSetNode, tableSources []*ast.TableSource) ([]*ast.TableSource, error) { + if node == nil { + return tableSources, nil + } + switch x := node.(type) { + case *ast.Join: + var err error + tableSources, err = collectTableSourcesInJoin(x.Left, tableSources) + if err != nil { + return nil, err + } + tableSources, err = collectTableSourcesInJoin(x.Right, tableSources) + if err != nil { + return nil, err } - stmt.ShardColumn = &ast.ColumnName{ - Schema: tableName.Schema, - Table: outputTableName, // so that table alias works - Name: model.NewCIStr(shardColumnName), + case *ast.TableSource: + // assert it's a table name + if _, ok := x.Source.(*ast.TableName); !ok { + return nil, errors.New("Non-transactional DML, table name not found in join") } - return true, shardColumnInfo, nil + tableSources = append(tableSources, x) + default: + return nil, errors.Errorf("Non-transactional DML, unknown type %T in table refs", node) + } + return tableSources, nil +} + +// it attempts to auto-select a shard column from handle if not specified, and fills back the corresponding info in the stmt, +// making it transparent to following steps +func selectShardColumnFromTheOnlyTable(stmt *ast.NonTransactionalDMLStmt, tableName *ast.TableName, + tableAsName model.CIStr, tbl table.Table) ( + indexed bool, shardColumnInfo *model.ColumnInfo, err error) { + if stmt.ShardColumn == nil { + return selectShardColumnAutomatically(stmt, tbl, tableName, tableAsName) } - shardColumnName = stmt.ShardColumn.Name.L + return selectShardColumnByGivenName(stmt.ShardColumn.Name.L, tbl) +} + +func selectShardColumnByGivenName(shardColumnName string, tbl table.Table) ( + indexed bool, shardColumnInfo *model.ColumnInfo, err error) { + tableInfo := tbl.Meta() if shardColumnName == model.ExtraHandleName.L && !tableInfo.HasClusteredIndex() { return true, nil, nil } @@ -621,6 +702,46 @@ func selectShardColumn(stmt *ast.NonTransactionalDMLStmt, se Session, tableName return indexed, shardColumnInfo, nil } +func selectShardColumnAutomatically(stmt *ast.NonTransactionalDMLStmt, tbl table.Table, + tableName *ast.TableName, tableAsName model.CIStr) (bool, *model.ColumnInfo, error) { + // auto-detect shard column + var shardColumnInfo *model.ColumnInfo + tableInfo := tbl.Meta() + if tbl.Meta().PKIsHandle { + shardColumnInfo = tableInfo.GetPkColInfo() + } else if tableInfo.IsCommonHandle { + for _, index := range tableInfo.Indices { + if index.Primary { + if len(index.Columns) == 1 { + shardColumnInfo = tableInfo.Columns[index.Columns[0].Offset] + break + } + // if the clustered index contains multiple columns, we cannot automatically choose a column as the shard column + return false, nil, errors.New("Non-transactional DML, the clustered index contains multiple columns. Please specify a shard column") + } + } + if shardColumnInfo == nil { + return false, nil, errors.New("Non-transactional DML, the clustered index is not found") + } + } + + shardColumnName := model.ExtraHandleName.L + if shardColumnInfo != nil { + shardColumnName = shardColumnInfo.Name.L + } + + outputTableName := tableName.Name + if tableAsName.L != "" { + outputTableName = tableAsName + } + stmt.ShardColumn = &ast.ColumnName{ + Schema: tableName.Schema, + Table: outputTableName, // so that table alias works + Name: model.NewCIStr(shardColumnName), + } + return true, shardColumnInfo, nil +} + func buildDryRunResults(dryRunOption int, results []string, maxChunkSize int) (sqlexec.RecordSet, error) { var fieldName string if dryRunOption == ast.DryRunSplitDml { diff --git a/session/nontransactional_test.go b/session/nontransactional_test.go index 3281cd71bd4da..9ae7188cfa7a7 100644 --- a/session/nontransactional_test.go +++ b/session/nontransactional_test.go @@ -722,9 +722,7 @@ func TestNonTransactionalWithCheckConstraint(t *testing.T) { err = tk.ExecToErr("batch limit 1 insert into t select 1, 1") require.EqualError(t, err, "table reference is nil") err = tk.ExecToErr("batch limit 1 insert into t select * from (select 1, 2) tmp") - require.EqualError(t, err, "Non-transaction insert must have s source table") - err = tk.ExecToErr("batch limit 1 insert into t select * from t") - require.EqualError(t, err, "Non-transactional insert doesn't support self-insert") + require.EqualError(t, err, "Non-transactional DML, table name not found in join") } func TestNonTransactionalWithOptimizerHints(t *testing.T) { @@ -879,3 +877,75 @@ func TestNonTransactionalWithShardOnUnsupportedTypes(t *testing.T) { require.Error(t, err) tk.MustQuery("select count(*) from t2").Check(testkit.Rows("1")) } + +func TestNonTransactionalWithJoin(t *testing.T) { + // insert + // BATCH ON t2.id LIMIT 10000 insert into t1 select id, name from t2 inner join t3 on t2.id = t3.id; + // insert into t1 select id, name from t2 inner join t3 on t2.id = t3.id where t2.id between 1 and 10000 + + // update + // BATCH ON t1.id LIMIT 10000 update t1 join t2 on t1.a=t2.a set t1.a=t1.a*1000; + // update t1 join t2 on t1.a=t2.a set t1.a=t1.a*1000 where t1.id between 1 and 10000; + + // delete + // BATCH ON pa.id LIMIT 10000 DELETE pa + // FROM pets_activities pa JOIN pets p ON pa.id = p.pet_id + // WHERE p.order > :order AND p.pet_id = :pet_id + // delete pa + // from pets_activities pa join pets p on pa.id = p.pet_id + // WHERE p.order > :order AND p.pet_id = :pet_id and pa.id between 1 and 10000; + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t(id int, v1 int, v2 int, key (id))") + tk.MustExec("create table t2(id int, v int, key i1(id))") + tk.MustExec("create table t3(id int, v int, key i1(id))") + tk.MustExec("insert into t2 values (1, 2), (2, 3), (3, 4)") + tk.MustExec("insert into t3 values (1, 4), (2, 5), (4, 6)") + + tk.MustExec("batch on test.t2.id limit 1 insert into t select t2.id, t2.v, t3.v from t2 join t3 on t2.id=t3.id") + tk.MustQuery("select * from t").Check(testkit.Rows("1 2 4", "2 3 5")) + + tk.MustContainErrMsg( + "batch on id limit 1 insert into t select t2.id, t2.v, t3.v from t2 join t3 on t2.id=t3.id", + "Non-transactional DML, shard column must be fully specified", + ) + tk.MustContainErrMsg( + "batch on test.t1.id limit 1 insert into t select t2.id, t2.v, t3.v from t2 join t3 on t2.id=t3.id", + "shard column test.t1.id is not in the tables involved in the join", + ) + + tk.MustExec("batch on test.t2.id limit 1 update t2 join t3 on t2.id=t3.id set t2.v=t2.v*100, t3.v=t3.v*200") + tk.MustQuery("select * from t2").Check(testkit.Rows("1 200", "2 300", "3 4")) + tk.MustQuery("select * from t3").Check(testkit.Rows("1 800", "2 1000", "4 6")) + + tk.MustExec("batch limit 1 delete t2 from t2 join t3 on t2.id=t3.id") + tk.MustQuery("select * from t2").Check(testkit.Rows("3 4")) +} + +func TestAnomalousNontransactionalDML(t *testing.T) { + // some weird and error-prone behavior + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t(id int, v int)") + + // self-insert, this is allowed but can be dangerous + tk.MustExec("insert into t values (1, 1)") + tk.MustExec("batch limit 1 insert into t select * from t") + tk.MustQuery("select * from t").Check(testkit.Rows("1 1", "1 1")) + tk.MustExec("drop table t") + + tk.MustExec("create table t(id int, v int, key(id))") + tk.MustExec("create table t2(id int, v int, key(id))") + tk.MustExec("insert into t values (1, 1), (2, 2), (3, 3)") + tk.MustExec("insert into t2 values (1, 1), (2, 2), (4, 4)") + + // FIXME: we should not allow this, where the shard column is the join key + tk.MustExec("batch on test.t.id limit 1 update t join t2 on t.id=t2.id set t2.id = t2.id+1") + tk.MustQuery("select * from t2").Check(testkit.Rows("4 1", "4 2", "4 4")) + + // FIXME: and this + tk.MustExec("batch on id limit 1 update t set id=id+1") + tk.MustQuery("select * from t").Check(testkit.Rows("4 1", "4 2", "4 3")) +} diff --git a/telemetry/data_feature_usage_test.go b/telemetry/data_feature_usage_test.go index 81b1f2e6a1f01..770f4a7d3d08e 100644 --- a/telemetry/data_feature_usage_test.go +++ b/telemetry/data_feature_usage_test.go @@ -373,12 +373,18 @@ func TestNonTransactionalUsage(t *testing.T) { usage, err := telemetry.GetFeatureUsage(tk.Session()) require.NoError(t, err) require.Equal(t, int64(0), usage.NonTransactionalUsage.DeleteCount) + require.Equal(t, int64(0), usage.NonTransactionalUsage.UpdateCount) + require.Equal(t, int64(0), usage.NonTransactionalUsage.InsertCount) tk.MustExec("create table t(a int);") tk.MustExec("batch limit 1 delete from t") + tk.MustExec("batch limit 1 update t set a = 1") + tk.MustExec("batch limit 1 insert into t select * from t") usage, err = telemetry.GetFeatureUsage(tk.Session()) require.NoError(t, err) require.Equal(t, int64(1), usage.NonTransactionalUsage.DeleteCount) + require.Equal(t, int64(1), usage.NonTransactionalUsage.UpdateCount) + require.Equal(t, int64(1), usage.NonTransactionalUsage.InsertCount) } func TestGlobalKillUsageInfo(t *testing.T) {