Skip to content

Commit

Permalink
planner: fix bug of mpp wrongly set schema of exchanger (#22845)
Browse files Browse the repository at this point in the history
* planner: fix bug of mpp wrongly set schema of exchanger

* support outer join in mock engine

* add unit test

Co-authored-by: Ti Chi Robot <[email protected]>
  • Loading branch information
hanfei1991 and ti-chi-bot authored Mar 5, 2021
1 parent a1712a0 commit 505955a
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 15 deletions.
3 changes: 3 additions & 0 deletions executor/tiflash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ func (s *tiflashTestSuite) TestMppExecution(c *C) {
tk.MustExec("set @@session.tidb_broadcast_join_threshold_size=1")
tk.MustQuery("select count(*) from t1 , t where t1.a = t.a").Check(testkit.Rows("3"))
tk.MustQuery("select count(*) from t1 , t, t2 where t1.a = t.a and t2.a = t.a").Check(testkit.Rows("3"))

tk.MustExec("insert into t1 values(4,0)")
tk.MustQuery("select count(*), t2.b from t1 left join t2 on t1.a = t2.a group by t2.b").Check(testkit.Rows("3 0", "1 <nil>"))
}

func (s *tiflashTestSuite) TestPartitionTable(c *C) {
Expand Down
13 changes: 10 additions & 3 deletions planner/core/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,22 @@ func buildLogicalJoinSchema(joinType JoinType, join LogicalPlan) *expression.Sch

// BuildPhysicalJoinSchema builds the schema of PhysicalJoin from it's children's schema.
func BuildPhysicalJoinSchema(joinType JoinType, join PhysicalPlan) *expression.Schema {
leftSchema := join.Children()[0].Schema()
switch joinType {
case SemiJoin, AntiSemiJoin:
return join.Children()[0].Schema().Clone()
return leftSchema.Clone()
case LeftOuterSemiJoin, AntiLeftOuterSemiJoin:
newSchema := join.Children()[0].Schema().Clone()
newSchema := leftSchema.Clone()
newSchema.Append(join.Schema().Columns[join.Schema().Len()-1])
return newSchema
}
return expression.MergeSchema(join.Children()[0].Schema(), join.Children()[1].Schema())
newSchema := expression.MergeSchema(leftSchema, join.Children()[1].Schema())
if joinType == LeftOuterJoin {
resetNotNullFlag(newSchema, leftSchema.Len(), newSchema.Len())
} else if joinType == RightOuterJoin {
resetNotNullFlag(newSchema, 0, leftSchema.Len())
}
return newSchema
}

// GetStatsInfo gets the statistics info from a physical plan tree.
Expand Down
22 changes: 22 additions & 0 deletions store/mockstore/unistore/cophandler/mpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/coprocessor"
"github.com/pingcap/kvproto/pkg/mpp"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/store/mockstore/unistore/client"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tipb/go-tipb"
"github.com/uber-go/atomic"
Expand Down Expand Up @@ -152,6 +154,25 @@ func (b *mppExecBuilder) buildMPPJoin(pb *tipb.Join, children []*tipb.Executor)
if err != nil {
return nil, errors.Trace(err)
}
if pb.JoinType == tipb.JoinType_TypeLeftOuterJoin {
for _, tp := range rightCh.getFieldTypes() {
tp.Flag &= ^mysql.NotNullFlag
}
defaultInner := chunk.MutRowFromTypes(rightCh.getFieldTypes())
for i := range rightCh.getFieldTypes() {
defaultInner.SetDatum(i, types.NewDatum(nil))
}
e.defaultInner = defaultInner.ToRow()
} else if pb.JoinType == tipb.JoinType_TypeRightOuterJoin {
for _, tp := range leftCh.getFieldTypes() {
tp.Flag &= ^mysql.NotNullFlag
}
defaultInner := chunk.MutRowFromTypes(leftCh.getFieldTypes())
for i := range leftCh.getFieldTypes() {
defaultInner.SetDatum(i, types.NewDatum(nil))
}
e.defaultInner = defaultInner.ToRow()
}
// because the field type is immutable, so this kind of appending is safe.
e.fieldTypes = append(leftCh.getFieldTypes(), rightCh.getFieldTypes()...)
if pb.InnerIdx == 1 {
Expand Down Expand Up @@ -253,6 +274,7 @@ func (b *mppExecBuilder) buildMPPAgg(agg *tipb.Aggregation) (*aggExec, error) {
for _, gby := range agg.GroupBy {
ft := expression.PbTypeToFieldType(gby.FieldType)
e.fieldTypes = append(e.fieldTypes, ft)
e.groupByTypes = append(e.groupByTypes, ft)
gbyExpr, err := expression.PBToExpr(gby, chExec.getFieldTypes(), b.sc)
if err != nil {
return nil, errors.Trace(err)
Expand Down
59 changes: 47 additions & 12 deletions store/mockstore/unistore/cophandler/mpp_exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/kvproto/pkg/mpp"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/kv"
Expand Down Expand Up @@ -164,8 +165,12 @@ func (e *exchSenderExec) next() (*chunk.Chunk, error) {
for i := 0; i < rows; i++ {
row := chk.GetRow(i)
d := row.GetDatum(e.hashKeyOffset, e.fieldTypes[e.hashKeyOffset])
hashKey := int(d.GetInt64() % int64(len(e.tunnels)))
targetChunks[hashKey].AppendRow(row)
if d.IsNull() {
targetChunks[0].AppendRow(row)
} else {
hashKey := int(d.GetInt64() % int64(len(e.tunnels)))
targetChunks[hashKey].AppendRow(row)
}
}
for i, tunnel := range e.tunnels {
if targetChunks[i].NumRows() > 0 {
Expand Down Expand Up @@ -337,6 +342,8 @@ type joinExec struct {

idx int
reservedRows []chunk.Row

defaultInner chunk.Row
inited bool
}

Expand Down Expand Up @@ -397,6 +404,16 @@ func (e *joinExec) fetchRows() (bool, error) {
}
e.reservedRows = append(e.reservedRows, newRow.ToRow())
}
} else if e.Join.JoinType == tipb.JoinType_TypeLeftOuterJoin {
newRow := chunk.MutRowFromTypes(e.fieldTypes)
newRow.ShallowCopyPartialRow(0, row)
newRow.ShallowCopyPartialRow(row.Len(), e.defaultInner)
e.reservedRows = append(e.reservedRows, newRow.ToRow())
} else if e.Join.JoinType == tipb.JoinType_TypeRightOuterJoin {
newRow := chunk.MutRowFromTypes(e.fieldTypes)
newRow.ShallowCopyPartialRow(0, e.defaultInner)
newRow.ShallowCopyPartialRow(e.defaultInner.Len(), row)
e.reservedRows = append(e.reservedRows, newRow.ToRow())
}
}
return false, nil
Expand Down Expand Up @@ -446,31 +463,36 @@ type aggExec struct {
groupKeys [][]byte
aggCtxsMap map[string][]*aggregation.AggEvaluateContext

groupByRows []chunk.Row
groupByTypes []*types.FieldType

processed bool
}

func (e *aggExec) open() error {
return e.children[0].open()
}

func (e *aggExec) getGroupKey(row chunk.Row) ([]byte, error) {
func (e *aggExec) getGroupKey(row chunk.Row) (*chunk.MutRow, []byte, error) {
length := len(e.groupByExprs)
if length == 0 {
return nil, nil
return nil, nil, nil
}
key := make([]byte, 0, 32)
for _, item := range e.groupByExprs {
gbyRow := chunk.MutRowFromTypes(e.groupByTypes)
for i, item := range e.groupByExprs {
v, err := item.Eval(row)
if err != nil {
return nil, errors.Trace(err)
return nil, nil, errors.Trace(err)
}
gbyRow.SetDatum(i, v)
b, err := codec.EncodeValue(e.sc, nil, v)
if err != nil {
return nil, errors.Trace(err)
return nil, nil, errors.Trace(err)
}
key = append(key, b...)
}
return key, nil
return &gbyRow, key, nil
}

func (e *aggExec) getContexts(groupKey []byte) []*aggregation.AggEvaluateContext {
Expand All @@ -497,13 +519,16 @@ func (e *aggExec) processAllRows() (*chunk.Chunk, error) {
rows := chk.NumRows()
for i := 0; i < rows; i++ {
row := chk.GetRow(i)
gk, err := e.getGroupKey(row)
gbyRow, gk, err := e.getGroupKey(row)
if err != nil {
return nil, errors.Trace(err)
}
if _, ok := e.groups[string(gk)]; !ok {
e.groups[string(gk)] = struct{}{}
e.groupKeys = append(e.groupKeys, gk)
if gbyRow != nil {
e.groupByRows = append(e.groupByRows, gbyRow.ToRow())
}
}

aggCtxs := e.getContexts(gk)
Expand All @@ -518,12 +543,22 @@ func (e *aggExec) processAllRows() (*chunk.Chunk, error) {

chk := chunk.NewChunkWithCapacity(e.fieldTypes, 0)

for _, gk := range e.groupKeys {
for i, gk := range e.groupKeys {
newRow := chunk.MutRowFromTypes(e.fieldTypes)
aggCtxs := e.getContexts(gk)
for i, agg := range e.aggExprs {
partialResults := agg.GetPartialResult(aggCtxs[i])
newRow.SetDatum(i, partialResults[0])
result := agg.GetResult(aggCtxs[i])
if e.fieldTypes[i].Tp == mysql.TypeLonglong && result.Kind() == types.KindMysqlDecimal {
var err error
result, err = result.ConvertTo(e.sc, e.fieldTypes[i])
if err != nil {
return nil, errors.Trace(err)
}
}
newRow.SetDatum(i, result)
}
if len(e.groupByRows) > 0 {
newRow.ShallowCopyPartialRow(len(e.aggExprs), e.groupByRows[i])
}
chk.AppendRow(newRow.ToRow())
}
Expand Down

0 comments on commit 505955a

Please sign in to comment.