From d8c502804f8a508c95c57350e2de83b01d244c32 Mon Sep 17 00:00:00 2001 From: qw4990 Date: Mon, 11 Feb 2019 15:24:31 +0800 Subject: [PATCH 1/8] add requiredRows field into RecordBatch --- util/chunk/recordbatch.go | 33 ++++++++++++++++++++- util/chunk/recordbatch_test.go | 54 ++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 util/chunk/recordbatch_test.go diff --git a/util/chunk/recordbatch.go b/util/chunk/recordbatch.go index 7eb79f54f4333..fcb70be622298 100644 --- a/util/chunk/recordbatch.go +++ b/util/chunk/recordbatch.go @@ -13,12 +13,43 @@ package chunk +import ( + "github.com/cznic/mathutil" +) + +const UnspecifiedNumRows = mathutil.MaxInt + // RecordBatch is input parameter of Executor.Next` method. type RecordBatch struct { *Chunk + + // requiredRows indicates how many rows is considered full for parent executor. + // Child executor can return immediately if there are such number of rows, + // instead of fulling the whole chunk. + // This is not compulsory, so the number of returned rows can be larger than it in some cases. + requiredRows int } // NewRecordBatch is used to construct a RecordBatch. func NewRecordBatch(chk *Chunk) *RecordBatch { - return &RecordBatch{chk} + return &RecordBatch{chk, UnspecifiedNumRows} +} + +// SetRequiredRows sets the number of rows the parent executor want. +func (rb *RecordBatch) SetRequiredRows(numRows int) *RecordBatch { + if numRows <= 0 { + numRows = UnspecifiedNumRows + } + rb.requiredRows = numRows + return rb +} + +// RequiredRows returns how many rows the parent executor want. +func (rb *RecordBatch) RequiredRows() int { + return rb.requiredRows +} + +func (rb *RecordBatch) IsFull() bool { + numRows := rb.NumRows() + return numRows >= rb.Capacity() || numRows >= rb.requiredRows } diff --git a/util/chunk/recordbatch_test.go b/util/chunk/recordbatch_test.go new file mode 100644 index 0000000000000..b2821de7d29fb --- /dev/null +++ b/util/chunk/recordbatch_test.go @@ -0,0 +1,54 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package chunk + +import ( + "github.com/pingcap/check" + "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/types" +) + +func (s *testChunkSuite) TestRecordBatch(c *check.C) { + chk := New([]*types.FieldType{types.NewFieldType(mysql.TypeLong)}, 10, 10) + batch := NewRecordBatch(chk) + c.Assert(batch.RequiredRows(), check.Equals, UnspecifiedNumRows) + for i := 1; i < 10; i++ { + batch.SetRequiredRows(i) + c.Assert(batch.RequiredRows(), check.Equals, i) + } + batch.SetRequiredRows(0) + c.Assert(batch.RequiredRows(), check.Equals, UnspecifiedNumRows) + batch.SetRequiredRows(-1) + c.Assert(batch.RequiredRows(), check.Equals, UnspecifiedNumRows) + batch.SetRequiredRows(1).SetRequiredRows(2).SetRequiredRows(3) + c.Assert(batch.RequiredRows(), check.Equals, 3) + + batch.SetRequiredRows(5) + batch.AppendInt64(0, 1) + batch.AppendInt64(0, 1) + batch.AppendInt64(0, 1) + batch.AppendInt64(0, 1) + c.Assert(batch.NumRows(), check.Equals, 4) + c.Assert(batch.IsFull(), check.IsFalse) + + batch.AppendInt64(0, 1) + c.Assert(batch.NumRows(), check.Equals, 5) + c.Assert(batch.IsFull(), check.IsTrue) + + batch.AppendInt64(0, 1) + batch.AppendInt64(0, 1) + batch.AppendInt64(0, 1) + c.Assert(batch.NumRows(), check.Equals, 8) + c.Assert(batch.IsFull(), check.IsTrue) +} From 747ec825de075223066abbce1e9258b4b62cddf6 Mon Sep 17 00:00:00 2001 From: qw4990 Date: Mon, 11 Feb 2019 16:28:33 +0800 Subject: [PATCH 2/8] make SelectResult support batch size control --- distsql/distsql_test.go | 107 +++++++++++++++++++++++++++++++++++---- distsql/select_result.go | 17 +++---- distsql/stream.go | 16 +++--- 3 files changed, 112 insertions(+), 28 deletions(-) diff --git a/distsql/distsql_test.go b/distsql/distsql_test.go index e196015b243e6..483dbe1674c29 100644 --- a/distsql/distsql_test.go +++ b/distsql/distsql_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/cznic/mathutil" . "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/parser/charset" @@ -34,7 +35,7 @@ import ( "github.com/pingcap/tipb/go-tipb" ) -func (s *testSuite) TestSelectNormal(c *C) { +func (s *testSuite) createSelectNormal(batch, totalRows int, c *C) (*selectResult, []*types.FieldType) { request, err := (&RequestBuilder{}).SetKeyRanges(nil). SetDAGRequest(&tipb.DAGRequest{}). SetDesc(false). @@ -67,13 +68,24 @@ func (s *testSuite) TestSelectNormal(c *C) { c.Assert(result.sqlType, Equals, "general") c.Assert(result.rowLen, Equals, len(colTypes)) + resp, ok := result.resp.(*mockResponse) + c.Assert(ok, IsTrue) + resp.total = totalRows + resp.batch = batch + + return result, colTypes +} + +func (s *testSuite) TestSelectNormal(c *C) { + response, colTypes := s.createSelectNormal(1, 2, c) response.Fetch(context.TODO()) // Test Next. chk := chunk.New(colTypes, 32, 32) + batch := chunk.NewRecordBatch(chk) numAllRows := 0 for { - err = response.Next(context.TODO(), chk) + err := response.Next(context.TODO(), batch) c.Assert(err, IsNil) numAllRows += chk.NumRows() if chk.NumRows() == 0 { @@ -81,11 +93,17 @@ func (s *testSuite) TestSelectNormal(c *C) { } } c.Assert(numAllRows, Equals, 2) - err = response.Close() + err := response.Close() c.Assert(err, IsNil) } -func (s *testSuite) TestSelectStreaming(c *C) { +func (s *testSuite) TestSelectNormalBatchSize(c *C) { + response, colTypes := s.createSelectNormal(100, 1000000, c) + response.Fetch(context.TODO()) + s.testBatchSize(response, colTypes, c) +} + +func (s *testSuite) createSelectStreaming(batch, totalRows int, c *C) (*streamResult, []*types.FieldType) { request, err := (&RequestBuilder{}).SetKeyRanges(nil). SetDAGRequest(&tipb.DAGRequest{}). SetDesc(false). @@ -112,20 +130,30 @@ func (s *testSuite) TestSelectStreaming(c *C) { s.sctx.GetSessionVars().EnableStreaming = true - // Test Next. response, err := Select(context.TODO(), s.sctx, request, colTypes, statistics.NewQueryFeedback(0, nil, 0, false)) c.Assert(err, IsNil) result, ok := response.(*streamResult) c.Assert(ok, IsTrue) c.Assert(result.rowLen, Equals, len(colTypes)) + resp, ok := result.resp.(*mockResponse) + c.Assert(ok, IsTrue) + resp.total = totalRows + resp.batch = batch + + return result, colTypes +} + +func (s *testSuite) TestSelectStreaming(c *C) { + response, colTypes := s.createSelectStreaming(1, 2, c) response.Fetch(context.TODO()) // Test Next. chk := chunk.New(colTypes, 32, 32) + batch := chunk.NewRecordBatch(chk) numAllRows := 0 for { - err = response.Next(context.TODO(), chk) + err := response.Next(context.TODO(), batch) c.Assert(err, IsNil) numAllRows += chk.NumRows() if chk.NumRows() == 0 { @@ -133,10 +161,60 @@ func (s *testSuite) TestSelectStreaming(c *C) { } } c.Assert(numAllRows, Equals, 2) - err = response.Close() + err := response.Close() c.Assert(err, IsNil) } +func (s *testSuite) TestSelectStreamingBatchSize(c *C) { + response, colTypes := s.createSelectStreaming(100, 1000000, c) + response.Fetch(context.TODO()) + s.testBatchSize(response, colTypes, c) +} + +func (s *testSuite) testBatchSize(response SelectResult, colTypes []*types.FieldType, c *C) { + chk := chunk.New(colTypes, 32, 32) + batch := chunk.NewRecordBatch(chk) + + err := response.Next(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 32) + + batch.SetRequiredRows(1) + err = response.Next(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 1) + + batch.SetRequiredRows(2) + err = response.Next(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 2) + + batch.SetRequiredRows(17) + err = response.Next(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 17) + + batch.SetRequiredRows(170) + err = response.Next(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 32) + + batch.SetRequiredRows(32) + err = response.Next(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 32) + + batch.SetRequiredRows(0) + err = response.Next(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 32) + + batch.SetRequiredRows(-1) + err = response.Next(context.TODO(), batch) + c.Assert(err, IsNil) + c.Assert(batch.NumRows(), Equals, 32) +} + func (s *testSuite) TestAnalyze(c *C) { request, err := (&RequestBuilder{}).SetKeyRanges(nil). SetAnalyzeRequest(&tipb.AnalyzeReq{}). @@ -166,6 +244,8 @@ func (s *testSuite) TestAnalyze(c *C) { // Used only for test. type mockResponse struct { count int + total int + batch int sync.Mutex } @@ -183,17 +263,24 @@ func (resp *mockResponse) Next(ctx context.Context) (kv.ResultSubset, error) { resp.Lock() defer resp.Unlock() - if resp.count == 2 { + if resp.count >= resp.total { return nil, nil } - defer func() { resp.count++ }() + numRows := mathutil.Min(resp.batch, resp.total-resp.count) + resp.count += numRows datum := types.NewIntDatum(1) bytes := make([]byte, 0, 100) bytes, _ = codec.EncodeValue(nil, bytes, datum, datum, datum, datum) + chunks := make([]tipb.Chunk, numRows) + for i := range chunks { + chkData := make([]byte, len(bytes)) + copy(chkData, bytes) + chunks[i] = tipb.Chunk{RowsData: chkData} + } respPB := &tipb.SelectResponse{ - Chunks: []tipb.Chunk{{RowsData: bytes}}, + Chunks: chunks, OutputCounts: []int64{1}, } respBytes, err := respPB.Marshal() diff --git a/distsql/select_result.go b/distsql/select_result.go index 5badfc624ec1c..92dd7289660ac 100644 --- a/distsql/select_result.go +++ b/distsql/select_result.go @@ -41,7 +41,7 @@ type SelectResult interface { // NextRaw gets the next raw result. NextRaw(context.Context) ([]byte, error) // Next reads the data into chunk. - Next(context.Context, *chunk.Chunk) error + Next(ctx context.Context, batch *chunk.RecordBatch) error // Close closes the iterator. Close() error } @@ -114,16 +114,16 @@ func (r *selectResult) NextRaw(ctx context.Context) ([]byte, error) { } // Next reads data to the chunk. -func (r *selectResult) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() - for chk.NumRows() < r.ctx.GetSessionVars().MaxChunkSize { +func (r *selectResult) Next(ctx context.Context, batch *chunk.RecordBatch) error { + batch.Reset() + for !batch.IsFull() { if r.selectResp == nil || r.respChkIdx == len(r.selectResp.Chunks) { err := r.getSelectResp() if err != nil || r.selectResp == nil { return errors.Trace(err) } } - err := r.readRowsData(chk) + err := r.readRowsData(batch) if err != nil { return errors.Trace(err) } @@ -167,11 +167,10 @@ func (r *selectResult) getSelectResp() error { } } -func (r *selectResult) readRowsData(chk *chunk.Chunk) (err error) { +func (r *selectResult) readRowsData(batch *chunk.RecordBatch) (err error) { rowsData := r.selectResp.Chunks[r.respChkIdx].RowsData - maxChunkSize := r.ctx.GetSessionVars().MaxChunkSize - decoder := codec.NewDecoder(chk, r.ctx.GetSessionVars().Location()) - for chk.NumRows() < maxChunkSize && len(rowsData) > 0 { + decoder := codec.NewDecoder(batch.Chunk, r.ctx.GetSessionVars().Location()) + for !batch.IsFull() && len(rowsData) > 0 { for i := 0; i < r.rowLen; i++ { rowsData, err = decoder.DecodeOne(rowsData, i, r.fieldTypes[i]) if err != nil { diff --git a/distsql/stream.go b/distsql/stream.go index dada7053f7a09..cff6be72a3e46 100644 --- a/distsql/stream.go +++ b/distsql/stream.go @@ -43,10 +43,9 @@ type streamResult struct { func (r *streamResult) Fetch(context.Context) {} -func (r *streamResult) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() - maxChunkSize := r.ctx.GetSessionVars().MaxChunkSize - for chk.NumRows() < maxChunkSize { +func (r *streamResult) Next(ctx context.Context, batch *chunk.RecordBatch) error { + batch.Reset() + for !batch.IsFull() { err := r.readDataIfNecessary(ctx) if err != nil { return errors.Trace(err) @@ -55,7 +54,7 @@ func (r *streamResult) Next(ctx context.Context, chk *chunk.Chunk) error { return nil } - err = r.flushToChunk(chk) + err = r.flushToChunk(batch) if err != nil { return errors.Trace(err) } @@ -113,11 +112,10 @@ func (r *streamResult) readDataIfNecessary(ctx context.Context) error { return nil } -func (r *streamResult) flushToChunk(chk *chunk.Chunk) (err error) { +func (r *streamResult) flushToChunk(batch *chunk.RecordBatch) (err error) { remainRowsData := r.curr.RowsData - maxChunkSize := r.ctx.GetSessionVars().MaxChunkSize - decoder := codec.NewDecoder(chk, r.ctx.GetSessionVars().Location()) - for chk.NumRows() < maxChunkSize && len(remainRowsData) > 0 { + decoder := codec.NewDecoder(batch.Chunk, r.ctx.GetSessionVars().Location()) + for !batch.IsFull() && len(remainRowsData) > 0 { for i := 0; i < r.rowLen; i++ { remainRowsData, err = decoder.DecodeOne(remainRowsData, i, r.fieldTypes[i]) if err != nil { From d623c06d0a92637514b267103bbd7b4edcb45526 Mon Sep 17 00:00:00 2001 From: qw4990 Date: Mon, 11 Feb 2019 16:44:19 +0800 Subject: [PATCH 3/8] add a new method NextBatch to avoid big change --- distsql/distsql_test.go | 26 ++++++++++++++------------ distsql/select_result.go | 12 ++++++++++-- distsql/stream.go | 8 +++++++- 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/distsql/distsql_test.go b/distsql/distsql_test.go index 483dbe1674c29..22411ad7d9134 100644 --- a/distsql/distsql_test.go +++ b/distsql/distsql_test.go @@ -82,10 +82,9 @@ func (s *testSuite) TestSelectNormal(c *C) { // Test Next. chk := chunk.New(colTypes, 32, 32) - batch := chunk.NewRecordBatch(chk) numAllRows := 0 for { - err := response.Next(context.TODO(), batch) + err := response.Next(context.TODO(), chk) c.Assert(err, IsNil) numAllRows += chk.NumRows() if chk.NumRows() == 0 { @@ -150,10 +149,9 @@ func (s *testSuite) TestSelectStreaming(c *C) { // Test Next. chk := chunk.New(colTypes, 32, 32) - batch := chunk.NewRecordBatch(chk) numAllRows := 0 for { - err := response.Next(context.TODO(), batch) + err := response.Next(context.TODO(), chk) c.Assert(err, IsNil) numAllRows += chk.NumRows() if chk.NumRows() == 0 { @@ -175,42 +173,46 @@ func (s *testSuite) testBatchSize(response SelectResult, colTypes []*types.Field chk := chunk.New(colTypes, 32, 32) batch := chunk.NewRecordBatch(chk) - err := response.Next(context.TODO(), batch) + err := response.Next(context.TODO(), chk) + c.Assert(err, IsNil) + c.Assert(chk.NumRows(), Equals, 32) + + err = response.NextBatch(context.TODO(), batch) c.Assert(err, IsNil) c.Assert(batch.NumRows(), Equals, 32) batch.SetRequiredRows(1) - err = response.Next(context.TODO(), batch) + err = response.NextBatch(context.TODO(), batch) c.Assert(err, IsNil) c.Assert(batch.NumRows(), Equals, 1) batch.SetRequiredRows(2) - err = response.Next(context.TODO(), batch) + err = response.NextBatch(context.TODO(), batch) c.Assert(err, IsNil) c.Assert(batch.NumRows(), Equals, 2) batch.SetRequiredRows(17) - err = response.Next(context.TODO(), batch) + err = response.NextBatch(context.TODO(), batch) c.Assert(err, IsNil) c.Assert(batch.NumRows(), Equals, 17) batch.SetRequiredRows(170) - err = response.Next(context.TODO(), batch) + err = response.NextBatch(context.TODO(), batch) c.Assert(err, IsNil) c.Assert(batch.NumRows(), Equals, 32) batch.SetRequiredRows(32) - err = response.Next(context.TODO(), batch) + err = response.NextBatch(context.TODO(), batch) c.Assert(err, IsNil) c.Assert(batch.NumRows(), Equals, 32) batch.SetRequiredRows(0) - err = response.Next(context.TODO(), batch) + err = response.NextBatch(context.TODO(), batch) c.Assert(err, IsNil) c.Assert(batch.NumRows(), Equals, 32) batch.SetRequiredRows(-1) - err = response.Next(context.TODO(), batch) + err = response.NextBatch(context.TODO(), batch) c.Assert(err, IsNil) c.Assert(batch.NumRows(), Equals, 32) } diff --git a/distsql/select_result.go b/distsql/select_result.go index 92dd7289660ac..605eccb85073c 100644 --- a/distsql/select_result.go +++ b/distsql/select_result.go @@ -41,7 +41,10 @@ type SelectResult interface { // NextRaw gets the next raw result. NextRaw(context.Context) ([]byte, error) // Next reads the data into chunk. - Next(ctx context.Context, batch *chunk.RecordBatch) error + // TODO: replace all calls of Next to NextBatch and remove this Next method + Next(ctx context.Context, chk *chunk.Chunk) error + // NextBatch reads the data into batch. + NextBatch(ctx context.Context, batch *chunk.RecordBatch) error // Close closes the iterator. Close() error } @@ -114,7 +117,12 @@ func (r *selectResult) NextRaw(ctx context.Context) ([]byte, error) { } // Next reads data to the chunk. -func (r *selectResult) Next(ctx context.Context, batch *chunk.RecordBatch) error { +func (r *selectResult) Next(ctx context.Context, chk *chunk.Chunk) error { + return r.NextBatch(ctx, chunk.NewRecordBatch(chk)) +} + +// NextBatch reads the data into batch. +func (r *selectResult) NextBatch(ctx context.Context, batch *chunk.RecordBatch) error { batch.Reset() for !batch.IsFull() { if r.selectResp == nil || r.respChkIdx == len(r.selectResp.Chunks) { diff --git a/distsql/stream.go b/distsql/stream.go index cff6be72a3e46..c10e3e7ebee5f 100644 --- a/distsql/stream.go +++ b/distsql/stream.go @@ -43,7 +43,13 @@ type streamResult struct { func (r *streamResult) Fetch(context.Context) {} -func (r *streamResult) Next(ctx context.Context, batch *chunk.RecordBatch) error { +// Next reads data to the chunk. +func (r *streamResult) Next(ctx context.Context, chk *chunk.Chunk) error { + return r.NextBatch(ctx, chunk.NewRecordBatch(chk)) +} + +// NextBatch reads the data into batch. +func (r *streamResult) NextBatch(ctx context.Context, batch *chunk.RecordBatch) error { batch.Reset() for !batch.IsFull() { err := r.readDataIfNecessary(ctx) From 8fadaf765efe6591656b89b5af1903c89d6959e3 Mon Sep 17 00:00:00 2001 From: qw4990 Date: Mon, 11 Feb 2019 19:06:02 +0800 Subject: [PATCH 4/8] fix UT --- distsql/request_builder_test.go | 10 ++++++++-- util/chunk/recordbatch.go | 8 ++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/distsql/request_builder_test.go b/distsql/request_builder_test.go index 10d319b9c3e65..640127f163594 100644 --- a/distsql/request_builder_test.go +++ b/distsql/request_builder_test.go @@ -53,7 +53,10 @@ func (s *testSuite) SetUpSuite(c *C) { ctx := mock.NewContext() ctx.Store = &mock.Store{ Client: &mock.Client{ - MockResponse: &mockResponse{}, + MockResponse: &mockResponse{ + batch: 1, + total: 2, + }, }, } s.sctx = ctx @@ -67,7 +70,10 @@ func (s *testSuite) SetUpTest(c *C) { ctx := s.sctx.(*mock.Context) store := ctx.Store.(*mock.Store) store.Client = &mock.Client{ - MockResponse: &mockResponse{}, + MockResponse: &mockResponse{ + batch: 1, + total: 2, + }, } } diff --git a/util/chunk/recordbatch.go b/util/chunk/recordbatch.go index fcb70be622298..5dc7ae2dfd3b3 100644 --- a/util/chunk/recordbatch.go +++ b/util/chunk/recordbatch.go @@ -13,11 +13,7 @@ package chunk -import ( - "github.com/cznic/mathutil" -) - -const UnspecifiedNumRows = mathutil.MaxInt +const UnspecifiedNumRows = -1 // RecordBatch is input parameter of Executor.Next` method. type RecordBatch struct { @@ -51,5 +47,5 @@ func (rb *RecordBatch) RequiredRows() int { func (rb *RecordBatch) IsFull() bool { numRows := rb.NumRows() - return numRows >= rb.Capacity() || numRows >= rb.requiredRows + return numRows >= rb.Capacity() || (rb.requiredRows != UnspecifiedNumRows && numRows >= rb.requiredRows) } From 7f491985a83e49598857d35d5f225c39b2ebdfca Mon Sep 17 00:00:00 2001 From: qw4990 Date: Mon, 11 Feb 2019 19:28:34 +0800 Subject: [PATCH 5/8] fix UT --- distsql/select_result.go | 4 ++-- distsql/stream.go | 4 ++-- util/chunk/recordbatch.go | 6 ++++-- util/chunk/recordbatch_test.go | 9 +++++---- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/distsql/select_result.go b/distsql/select_result.go index 605eccb85073c..697c1aa7f4510 100644 --- a/distsql/select_result.go +++ b/distsql/select_result.go @@ -124,7 +124,7 @@ func (r *selectResult) Next(ctx context.Context, chk *chunk.Chunk) error { // NextBatch reads the data into batch. func (r *selectResult) NextBatch(ctx context.Context, batch *chunk.RecordBatch) error { batch.Reset() - for !batch.IsFull() { + for !batch.IsFull(r.ctx.GetSessionVars().MaxChunkSize) { if r.selectResp == nil || r.respChkIdx == len(r.selectResp.Chunks) { err := r.getSelectResp() if err != nil || r.selectResp == nil { @@ -178,7 +178,7 @@ func (r *selectResult) getSelectResp() error { func (r *selectResult) readRowsData(batch *chunk.RecordBatch) (err error) { rowsData := r.selectResp.Chunks[r.respChkIdx].RowsData decoder := codec.NewDecoder(batch.Chunk, r.ctx.GetSessionVars().Location()) - for !batch.IsFull() && len(rowsData) > 0 { + for !batch.IsFull(r.ctx.GetSessionVars().MaxChunkSize) && len(rowsData) > 0 { for i := 0; i < r.rowLen; i++ { rowsData, err = decoder.DecodeOne(rowsData, i, r.fieldTypes[i]) if err != nil { diff --git a/distsql/stream.go b/distsql/stream.go index c10e3e7ebee5f..c85ef0eeab3b2 100644 --- a/distsql/stream.go +++ b/distsql/stream.go @@ -51,7 +51,7 @@ func (r *streamResult) Next(ctx context.Context, chk *chunk.Chunk) error { // NextBatch reads the data into batch. func (r *streamResult) NextBatch(ctx context.Context, batch *chunk.RecordBatch) error { batch.Reset() - for !batch.IsFull() { + for !batch.IsFull(r.ctx.GetSessionVars().MaxChunkSize) { err := r.readDataIfNecessary(ctx) if err != nil { return errors.Trace(err) @@ -121,7 +121,7 @@ func (r *streamResult) readDataIfNecessary(ctx context.Context) error { func (r *streamResult) flushToChunk(batch *chunk.RecordBatch) (err error) { remainRowsData := r.curr.RowsData decoder := codec.NewDecoder(batch.Chunk, r.ctx.GetSessionVars().Location()) - for !batch.IsFull() && len(remainRowsData) > 0 { + for !batch.IsFull(r.ctx.GetSessionVars().MaxChunkSize) && len(remainRowsData) > 0 { for i := 0; i < r.rowLen; i++ { remainRowsData, err = decoder.DecodeOne(remainRowsData, i, r.fieldTypes[i]) if err != nil { diff --git a/util/chunk/recordbatch.go b/util/chunk/recordbatch.go index 5dc7ae2dfd3b3..e09ee11328366 100644 --- a/util/chunk/recordbatch.go +++ b/util/chunk/recordbatch.go @@ -13,6 +13,7 @@ package chunk +// UnspecifiedNumRows represents requiredRows is not specified. const UnspecifiedNumRows = -1 // RecordBatch is input parameter of Executor.Next` method. @@ -45,7 +46,8 @@ func (rb *RecordBatch) RequiredRows() int { return rb.requiredRows } -func (rb *RecordBatch) IsFull() bool { +// IsFull returns if this batch can be considered full. +func (rb *RecordBatch) IsFull(maxChunkSize int) bool { numRows := rb.NumRows() - return numRows >= rb.Capacity() || (rb.requiredRows != UnspecifiedNumRows && numRows >= rb.requiredRows) + return numRows >= maxChunkSize || (rb.requiredRows != UnspecifiedNumRows && numRows >= rb.requiredRows) } diff --git a/util/chunk/recordbatch_test.go b/util/chunk/recordbatch_test.go index b2821de7d29fb..5e83d2f80f6d6 100644 --- a/util/chunk/recordbatch_test.go +++ b/util/chunk/recordbatch_test.go @@ -20,7 +20,8 @@ import ( ) func (s *testChunkSuite) TestRecordBatch(c *check.C) { - chk := New([]*types.FieldType{types.NewFieldType(mysql.TypeLong)}, 10, 10) + maxChunkSize := 10 + chk := New([]*types.FieldType{types.NewFieldType(mysql.TypeLong)}, maxChunkSize, maxChunkSize) batch := NewRecordBatch(chk) c.Assert(batch.RequiredRows(), check.Equals, UnspecifiedNumRows) for i := 1; i < 10; i++ { @@ -40,15 +41,15 @@ func (s *testChunkSuite) TestRecordBatch(c *check.C) { batch.AppendInt64(0, 1) batch.AppendInt64(0, 1) c.Assert(batch.NumRows(), check.Equals, 4) - c.Assert(batch.IsFull(), check.IsFalse) + c.Assert(batch.IsFull(maxChunkSize), check.IsFalse) batch.AppendInt64(0, 1) c.Assert(batch.NumRows(), check.Equals, 5) - c.Assert(batch.IsFull(), check.IsTrue) + c.Assert(batch.IsFull(maxChunkSize), check.IsTrue) batch.AppendInt64(0, 1) batch.AppendInt64(0, 1) batch.AppendInt64(0, 1) c.Assert(batch.NumRows(), check.Equals, 8) - c.Assert(batch.IsFull(), check.IsTrue) + c.Assert(batch.IsFull(maxChunkSize), check.IsTrue) } From 75498c0573c77c37f6e3eb1840fc33aa5c25c6b4 Mon Sep 17 00:00:00 2001 From: qw4990 Date: Tue, 12 Feb 2019 14:58:00 +0800 Subject: [PATCH 6/8] address some comments --- distsql/stream.go | 4 ++-- util/chunk/recordbatch.go | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/distsql/stream.go b/distsql/stream.go index c85ef0eeab3b2..a8d9d1445a43f 100644 --- a/distsql/stream.go +++ b/distsql/stream.go @@ -60,7 +60,7 @@ func (r *streamResult) NextBatch(ctx context.Context, batch *chunk.RecordBatch) return nil } - err = r.flushToChunk(batch) + err = r.flushToBatch(batch) if err != nil { return errors.Trace(err) } @@ -118,7 +118,7 @@ func (r *streamResult) readDataIfNecessary(ctx context.Context) error { return nil } -func (r *streamResult) flushToChunk(batch *chunk.RecordBatch) (err error) { +func (r *streamResult) flushToBatch(batch *chunk.RecordBatch) (err error) { remainRowsData := r.curr.RowsData decoder := codec.NewDecoder(batch.Chunk, r.ctx.GetSessionVars().Location()) for !batch.IsFull(r.ctx.GetSessionVars().MaxChunkSize) && len(remainRowsData) > 0 { diff --git a/util/chunk/recordbatch.go b/util/chunk/recordbatch.go index e09ee11328366..fafedc7a6b137 100644 --- a/util/chunk/recordbatch.go +++ b/util/chunk/recordbatch.go @@ -20,10 +20,9 @@ const UnspecifiedNumRows = -1 type RecordBatch struct { *Chunk - // requiredRows indicates how many rows is considered full for parent executor. - // Child executor can return immediately if there are such number of rows, - // instead of fulling the whole chunk. - // This is not compulsory, so the number of returned rows can be larger than it in some cases. + // requiredRows indicates how many rows is required by the parent executor. + // Child executor should stop populating rows immediately if there are at + // least required rows in the Chunk. requiredRows int } From 4ccfa417064b730141b16665a211a468742c58a7 Mon Sep 17 00:00:00 2001 From: qw4990 Date: Tue, 12 Feb 2019 15:00:59 +0800 Subject: [PATCH 7/8] modify UnspecifiedNumRows --- util/chunk/recordbatch.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/util/chunk/recordbatch.go b/util/chunk/recordbatch.go index fafedc7a6b137..c539e4dd83322 100644 --- a/util/chunk/recordbatch.go +++ b/util/chunk/recordbatch.go @@ -14,7 +14,7 @@ package chunk // UnspecifiedNumRows represents requiredRows is not specified. -const UnspecifiedNumRows = -1 +const UnspecifiedNumRows = 0 // RecordBatch is input parameter of Executor.Next` method. type RecordBatch struct { From 973b69807e6c7b3f498bf677e0a51783e3b212fc Mon Sep 17 00:00:00 2001 From: qw4990 Date: Tue, 12 Feb 2019 16:19:18 +0800 Subject: [PATCH 8/8] remove maxChunkSize from IsFull --- distsql/select_result.go | 4 ++-- distsql/stream.go | 4 ++-- util/chunk/recordbatch.go | 12 +++++++++--- util/chunk/recordbatch_test.go | 9 ++++++--- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/distsql/select_result.go b/distsql/select_result.go index 697c1aa7f4510..5080cba753b1a 100644 --- a/distsql/select_result.go +++ b/distsql/select_result.go @@ -124,7 +124,7 @@ func (r *selectResult) Next(ctx context.Context, chk *chunk.Chunk) error { // NextBatch reads the data into batch. func (r *selectResult) NextBatch(ctx context.Context, batch *chunk.RecordBatch) error { batch.Reset() - for !batch.IsFull(r.ctx.GetSessionVars().MaxChunkSize) { + for !batch.IsFull() && batch.NumRows() < r.ctx.GetSessionVars().MaxChunkSize { if r.selectResp == nil || r.respChkIdx == len(r.selectResp.Chunks) { err := r.getSelectResp() if err != nil || r.selectResp == nil { @@ -178,7 +178,7 @@ func (r *selectResult) getSelectResp() error { func (r *selectResult) readRowsData(batch *chunk.RecordBatch) (err error) { rowsData := r.selectResp.Chunks[r.respChkIdx].RowsData decoder := codec.NewDecoder(batch.Chunk, r.ctx.GetSessionVars().Location()) - for !batch.IsFull(r.ctx.GetSessionVars().MaxChunkSize) && len(rowsData) > 0 { + for !batch.IsFull() && batch.NumRows() < r.ctx.GetSessionVars().MaxChunkSize && len(rowsData) > 0 { for i := 0; i < r.rowLen; i++ { rowsData, err = decoder.DecodeOne(rowsData, i, r.fieldTypes[i]) if err != nil { diff --git a/distsql/stream.go b/distsql/stream.go index a8d9d1445a43f..a8a87a7738229 100644 --- a/distsql/stream.go +++ b/distsql/stream.go @@ -51,7 +51,7 @@ func (r *streamResult) Next(ctx context.Context, chk *chunk.Chunk) error { // NextBatch reads the data into batch. func (r *streamResult) NextBatch(ctx context.Context, batch *chunk.RecordBatch) error { batch.Reset() - for !batch.IsFull(r.ctx.GetSessionVars().MaxChunkSize) { + for !batch.IsFull() && batch.NumRows() < r.ctx.GetSessionVars().MaxChunkSize { err := r.readDataIfNecessary(ctx) if err != nil { return errors.Trace(err) @@ -121,7 +121,7 @@ func (r *streamResult) readDataIfNecessary(ctx context.Context) error { func (r *streamResult) flushToBatch(batch *chunk.RecordBatch) (err error) { remainRowsData := r.curr.RowsData decoder := codec.NewDecoder(batch.Chunk, r.ctx.GetSessionVars().Location()) - for !batch.IsFull(r.ctx.GetSessionVars().MaxChunkSize) && len(remainRowsData) > 0 { + for !batch.IsFull() && batch.NumRows() < r.ctx.GetSessionVars().MaxChunkSize && len(remainRowsData) > 0 { for i := 0; i < r.rowLen; i++ { remainRowsData, err = decoder.DecodeOne(remainRowsData, i, r.fieldTypes[i]) if err != nil { diff --git a/util/chunk/recordbatch.go b/util/chunk/recordbatch.go index c539e4dd83322..4c1a3520af642 100644 --- a/util/chunk/recordbatch.go +++ b/util/chunk/recordbatch.go @@ -46,7 +46,13 @@ func (rb *RecordBatch) RequiredRows() int { } // IsFull returns if this batch can be considered full. -func (rb *RecordBatch) IsFull(maxChunkSize int) bool { - numRows := rb.NumRows() - return numRows >= maxChunkSize || (rb.requiredRows != UnspecifiedNumRows && numRows >= rb.requiredRows) +// IsFull only takes requiredRows into account, the caller of this method should +// also consider maxChunkSize, then it should behave like: +// if !batch.IsFull() && batch.NumRows() < maxChunkSize { ... } +func (rb *RecordBatch) IsFull() bool { + if rb.requiredRows == UnspecifiedNumRows { + return false + } + + return rb.NumRows() >= rb.requiredRows } diff --git a/util/chunk/recordbatch_test.go b/util/chunk/recordbatch_test.go index 5e83d2f80f6d6..b2274ef54190a 100644 --- a/util/chunk/recordbatch_test.go +++ b/util/chunk/recordbatch_test.go @@ -41,15 +41,18 @@ func (s *testChunkSuite) TestRecordBatch(c *check.C) { batch.AppendInt64(0, 1) batch.AppendInt64(0, 1) c.Assert(batch.NumRows(), check.Equals, 4) - c.Assert(batch.IsFull(maxChunkSize), check.IsFalse) + c.Assert(batch.IsFull(), check.IsFalse) batch.AppendInt64(0, 1) c.Assert(batch.NumRows(), check.Equals, 5) - c.Assert(batch.IsFull(maxChunkSize), check.IsTrue) + c.Assert(batch.IsFull(), check.IsTrue) batch.AppendInt64(0, 1) batch.AppendInt64(0, 1) batch.AppendInt64(0, 1) c.Assert(batch.NumRows(), check.Equals, 8) - c.Assert(batch.IsFull(maxChunkSize), check.IsTrue) + c.Assert(batch.IsFull(), check.IsTrue) + + batch.SetRequiredRows(UnspecifiedNumRows) + c.Assert(batch.IsFull(), check.IsFalse) }