From a7014180f73a5b4d47e8c10081ada377ffa35533 Mon Sep 17 00:00:00 2001 From: Yang Keao Date: Thu, 29 Jun 2023 14:29:21 +0800 Subject: [PATCH] send row in FETCH command row by row Signed-off-by: Yang Keao --- server/conn.go | 71 +++++++++++++++------------------- server/conn_stmt_test.go | 82 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 40 deletions(-) diff --git a/server/conn.go b/server/conn.go index 8349e067d9918..ff25d3dfde121 100644 --- a/server/conn.go +++ b/server/conn.go @@ -2410,14 +2410,40 @@ func (cc *clientConn) writeChunks(ctx context.Context, rs ResultSet, binary bool // serverStatus, a flag bit represents server information. // fetchSize, the desired number of rows to be fetched each time when client uses cursor. func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs cursorResultSet, serverStatus uint16, fetchSize int) error { - // construct the rows sent to the client according to fetchSize. - var curRows []chunk.Row + var ( + stmtDetail *execdetails.StmtExecDetails + err error + start time.Time + ) + data := cc.alloc.AllocWithLen(4, 1024) + stmtDetailRaw := ctx.Value(execdetails.StmtExecDetailKey) + if stmtDetailRaw != nil { + //nolint:forcetypeassert + stmtDetail = stmtDetailRaw.(*execdetails.StmtExecDetails) + } + if stmtDetail != nil { + start = time.Now() + } + iter := rs.GetRowContainerReader() // send the rows to the client according to fetchSize. for i := 0; i < fetchSize && iter.Current() != iter.End(); i++ { - curRows = append(curRows, iter.Current()) + row := iter.Current() + + data = data[0:4] + data, err = dumpBinaryRow(data, rs.Columns(), row, cc.rsEncoder) + if err != nil { + return err + } + if err = cc.writePacket(data); err != nil { + return err + } + iter.Next() } + if iter.Error() != nil { + return iter.Error() + } // tell the client COM_STMT_FETCH has finished by setting proper serverStatus, // and close ResultSet. @@ -2426,43 +2452,8 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs cursorRes serverStatus |= mysql.ServerStatusLastRowSend } - var ( - stmtDetail *execdetails.StmtExecDetails - err error - start time.Time - ) - - if len(curRows) != 0 { - data := cc.alloc.AllocWithLen(4, 1024) - stmtDetailRaw := ctx.Value(execdetails.StmtExecDetailKey) - if stmtDetailRaw != nil { - //nolint:forcetypeassert - stmtDetail = stmtDetailRaw.(*execdetails.StmtExecDetails) - } - if stmtDetail != nil { - start = time.Now() - } - - for _, row := range curRows { - data = data[0:4] - data, err = dumpBinaryRow(data, rs.Columns(), row, cc.rsEncoder) - if err != nil { - return err - } - if err = cc.writePacket(data); err != nil { - return err - } - } - - if stmtDetail != nil { - stmtDetail.WriteSQLRespDuration += time.Since(start) - } - if cl, ok := rs.(fetchNotifier); ok { - cl.OnFetchReturned() - } - if stmtDetail != nil { - start = time.Now() - } + if cl, ok := rs.(fetchNotifier); ok { + cl.OnFetchReturned() } err = cc.writeEOF(ctx, serverStatus) diff --git a/server/conn_stmt_test.go b/server/conn_stmt_test.go index 7c6c665a46b0f..a3f8b603ff6dc 100644 --- a/server/conn_stmt_test.go +++ b/server/conn_stmt_test.go @@ -17,11 +17,19 @@ package server import ( "bytes" "context" + "crypto/rand" "encoding/binary" "fmt" + "io/fs" + "os" + "path/filepath" + "strconv" + "strings" + "syscall" "testing" "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/testkit" "github.com/stretchr/testify/require" @@ -235,6 +243,12 @@ func TestMemoryTrackForPrepareBinaryProtocol(t *testing.T) { } func TestCursorFetchShouldSpill(t *testing.T) { + restore := config.RestoreFunc() + defer restore() + config.UpdateGlobal(func(conf *config.Config) { + conf.TempStoragePath = t.TempDir() + }) + store, dom := testkit.CreateMockStoreAndDomain(t) srv := CreateMockServer(t, store) srv.SetDomain(dom) @@ -270,3 +284,71 @@ func TestCursorFetchShouldSpill(t *testing.T) { mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, ))) } + +func TestCursorFetchErrorInFetch(t *testing.T) { + tmpStoragePath := t.TempDir() + restore := config.RestoreFunc() + defer restore() + config.UpdateGlobal(func(conf *config.Config) { + conf.TempStoragePath = tmpStoragePath + }) + + store, dom := testkit.CreateMockStoreAndDomain(t) + srv := CreateMockServer(t, store) + srv.SetDomain(dom) + defer srv.Close() + + appendUint32 := binary.LittleEndian.AppendUint32 + ctx := context.Background() + c := CreateMockConn(t, srv).(*mockConn) + + tk := testkit.NewTestKitWithSession(t, store, c.Context().Session) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(id int, payload BLOB)") + payload := make([]byte, 512) + for i := 0; i < 2048; i++ { + rand.Read(payload) + tk.MustExec("insert into t values (?, ?)", i, payload) + } + + tk.MustExec("set global tidb_enable_tmp_storage_on_oom = ON") + tk.MustExec("set global tidb_mem_oom_action = 'CANCEL'") + defer tk.MustExec("set global tidb_mem_oom_action= DEFAULT") + require.Len(t, tk.Session().GetSessionVars().MemTracker.GetChildrenForTest(), 1) + + // execute a normal statement, it'll spill to disk + stmt, _, _, err := c.Context().Prepare("select * from t") + require.NoError(t, err) + + tk.MustExec(fmt.Sprintf("set tidb_mem_quota_query=%d", 1)) + + require.NoError(t, c.Dispatch(ctx, append( + appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())), + mysql.CursorTypeReadOnly, 0x1, 0x0, 0x0, 0x0, + ))) + + // close these disk files to produce error + filepath.Walk("/proc/self/fd", func(path string, info fs.FileInfo, err error) error { + if err != nil { + return nil + } + target, err := os.Readlink(path) + if err != nil { + return nil + } + if strings.HasPrefix(target, tmpStoragePath) { + fd, err := strconv.Atoi(filepath.Base(path)) + require.NoError(t, err) + require.NoError(t, syscall.Close(fd)) + } + return nil + }) + + // it'll get "bad file descriptor", as it has been closed in the test. + require.Error(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 1024))) + // after getting a failed FETCH, the cursor should have been reseted + require.False(t, stmt.GetCursorActive()) + require.Len(t, tk.Session().GetSessionVars().MemTracker.GetChildrenForTest(), 0) + require.Len(t, tk.Session().GetSessionVars().DiskTracker.GetChildrenForTest(), 0) +}