diff --git a/bindinfo/session_handle_test.go b/bindinfo/session_handle_test.go index a9b05c03eda18..f38ab0696bc97 100644 --- a/bindinfo/session_handle_test.go +++ b/bindinfo/session_handle_test.go @@ -16,6 +16,7 @@ package bindinfo_test import ( "context" + "fmt" "strconv" "testing" "time" @@ -459,6 +460,29 @@ func TestDropSingleBindings(t *testing.T) { require.Len(t, rows, 0) } +func TestIssue53834(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t (a varchar(1024))`) + tk.MustExec(`insert into t values (space(1024))`) + for i := 0; i < 12; i++ { + tk.MustExec(`insert into t select * from t`) + } + oomAction := tk.MustQuery(`select @@tidb_mem_oom_action`).Rows()[0][0].(string) + defer func() { + tk.MustExec(fmt.Sprintf(`set global tidb_mem_oom_action='%v'`, oomAction)) + }() + + tk.MustExec(`set global tidb_mem_oom_action='cancel'`) + err := tk.ExecToErr(`replace into t select /*+ memory_quota(1 mb) */ * from t`) + require.ErrorContains(t, err, "cancelled due to exceeding the allowed memory limit") + + tk.MustExec(`create binding for replace into t select * from t using replace into t select /*+ memory_quota(1 mb) */ * from t`) + err = tk.ExecToErr(`replace into t select * from t`) + require.ErrorContains(t, err, "cancelled due to exceeding the allowed memory limit") +} + func TestPreparedStmt(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) diff --git a/planner/optimize.go b/planner/optimize.go index fbef265e32aa6..b1af46e2e73a1 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -206,7 +206,7 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in } metrics.BindUsageCounter.WithLabelValues(scope).Inc() hint.BindHint(stmtNode, binding.Hint) - curStmtHints, _, curWarns := handleStmtHints(binding.Hint.GetFirstTableHints()) + curStmtHints, _, curWarns := handleStmtHints(binding.Hint.GetStmtHints()) sessVars.StmtCtx.StmtHints = curStmtHints // update session var by hint /set_var/ for name, val := range sessVars.StmtCtx.StmtHints.SetVars { diff --git a/session/session_test/session_test.go b/session/session_test/session_test.go index 4e8fb87069529..42e27758e6ea6 100644 --- a/session/session_test/session_test.go +++ b/session/session_test/session_test.go @@ -2185,11 +2185,11 @@ func TestStmtHints(t *testing.T) { val = int64(1) * 1024 * 1024 require.True(t, tk.Session().GetSessionVars().MemTracker.CheckBytesLimit(val)) - tk.MustExec("insert /*+ MEMORY_QUOTA(1 MB) */ into t1 select /*+ MEMORY_QUOTA(3 MB) */ * from t1;") + tk.MustExec("insert /*+ MEMORY_QUOTA(1 MB) */ into t1 select /*+ MEMORY_QUOTA(1 MB) */ * from t1;") val = int64(1) * 1024 * 1024 require.True(t, tk.Session().GetSessionVars().MemTracker.CheckBytesLimit(val)) - require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 1) - require.EqualError(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings()[0].Err, "[util:3126]Hint MEMORY_QUOTA(`3145728`) is ignored as conflicting/duplicated.") + require.Len(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings(), 2) + require.EqualError(t, tk.Session().GetSessionVars().StmtCtx.GetWarnings()[0].Err, "[util:3126]Hint MEMORY_QUOTA(`1048576`) is ignored as conflicting/duplicated.") // Test NO_INDEX_MERGE hint tk.Session().GetSessionVars().SetEnableIndexMerge(true) diff --git a/util/hint/hint_processor.go b/util/hint/hint_processor.go index 6051f01b2489c..316cab025726c 100644 --- a/util/hint/hint_processor.go +++ b/util/hint/hint_processor.go @@ -44,12 +44,20 @@ type HintsSet struct { indexHints [][]*ast.IndexHint // Slice offset is the traversal order of `TableName` in the ast. } -// GetFirstTableHints gets the first table hints. -func (hs *HintsSet) GetFirstTableHints() []*ast.TableOptimizerHint { +// GetStmtHints gets all statement-level hints. +func (hs *HintsSet) GetStmtHints() []*ast.TableOptimizerHint { + var result []*ast.TableOptimizerHint if len(hs.tableHints) > 0 { - return hs.tableHints[0] + result = append(result, hs.tableHints[0]...) // keep the same behavior with prior implementation } - return nil + for _, tHints := range hs.tableHints[1:] { + for _, h := range tHints { + if isStmtHint(h) { + result = append(result, h) + } + } + } + return result } // ContainTableHint checks whether the table hint set contains a hint. @@ -88,7 +96,18 @@ func ExtractTableHintsFromStmtNode(node ast.Node, sctx sessionctx.Context) []*as case *ast.InsertStmt: // check duplicated hints checkInsertStmtHintDuplicated(node, sctx) - return x.TableHints + result := make([]*ast.TableOptimizerHint, 0, len(x.TableHints)) + result = append(result, x.TableHints...) + if x.Select != nil { + // support statement-level hint in sub-select: "insert into t select /* ... */ ..." + // TODO: support this for Update and Delete as well + for _, h := range ExtractTableHintsFromStmtNode(x.Select, sctx) { + if isStmtHint(h) { + result = append(result, h) + } + } + } + return result case *ast.ExplainStmt: return ExtractTableHintsFromStmtNode(x.Stmt, sctx) case *ast.SetOprStmt: @@ -634,3 +653,13 @@ func GenerateQBName(nodeType NodeType, blockOffset int) (model.CIStr, error) { } return model.NewCIStr(fmt.Sprintf("%s%d", defaultSelectBlockPrefix, blockOffset)), nil } + +// isStmtHint checks whether this hint is a statement-level hint. +func isStmtHint(h *ast.TableOptimizerHint) bool { + switch h.HintName.L { + case "max_execution_time", "memory_quota", "resource_group": + return true + default: + return false + } +}