diff --git a/expression/expression.go b/expression/expression.go index 30897624b82f3..360c545eff87f 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -1567,6 +1567,8 @@ func Args2Expressions4Test(args ...interface{}) []Expression { ft = types.NewFieldType(mysql.TypeVarString) case types.KindMysqlTime: ft = types.NewFieldType(mysql.TypeTimestamp) + case types.KindBytes: + ft = types.NewFieldType(mysql.TypeBlob) default: exprs[i] = nil continue diff --git a/parser/mysql/errname.go b/parser/mysql/errname.go index 5e49360ff3104..bffeae064e848 100644 --- a/parser/mysql/errname.go +++ b/parser/mysql/errname.go @@ -269,7 +269,7 @@ var MySQLErrName = map[uint16]*ErrMessage{ ErrKeyRefDoNotMatchTableRef: Message("Key reference and table reference don't match", nil), ErrOperandColumns: Message("Operand should contain %d column(s)", nil), ErrSubqueryNo1Row: Message("Subquery returns more than 1 row", nil), - ErrUnknownStmtHandler: Message("Unknown prepared statement handler (%.*s) given to %s", nil), + ErrUnknownStmtHandler: Message("Unknown prepared statement handler %s given to %s", nil), ErrCorruptHelpDB: Message("Help database is corrupt or does not exist", nil), ErrCyclicReference: Message("Cyclic reference on subqueries", nil), ErrAutoConvert: Message("Converting column '%s' from %s to %s", nil), diff --git a/server/conn.go b/server/conn.go index 8d13186872a6b..83758683806fb 100644 --- a/server/conn.go +++ b/server/conn.go @@ -2277,7 +2277,12 @@ func (cc *clientConn) writeResultset(ctx context.Context, rs ResultSet, binary b cc.initResultEncoder(ctx) defer cc.rsEncoder.clean() if mysql.HasCursorExistsFlag(serverStatus) { - if err := cc.writeChunksWithFetchSize(ctx, rs, serverStatus, fetchSize); err != nil { + crs, ok := rs.(cursorResultSet) + if !ok { + // this branch is actually unreachable + return false, errors.New("this cursor is not a resultSet") + } + if err := cc.writeChunksWithFetchSize(ctx, crs, serverStatus, fetchSize); err != nil { return false, err } return false, cc.flush(ctx) @@ -2411,43 +2416,27 @@ func (cc *clientConn) writeChunks(ctx context.Context, rs ResultSet, binary bool // binary specifies the way to dump data. It throws any error while dumping data. // serverStatus, a flag bit represents server information. // fetchSize, the desired number of rows to be fetched each time when client uses cursor. -func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet, serverStatus uint16, fetchSize int) error { - fetchedRows := rs.GetFetchedRows() - - // tell the client COM_STMT_FETCH has finished by setting proper serverStatus, - // and close ResultSet. - if len(fetchedRows) == 0 { - serverStatus &^= mysql.ServerStatusCursorExists - serverStatus |= mysql.ServerStatusLastRowSend - return cc.writeEOF(ctx, serverStatus) - } - - // construct the rows sent to the client according to fetchSize. - var curRows []chunk.Row - if fetchSize < len(fetchedRows) { - curRows = fetchedRows[:fetchSize] - fetchedRows = fetchedRows[fetchSize:] - } else { - curRows = fetchedRows - fetchedRows = fetchedRows[:0] - } - rs.StoreFetchedRows(fetchedRows) - +func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs cursorResultSet, serverStatus uint16, fetchSize int) error { + var ( + stmtDetail *execdetails.StmtExecDetails + err error + start time.Time + ) data := cc.alloc.AllocWithLen(4, 1024) - var stmtDetail *execdetails.StmtExecDetails stmtDetailRaw := ctx.Value(execdetails.StmtExecDetailKey) if stmtDetailRaw != nil { //nolint:forcetypeassert stmtDetail = stmtDetailRaw.(*execdetails.StmtExecDetails) } - var ( - err error - start time.Time - ) if stmtDetail != nil { start = time.Now() } - for _, row := range curRows { + + iter := rs.GetRowContainerReader() + // send the rows to the client according to fetchSize. + for i := 0; i < fetchSize && iter.Current() != iter.End(); i++ { + row := iter.Current() + data = data[0:4] data, err = dumpBinaryRow(data, rs.Columns(), row, cc.rsEncoder) if err != nil { @@ -2456,16 +2445,30 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet if err = cc.writePacket(data); err != nil { return err } + + iter.Next() + } + if iter.Error() != nil { + return iter.Error() + } + + // tell the client COM_STMT_FETCH has finished by setting proper serverStatus, + // and close ResultSet. + if iter.Current() == iter.End() { + serverStatus &^= mysql.ServerStatusCursorExists + serverStatus |= mysql.ServerStatusLastRowSend } + + // don't include the time consumed by `cl.OnFetchReturned()` in the `WriteSQLRespDuration` if stmtDetail != nil { stmtDetail.WriteSQLRespDuration += time.Since(start) } + if cl, ok := rs.(fetchNotifier); ok { cl.OnFetchReturned() } - if stmtDetail != nil { - start = time.Now() - } + + start = time.Now() err = cc.writeEOF(ctx, serverStatus) if stmtDetail != nil { stmtDetail.WriteSQLRespDuration += time.Since(start) diff --git a/server/conn_stmt.go b/server/conn_stmt.go index 145e83aee82b1..eb9a53ece3958 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -45,6 +45,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/ast" @@ -53,15 +54,19 @@ import ( "github.com/pingcap/tidb/parser/terror" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/sessiontxn" storeerr "github.com/pingcap/tidb/store/driver/error" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/execdetails" "github.com/pingcap/tidb/util/hack" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/topsql" topsqlstate "github.com/pingcap/tidb/util/topsql/state" "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" ) func (cc *clientConn) handleStmtPrepare(ctx context.Context, sql string) error { @@ -203,7 +208,11 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e } err = parseExecArgs(cc.ctx.GetSessionVars().StmtCtx, args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder) - stmt.Reset() + // This `.Reset` resets the arguments, so it's fine to just ignore the error (and the it'll be reset again in the following routine) + errReset := stmt.Reset() + if errReset != nil { + logutil.Logger(ctx).Warn("fail to reset statement in EXECUTE command", zap.Error(errReset)) + } if err != nil { return errors.Annotate(err, cc.preparedStmt2String(stmtID)) } @@ -265,6 +274,26 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm PrepStmt: prepStmt, } + // first, try to clear the left cursor if there is one + if useCursor && stmt.GetCursorActive() { + if stmt.GetResultSet() != nil && stmt.GetResultSet().GetRowContainerReader() != nil { + stmt.GetResultSet().GetRowContainerReader().Close() + } + if stmt.GetRowContainer() != nil { + stmt.GetRowContainer().GetMemTracker().Detach() + stmt.GetRowContainer().GetDiskTracker().Detach() + err := stmt.GetRowContainer().Close() + if err != nil { + logutil.Logger(ctx).Error( + "Fail to close rowContainer before executing statement. May cause resource leak", + zap.Error(err)) + } + stmt.StoreRowContainer(nil) + } + stmt.StoreResultSet(nil) + stmt.SetCursorActive(false) + } + // For the combination of `ComPrepare` and `ComExecute`, the statement name is stored in the client side, and the // TiDB only has the ID, so don't try to construct an `EXECUTE SOMETHING`. Use the original prepared statement here // instead. @@ -306,42 +335,83 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm // we should hold the ResultSet in PreparedStatement for next stmt_fetch, and only send back ColumnInfo. // Tell the client cursor exists in server by setting proper serverStatus. if useCursor { + crs := wrapWithCursor(rs) + cc.initResultEncoder(ctx) defer cc.rsEncoder.clean() // fetch all results of the resultSet, and stored them locally, so that the future `FETCH` command can read // the rows directly to avoid running executor and accessing shared params/variables in the session // NOTE: chunk should not be allocated from the connection allocator, which will reset after executing this command // but the rows are still needed in the following FETCH command. - // - // TODO: trace the memory used here - chk := rs.NewChunk(nil) - var rows []chunk.Row + + // create the row container to manage spill + // this `rowContainer` will be released when the statement (or the connection) is closed. + rowContainer := chunk.NewRowContainer(crs.FieldTypes(), vars.MaxChunkSize) + rowContainer.GetMemTracker().AttachTo(vars.MemTracker) + rowContainer.GetMemTracker().SetLabel(memory.LabelForCursorFetch) + rowContainer.GetDiskTracker().AttachTo(vars.DiskTracker) + rowContainer.GetDiskTracker().SetLabel(memory.LabelForCursorFetch) + if variable.EnableTmpStorageOnOOM.Load() { + failpoint.Inject("testCursorFetchSpill", func(val failpoint.Value) { + if val, ok := val.(bool); val && ok { + actionSpill := rowContainer.ActionSpillForTest() + defer actionSpill.WaitForTest() + } + }) + action := memory.NewActionWithPriority(rowContainer.ActionSpill(), memory.DefCursorFetchSpillPriority) + vars.MemTracker.FallbackOldAndSetNewAction(action) + } + defer func() { + if err != nil { + rowContainer.GetMemTracker().Detach() + rowContainer.GetDiskTracker().Detach() + errCloseRowContainer := rowContainer.Close() + if errCloseRowContainer != nil { + logutil.Logger(ctx).Error("Fail to close rowContainer in error handler. May cause resource leak", + zap.NamedError("original-error", err), zap.NamedError("close-error", errCloseRowContainer)) + } + } + }() + for { - if err = rs.Next(ctx, chk); err != nil { + chk := crs.NewChunk(nil) + + if err = crs.Next(ctx, chk); err != nil { return false, err } rowCount := chk.NumRows() if rowCount == 0 { break } - // filling fetchedRows with chunk - for i := 0; i < rowCount; i++ { - row := chk.GetRow(i) - rows = append(rows, row) + + err = rowContainer.Add(chk) + if err != nil { + return false, err } - chk = chunk.Renew(chk, vars.MaxChunkSize) } - rs.StoreFetchedRows(rows) - stmt.StoreResultSet(rs) - if err = cc.writeColumnInfo(rs.Columns()); err != nil { - return false, err - } - if cl, ok := rs.(fetchNotifier); ok { + reader := chunk.NewRowContainerReader(rowContainer) + crs.StoreRowContainerReader(reader) + stmt.StoreResultSet(crs) + stmt.StoreRowContainer(rowContainer) + if cl, ok := crs.(fetchNotifier); ok { cl.OnFetchReturned() } - stmt.SetCursorActive(true) + defer func() { + if err != nil { + reader.Close() + + // the resultSet and rowContainer have been closed in former "defer" statement. + stmt.StoreResultSet(nil) + stmt.StoreRowContainer(nil) + stmt.SetCursorActive(false) + } + }() + + if err = cc.writeColumnInfo(crs.Columns()); err != nil { + return false, err + } // explicitly flush columnInfo to client. err = cc.writeEOF(ctx, cc.ctx.Status()) @@ -368,6 +438,12 @@ func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err err cc.ctx.GetSessionVars().ClearAlloc(nil, false) cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, true) defer cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, false) + // Reset the warn count. TODO: consider whether it's better to reset the whole session context/statement context. + if cc.ctx.GetSessionVars().StmtCtx != nil { + cc.ctx.GetSessionVars().StmtCtx.SetWarnings(nil) + } + cc.ctx.GetSessionVars().SysErrorCount = 0 + cc.ctx.GetSessionVars().SysWarningCount = 0 stmtID, fetchSize, err := parseStmtFetchCmd(data) if err != nil { @@ -379,6 +455,21 @@ func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err err return errors.Annotate(mysql.NewErr(mysql.ErrUnknownStmtHandler, strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch"), cc.preparedStmt2String(stmtID)) } + if !stmt.GetCursorActive() { + return errors.Annotate(mysql.NewErr(mysql.ErrSpCursorNotOpen), cc.preparedStmt2String(stmtID)) + } + // from now on, we have made sure: the statement has an active cursor + // then if facing any error, this cursor should be reset + defer func() { + if err != nil { + errReset := stmt.Reset() + if errReset != nil { + logutil.Logger(ctx).Error("Fail to reset statement in error handler. May cause resource leak.", + zap.NamedError("original-error", err), zap.NamedError("reset-error", errReset)) + } + } + }() + if topsqlstate.TopSQLEnabled() { prepareObj, _ := cc.preparedStmtID2CachePreparedStmt(stmtID) if prepareObj != nil && prepareObj.SQLDigest != nil { @@ -391,23 +482,22 @@ func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err err } cc.ctx.SetProcessInfo(sql, time.Now(), mysql.ComStmtExecute, 0) rs := stmt.GetResultSet() - if rs == nil { - return errors.Annotate(mysql.NewErr(mysql.ErrUnknownStmtHandler, - strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch_rs"), cc.preparedStmt2String(stmtID)) - } - sendingEOF := false - // if the `fetchedRows` are empty before writing result, we could say the `FETCH` command will send EOF - if len(rs.GetFetchedRows()) == 0 { - sendingEOF = true - } _, err = cc.writeResultset(ctx, rs, true, cc.ctx.Status(), int(fetchSize)) + // if the iterator reached the end before writing result, we could say the `FETCH` command will send EOF + if rs.GetRowContainerReader().Current() == rs.GetRowContainerReader().End() { + // also reset the statement when the cursor reaches the end + // don't overwrite the `err` in outer scope, to avoid redundant `Reset()` in `defer` statement (though, it's not + // a big problem, as the `Reset()` function call is idempotent.) + err := stmt.Reset() + if err != nil { + logutil.Logger(ctx).Error("Fail to reset statement when FETCH command reaches the end. May cause resource leak", + zap.NamedError("error", err)) + } + } if err != nil { return errors.Annotate(err, cc.preparedStmt2String(stmtID)) } - if sendingEOF { - stmt.SetCursorActive(false) - } return nil } @@ -768,6 +858,10 @@ func (cc *clientConn) handleStmtSendLongData(data []byte) (err error) { } func (cc *clientConn) handleStmtReset(ctx context.Context, data []byte) (err error) { + // A reset command should reset the statement to the state when it was right after prepare + // Then the following state should be cleared: + // 1.The opened cursor, including the rowContainer (and its cursor/memTracker). + // 2.The argument sent through `SEND_LONG_DATA`. if len(data) < 4 { return mysql.ErrMalformPacket } @@ -778,8 +872,16 @@ func (cc *clientConn) handleStmtReset(ctx context.Context, data []byte) (err err return mysql.NewErr(mysql.ErrUnknownStmtHandler, strconv.Itoa(stmtID), "stmt_reset") } - stmt.Reset() - stmt.StoreResultSet(nil) + err = stmt.Reset() + if err != nil { + // Both server and client cannot handle the error case well, so just left an error and return OK. + // It's fine to receive further `EXECUTE` command even the `Reset` function call failed. + logutil.Logger(ctx).Error("Fail to close statement in error handler of RESET command. May cause resource leak", + zap.NamedError("original-error", err), zap.NamedError("close-error", err)) + + return cc.writeOK(ctx) + } + return cc.writeOK(ctx) } diff --git a/server/conn_stmt_test.go b/server/conn_stmt_test.go index 4383aacbe3ab1..6d91c5ae4071a 100644 --- a/server/conn_stmt_test.go +++ b/server/conn_stmt_test.go @@ -17,10 +17,19 @@ package server import ( "bytes" "context" + "crypto/rand" "encoding/binary" "fmt" + "io/fs" + "os" + "path/filepath" + "strconv" + "strings" + "syscall" "testing" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" @@ -298,12 +307,7 @@ func TestCursorExistsFlag(t *testing.T) { require.NoError(t, c.Dispatch(ctx, append([]byte{mysql.ComQuery}, "select * from t"...))) require.False(t, mysql.HasCursorExistsFlag(getLastStatus())) - // fetch last 3 - require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5))) - require.True(t, mysql.HasCursorExistsFlag(getLastStatus())) - - // final fetch with no row retured - // (tidb doesn't unset cursor-exists flag in the previous response like mysql, one more fetch is needed) + // fetch last 3, the `CursorExist` flag should have been unset and the `LastRowSend` flag should have been set require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5))) require.False(t, mysql.HasCursorExistsFlag(getLastStatus())) require.True(t, getLastStatus()&mysql.ServerStatusLastRowSend > 0) @@ -311,6 +315,24 @@ func TestCursorExistsFlag(t *testing.T) { // COM_QUERY after fetch require.NoError(t, c.Dispatch(ctx, append([]byte{mysql.ComQuery}, "select * from t"...))) require.False(t, mysql.HasCursorExistsFlag(getLastStatus())) + + // try another query without response + stmt, _, _, err = c.Context().Prepare("select * from t where a = 100") + require.NoError(t, err) + + require.NoError(t, c.Dispatch(ctx, append( + appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())), + mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, + ))) + require.True(t, mysql.HasCursorExistsFlag(getLastStatus())) + + // fetch 5 rows, it will return no data with the `CursorExist` unset and `LastRowSend` set. + require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5))) + require.False(t, mysql.HasCursorExistsFlag(getLastStatus())) + require.True(t, getLastStatus()&mysql.ServerStatusLastRowSend > 0) + + // the following FETCH should fail, as the cursor has been automatically closed + require.Error(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5))) } func TestCursorWithParams(t *testing.T) { @@ -341,10 +363,11 @@ func TestCursorWithParams(t *testing.T) { 0x0, 0x1, 0x3, 0x0, 0x3, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, ))) - rows := c.Context().stmts[stmt1.ID()].GetResultSet().GetFetchedRows() - require.Len(t, rows, 1) - require.Equal(t, int64(1), rows[0].GetInt64(0)) - require.Equal(t, int64(2), rows[0].GetInt64(1)) + rows := c.Context().stmts[stmt1.ID()].GetResultSet().GetRowContainerReader() + require.Equal(t, int64(1), rows.Current().GetInt64(0)) + require.Equal(t, int64(2), rows.Current().GetInt64(1)) + rows.Next() + require.Equal(t, rows.End(), rows.Current()) // `execute stmt2 using 1` with cursor require.NoError(t, c.Dispatch(ctx, append( @@ -353,12 +376,13 @@ func TestCursorWithParams(t *testing.T) { 0x0, 0x1, 0x3, 0x0, 0x1, 0x0, 0x0, 0x0, ))) - rows = c.Context().stmts[stmt2.ID()].GetResultSet().GetFetchedRows() - require.Len(t, rows, 2) - require.Equal(t, int64(1), rows[0].GetInt64(0)) - require.Equal(t, int64(1), rows[0].GetInt64(1)) - require.Equal(t, int64(1), rows[1].GetInt64(0)) - require.Equal(t, int64(2), rows[1].GetInt64(1)) + rows = c.Context().stmts[stmt2.ID()].GetResultSet().GetRowContainerReader() + require.Equal(t, int64(1), rows.Current().GetInt64(0)) + require.Equal(t, int64(1), rows.Current().GetInt64(1)) + require.Equal(t, int64(1), rows.Next().GetInt64(0)) + require.Equal(t, int64(2), rows.Current().GetInt64(1)) + rows.Next() + require.Equal(t, rows.End(), rows.Current()) // fetch stmt2 with fetch size 256 require.NoError(t, c.Dispatch(ctx, append( @@ -407,7 +431,9 @@ func TestCursorDetachMemTracker(t *testing.T) { // testkit also uses `PREPARE` related calls to run statement with arguments. // format the SQL to avoid the interference from testkit. tk.MustExec(fmt.Sprintf("set tidb_mem_quota_query=%d", maxConsumed/2)) - require.Len(t, tk.Session().GetSessionVars().MemTracker.GetChildrenForTest(), 0) + // there is one memTracker for the resultSet spill-disk + require.Len(t, tk.Session().GetSessionVars().MemTracker.GetChildrenForTest(), 1) + // This query should exceed the memory limitation during `openExecutor` require.Error(t, c.Dispatch(ctx, append( appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())), @@ -423,7 +449,8 @@ func TestCursorDetachMemTracker(t *testing.T) { appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())), mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, ))) - require.Len(t, tk.Session().GetSessionVars().MemTracker.GetChildrenForTest(), 0) + // there is one memTracker for the resultSet spill-disk + require.Len(t, tk.Session().GetSessionVars().MemTracker.GetChildrenForTest(), 1) } func TestMemoryTrackForPrepareBinaryProtocol(t *testing.T) { @@ -445,3 +472,114 @@ func TestMemoryTrackForPrepareBinaryProtocol(t *testing.T) { } require.Len(t, tk.Session().GetSessionVars().MemTracker.GetChildrenForTest(), 0) } + +func TestCursorFetchShouldSpill(t *testing.T) { + restore := config.RestoreFunc() + defer restore() + config.UpdateGlobal(func(conf *config.Config) { + conf.TempStoragePath = t.TempDir() + }) + + store, dom := testkit.CreateMockStoreAndDomain(t) + srv := CreateMockServer(t, store) + srv.SetDomain(dom) + defer srv.Close() + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/testCursorFetchSpill", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/testCursorFetchSpill")) + }() + + appendUint32 := binary.LittleEndian.AppendUint32 + ctx := context.Background() + c := CreateMockConn(t, srv).(*mockConn) + + tk := testkit.NewTestKitWithSession(t, store, c.Context().Session) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(id_1 int, id_2 int)") + tk.MustExec("insert into t values (1, 1), (1, 2)") + tk.MustExec("set global tidb_enable_tmp_storage_on_oom = ON") + tk.MustExec("set global tidb_mem_oom_action = 'CANCEL'") + defer tk.MustExec("set global tidb_mem_oom_action= DEFAULT") + require.Len(t, tk.Session().GetSessionVars().MemTracker.GetChildrenForTest(), 1) + + // execute a normal statement, it'll spill to disk + stmt, _, _, err := c.Context().Prepare("select * from t") + require.NoError(t, err) + + tk.MustExec(fmt.Sprintf("set tidb_mem_quota_query=%d", 1)) + + require.NoError(t, c.Dispatch(ctx, append( + appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())), + mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, + ))) +} + +func TestCursorFetchErrorInFetch(t *testing.T) { + tmpStoragePath := t.TempDir() + restore := config.RestoreFunc() + defer restore() + config.UpdateGlobal(func(conf *config.Config) { + conf.TempStoragePath = tmpStoragePath + }) + + store, dom := testkit.CreateMockStoreAndDomain(t) + srv := CreateMockServer(t, store) + srv.SetDomain(dom) + defer srv.Close() + + appendUint32 := binary.LittleEndian.AppendUint32 + ctx := context.Background() + c := CreateMockConn(t, srv).(*mockConn) + + tk := testkit.NewTestKitWithSession(t, store, c.Context().Session) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(id int, payload BLOB)") + payload := make([]byte, 512) + for i := 0; i < 2048; i++ { + rand.Read(payload) + tk.MustExec("insert into t values (?, ?)", i, payload) + } + + tk.MustExec("set global tidb_enable_tmp_storage_on_oom = ON") + tk.MustExec("set global tidb_mem_oom_action = 'CANCEL'") + defer tk.MustExec("set global tidb_mem_oom_action= DEFAULT") + require.Len(t, tk.Session().GetSessionVars().MemTracker.GetChildrenForTest(), 1) + + // execute a normal statement, it'll spill to disk + stmt, _, _, err := c.Context().Prepare("select * from t") + require.NoError(t, err) + + tk.MustExec(fmt.Sprintf("set tidb_mem_quota_query=%d", 1)) + + require.NoError(t, c.Dispatch(ctx, append( + appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())), + mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, + ))) + + // close these disk files to produce error + filepath.Walk("/proc/self/fd", func(path string, info fs.FileInfo, err error) error { + if err != nil { + return nil + } + target, err := os.Readlink(path) + if err != nil { + return nil + } + if strings.HasPrefix(target, tmpStoragePath) { + fd, err := strconv.Atoi(filepath.Base(path)) + require.NoError(t, err) + require.NoError(t, syscall.Close(fd)) + } + return nil + }) + + // it'll get "bad file descriptor", as it has been closed in the test. + require.Error(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 1024))) + // after getting a failed FETCH, the cursor should have been reseted + require.False(t, stmt.GetCursorActive()) + require.Len(t, tk.Session().GetSessionVars().MemTracker.GetChildrenForTest(), 0) + require.Len(t, tk.Session().GetSessionVars().DiskTracker.GetChildrenForTest(), 0) +} diff --git a/server/driver.go b/server/driver.go index 7fdebdd2739ff..486490b3c2ec3 100644 --- a/server/driver.go +++ b/server/driver.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/extension" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" ) @@ -53,13 +54,13 @@ type PreparedStatement interface { GetParamsType() []byte // StoreResultSet stores ResultSet for subsequent stmt fetching - StoreResultSet(rs ResultSet) + StoreResultSet(rs cursorResultSet) // GetResultSet gets ResultSet associated this statement - GetResultSet() ResultSet + GetResultSet() cursorResultSet - // Reset removes all bound parameters. - Reset() + // Reset removes all bound parameters and opened resultSet/rowContainer. + Reset() error // Close closes the statement. Close() error @@ -69,6 +70,13 @@ type PreparedStatement interface { // SetCursorActive sets whether the statement has active cursor SetCursorActive(active bool) + + // StoreRowContainer stores a row container into the prepared statement. The `rowContainer` is used to be closed at + // appropriate time. It's actually not used to read, because an iterator of it has been stored in the result set. + StoreRowContainer(container *chunk.RowContainer) + + // GetRowContainer returns the row container of the statement + GetRowContainer() *chunk.RowContainer } // ResultSet is the result set of an query. @@ -76,11 +84,18 @@ type ResultSet interface { Columns() []*ColumnInfo NewChunk(chunk.Allocator) *chunk.Chunk Next(context.Context, *chunk.Chunk) error - StoreFetchedRows(rows []chunk.Row) - GetFetchedRows() []chunk.Row Close() error // IsClosed checks whether the result set is closed. IsClosed() bool + FieldTypes() []*types.FieldType +} + +// cursorResultSet extends the `ResultSet` to provide the ability to store an iterator +type cursorResultSet interface { + ResultSet + + StoreRowContainerReader(reader chunk.RowContainerReader) + GetRowContainerReader() chunk.RowContainerReader } // fetchNotifier represents notifier will be called in COM_FETCH. @@ -89,3 +104,9 @@ type fetchNotifier interface { // it will be used in server-side cursor. OnFetchReturned() } + +func wrapWithCursor(rs ResultSet) cursorResultSet { + return &tidbCursorResultSet{ + rs, nil, + } +} diff --git a/server/driver_tidb.go b/server/driver_tidb.go index 5b3980f0b8d37..7f4a8e3460531 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -66,11 +66,14 @@ type TiDBStatement struct { boundParams [][]byte paramsType []byte ctx *TiDBContext - // this result set should have been closed before stored here. Only the `fetchedRows` are used here. This field is + // this result set should have been closed before stored here. Only the `rowIterator` are used here. This field is // not moved out to reuse the logic inside functions `writeResultSet...` // TODO: move the `fetchedRows` into the statement, and remove the `ResultSet` from statement. - rs ResultSet - sql string + rs cursorResultSet + // the `rowContainer` should contain all pre-fetched results of the statement in `EXECUTE` command. + // it's stored here to be closed in RESET and CLOSE command + rowContainer *chunk.RowContainer + sql string hasActiveCursor bool } @@ -131,32 +134,61 @@ func (ts *TiDBStatement) GetParamsType() []byte { } // StoreResultSet stores ResultSet for stmt fetching -func (ts *TiDBStatement) StoreResultSet(rs ResultSet) { - // refer to https://dev.mysql.com/doc/refman/5.7/en/cursor-restrictions.html - // You can have open only a single cursor per prepared statement. - // closing previous ResultSet before associating a new ResultSet with this statement - // if it exists - if ts.rs != nil { - terror.Call(ts.rs.Close) - } +func (ts *TiDBStatement) StoreResultSet(rs cursorResultSet) { + // the original reset set should have been closed, and it's only used to store the iterator through the rowContainer + // so it's fine to just overwrite it. ts.rs = rs } // GetResultSet gets ResultSet associated this statement -func (ts *TiDBStatement) GetResultSet() ResultSet { +func (ts *TiDBStatement) GetResultSet() cursorResultSet { return ts.rs } // Reset implements PreparedStatement Reset method. -func (ts *TiDBStatement) Reset() { +func (ts *TiDBStatement) Reset() error { for i := range ts.boundParams { ts.boundParams[i] = nil } ts.hasActiveCursor = false + + if ts.rs != nil && ts.rs.GetRowContainerReader() != nil { + ts.rs.GetRowContainerReader().Close() + } + ts.rs = nil + + if ts.rowContainer != nil { + ts.rowContainer.GetMemTracker().Detach() + ts.rowContainer.GetDiskTracker().Detach() + + rc := ts.rowContainer + ts.rowContainer = nil + + err := rc.Close() + if err != nil { + return err + } + } + + return nil } // Close implements PreparedStatement Close method. func (ts *TiDBStatement) Close() error { + if ts.rs != nil && ts.rs.GetRowContainerReader() != nil { + ts.rs.GetRowContainerReader().Close() + } + + if ts.rowContainer != nil { + ts.rowContainer.GetMemTracker().Detach() + ts.rowContainer.GetDiskTracker().Detach() + + err := ts.rowContainer.Close() + if err != nil { + return err + } + } + // TODO close at tidb level if ts.ctx.GetSessionVars().TxnCtx != nil && ts.ctx.GetSessionVars().TxnCtx.CouldRetry { err := ts.ctx.DropPreparedStmt(ts.id) @@ -196,6 +228,16 @@ func (ts *TiDBStatement) SetCursorActive(fetchEnd bool) { ts.hasActiveCursor = fetchEnd } +// StoreRowContainer stores a row container into the prepared statement +func (ts *TiDBStatement) StoreRowContainer(c *chunk.RowContainer) { + ts.rowContainer = c +} + +// GetRowContainer returns the row container of the statement +func (ts *TiDBStatement) GetRowContainer() *chunk.RowContainer { + return ts.rowContainer +} + // OpenCtx implements IDriver. func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState, extensions *extension.SessionExtensions) (*TiDBContext, error) { se, err := session.CreateSession(qd.store) @@ -415,7 +457,6 @@ func (tc *TiDBContext) DecodeSessionStates(ctx context.Context, sctx sessionctx. type tidbResultSet struct { recordSet sqlexec.RecordSet columns []*ColumnInfo - rows []chunk.Row closed int32 preparedStmt *core.PlanCacheStmt } @@ -428,17 +469,6 @@ func (trs *tidbResultSet) Next(ctx context.Context, req *chunk.Chunk) error { return trs.recordSet.Next(ctx, req) } -func (trs *tidbResultSet) StoreFetchedRows(rows []chunk.Row) { - trs.rows = rows -} - -func (trs *tidbResultSet) GetFetchedRows() []chunk.Row { - if trs.rows == nil { - trs.rows = make([]chunk.Row, 0, 1024) - } - return trs.rows -} - func (trs *tidbResultSet) Close() error { if !atomic.CompareAndSwapInt32(&trs.closed, 0, 1) { return nil @@ -485,6 +515,30 @@ func (trs *tidbResultSet) Columns() []*ColumnInfo { return trs.columns } +func (trs *tidbResultSet) FieldTypes() []*types.FieldType { + fts := make([]*types.FieldType, 0, len(trs.recordSet.Fields())) + for _, f := range trs.recordSet.Fields() { + fts = append(fts, &f.Column.FieldType) + } + return fts +} + +var _ cursorResultSet = &tidbCursorResultSet{} + +type tidbCursorResultSet struct { + ResultSet + + reader chunk.RowContainerReader +} + +func (tcrs *tidbCursorResultSet) StoreRowContainerReader(reader chunk.RowContainerReader) { + tcrs.reader = reader +} + +func (tcrs *tidbCursorResultSet) GetRowContainerReader() chunk.RowContainerReader { + return tcrs.reader +} + func convertColumnInfo(fld *ast.ResultField) (ci *ColumnInfo) { ci = &ColumnInfo{ Name: fld.ColumnAsName.O, diff --git a/sessionctx/sessionstates/session_states_test.go b/sessionctx/sessionstates/session_states_test.go index d0b2b5e9ec83f..9eafc7b577812 100644 --- a/sessionctx/sessionstates/session_states_test.go +++ b/sessionctx/sessionstates/session_states_test.go @@ -926,9 +926,9 @@ func TestPreparedStatements(t *testing.T) { require.NoError(t, conn.Dispatch(context.Background(), cmd)) cmd = getFetchBytes(1, 10) require.NoError(t, conn.Dispatch(context.Background(), cmd)) - // This COM_STMT_FETCH returns EOF. + // This COM_STMT_FETCH returns error, because the cursor has been automatically closed. cmd = getFetchBytes(1, 10) - require.NoError(t, conn.Dispatch(context.Background(), cmd)) + require.Error(t, conn.Dispatch(context.Background(), cmd)) return uint32(1) }, checkFunc: func(tk *testkit.TestKit, conn server.MockConn, param any) { @@ -1289,7 +1289,7 @@ func TestShowStateFail(t *testing.T) { }, }, { - // fetched all the data but the EOF packet is not sent + // fetched all the data and `ServerStatusLastRowSend` is set, then the cursor should have been closed setFunc: func(tk *testkit.TestKit, conn server.MockConn) { tk.MustExec("create table test.t1(id int)") tk.MustExec("insert test.t1 value(1), (2), (3)") @@ -1299,26 +1299,8 @@ func TestShowStateFail(t *testing.T) { require.NoError(t, conn.Dispatch(context.Background(), cmd)) cmd = getFetchBytes(1, 10) require.NoError(t, conn.Dispatch(context.Background(), cmd)) - }, - showErr: errno.ErrCannotMigrateSession, - cleanFunc: func(tk *testkit.TestKit) { - tk.MustExec("drop table test.t1") - }, - }, - { - // EOF is sent - setFunc: func(tk *testkit.TestKit, conn server.MockConn) { - tk.MustExec("create table test.t1(id int)") - tk.MustExec("insert test.t1 value(1), (2), (3)") - cmd := append([]byte{mysql.ComStmtPrepare}, []byte("select * from test.t1")...) - require.NoError(t, conn.Dispatch(context.Background(), cmd)) - cmd = getExecuteBytes(1, true, false) - require.NoError(t, conn.Dispatch(context.Background(), cmd)) - cmd = getFetchBytes(1, 10) - require.NoError(t, conn.Dispatch(context.Background(), cmd)) - // This COM_STMT_FETCH returns EOF. - cmd = getFetchBytes(1, 10) - require.NoError(t, conn.Dispatch(context.Background(), cmd)) + // following FETCH command should fail because the cursor has been closed + require.Error(t, conn.Dispatch(context.Background(), getFetchBytes(1, 10))) }, cleanFunc: func(tk *testkit.TestKit) { tk.MustExec("drop table test.t1") diff --git a/util/memory/action.go b/util/memory/action.go index c1a6e9b581c99..9b141a96dc6e3 100644 --- a/util/memory/action.go +++ b/util/memory/action.go @@ -44,6 +44,40 @@ type ActionOnExceed interface { IsFinished() bool } +var _ ActionOnExceed = &actionWithPriority{} + +type actionWithPriority struct { + ActionOnExceed + priority int64 +} + +// NewActionWithPriority wraps the action with a new priority +func NewActionWithPriority(action ActionOnExceed, priority int64) *actionWithPriority { + return &actionWithPriority{ + action, + priority, + } +} + +func (a *actionWithPriority) GetPriority() int64 { + return a.priority +} + +// ActionInvoker indicates the invoker of the Action. +type ActionInvoker byte + +const ( + // SingleQuery indicates the Action is invoked by a tidb_mem_quota_query. + SingleQuery ActionInvoker = iota + // Instance indicates the Action is invoked by a tidb_server_memory_limit. + Instance +) + +// ActionCareInvoker is the interface for the Actions which need to be aware of the invoker. +type ActionCareInvoker interface { + SetInvoker(invoker ActionInvoker) +} + // BaseOOMAction manages the fallback action for all Action. type BaseOOMAction struct { fallbackAction ActionOnExceed @@ -79,6 +113,9 @@ const ( DefPanicPriority = iota DefLogPriority DefSpillPriority + // DefCursorFetchSpillPriority is higher than normal disk spill, because it can release much more memory in the future. + // And the performance impaction of it is less than other disk-spill action, because it's write-only in execution stage. + DefCursorFetchSpillPriority DefRateLimitPriority ) diff --git a/util/memory/tracker.go b/util/memory/tracker.go index a40c73dab6e51..e39df98d848fd 100644 --- a/util/memory/tracker.go +++ b/util/memory/tracker.go @@ -854,6 +854,8 @@ const ( LabelForSession int = -27 // LabelForMemDB represents the label of the MemDB LabelForMemDB int = -28 + // LabelForCursorFetch represents the label of the execution of cursor fetch + LabelForCursorFetch int = -29 ) // MetricsTypes is used to get label for metrics