Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plan: use single method to set parent and children. #4738

Merged
merged 2 commits into from
Oct 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions plan/aggregation_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ func (a *aggregationOptimizer) tryToPushDownAgg(aggFuncs []aggregation.Aggregati
}
}
agg := a.makeNewAgg(aggFuncs, gbyCols)
child.SetParents(agg)
agg.SetChildren(child)
setParentAndChildren(agg, child)
// If agg has no group-by item, it will return a default value, which may cause some bugs.
// So here we add a group-by item forcely.
if len(agg.GroupByItems) == 0 {
Expand Down Expand Up @@ -307,12 +306,11 @@ func (a *aggregationOptimizer) pushAggCrossUnion(agg *LogicalAggregation, unionS
for _, key := range unionChild.Schema().Keys {
if tmpSchema.ColumnsIndices(key) != nil {
proj := a.convertAggToProj(newAgg, a.ctx, a.allocator)
proj.SetChildren(unionChild)
setParentAndChildren(proj, unionChild)
return proj
}
}
newAgg.SetChildren(unionChild)
unionChild.SetParents(newAgg)
setParentAndChildren(newAgg, unionChild)
return newAgg
}

Expand Down Expand Up @@ -351,9 +349,7 @@ func (a *aggregationOptimizer) aggPushDown(p LogicalPlan) LogicalPlan {
} else {
lChild = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0)
}
join.SetChildren(lChild, rChild)
lChild.SetParents(join)
rChild.SetParents(join)
setParentAndChildren(join, lChild, rChild)
join.SetSchema(expression.MergeSchema(lChild.Schema(), rChild.Schema()))
join.buildKeyInfo()
proj := a.tryToEliminateAggregation(agg)
Expand All @@ -376,8 +372,7 @@ func (a *aggregationOptimizer) aggPushDown(p LogicalPlan) LogicalPlan {
aggFunc.SetArgs(newArgs)
}
projChild := proj.children[0]
agg.SetChildren(projChild)
projChild.SetParents(agg)
setParentAndChildren(agg, projChild)
} else if union, ok1 := child.(*Union); ok1 {
var gbyCols []*expression.Column
for _, gbyExpr := range agg.GroupByItems {
Expand All @@ -388,20 +383,18 @@ func (a *aggregationOptimizer) aggPushDown(p LogicalPlan) LogicalPlan {
for _, child := range union.children {
newChild := a.pushAggCrossUnion(pushedAgg, union.schema, child.(LogicalPlan))
newChildren = append(newChildren, newChild)
newChild.SetParents(union)
}
union.SetChildren(newChildren...)
setParentAndChildren(union, newChildren...)
union.SetSchema(pushedAgg.schema)
}
}
}
newChildren := make([]Plan, 0, len(p.Children()))
for _, child := range p.Children() {
newChild := a.aggPushDown(child.(LogicalPlan))
newChild.SetParents(p)
newChildren = append(newChildren, newChild)
}
p.SetChildren(newChildren...)
setParentAndChildren(p, newChildren...)
return p
}

Expand All @@ -421,8 +414,7 @@ func (a *aggregationOptimizer) tryToEliminateAggregation(agg *LogicalAggregation
if coveredByUniqueKey {
// GroupByCols has unique key, so this aggregation can be removed.
proj := a.convertAggToProj(agg, a.ctx, a.allocator)
proj.SetChildren(agg.children[0])
agg.children[0].SetParents(proj)
setParentAndChildren(proj, agg.children[0])
return proj
}
return nil
Expand Down
21 changes: 7 additions & 14 deletions plan/decorrelate.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,12 @@ func (s *decorrelateSolver) optimize(p LogicalPlan, _ context.Context, _ *idAllo
}
apply.attachOnConds(newConds)
innerPlan = sel.children[0].(LogicalPlan)
apply.SetChildren(outerPlan, innerPlan)
innerPlan.SetParents(apply)
setParentAndChildren(apply, outerPlan, innerPlan)
return s.optimize(p, nil, nil)
} else if m, ok := innerPlan.(*MaxOneRow); ok {
if m.children[0].Schema().MaxOneRow {
innerPlan = m.children[0].(LogicalPlan)
innerPlan.SetParents(apply)
apply.SetChildren(outerPlan, innerPlan)
setParentAndChildren(apply, outerPlan, innerPlan)
return s.optimize(p, nil, nil)
}
} else if proj, ok := innerPlan.(*Projection); ok {
Expand All @@ -123,8 +121,7 @@ func (s *decorrelateSolver) optimize(p LogicalPlan, _ context.Context, _ *idAllo
}
apply.columnSubstitute(proj.Schema(), proj.Exprs)
innerPlan = proj.children[0].(LogicalPlan)
apply.SetChildren(outerPlan, innerPlan)
innerPlan.SetParents(apply)
setParentAndChildren(apply, outerPlan, innerPlan)
if apply.JoinType != SemiJoin && apply.JoinType != LeftOuterSemiJoin {
proj.SetSchema(apply.Schema())
proj.Exprs = append(expression.Column2Exprs(outerPlan.Schema().Clone().Columns), proj.Exprs...)
Expand All @@ -134,17 +131,15 @@ func (s *decorrelateSolver) optimize(p LogicalPlan, _ context.Context, _ *idAllo
if err != nil {
return nil, errors.Trace(err)
}
proj.SetChildren(np)
np.SetParents(proj)
setParentAndChildren(proj, np)
return proj, nil
}
return s.optimize(p, nil, nil)
} else if agg, ok := innerPlan.(*LogicalAggregation); ok {
if apply.canPullUpAgg() && agg.canPullUp() {
innerPlan = agg.children[0].(LogicalPlan)
apply.JoinType = LeftOuterJoin
apply.SetChildren(outerPlan, innerPlan)
innerPlan.SetParents(apply)
setParentAndChildren(apply, outerPlan, innerPlan)
agg.SetSchema(apply.Schema())
agg.GroupByItems = expression.Column2Exprs(outerPlan.Schema().Keys[0])
newAggFuncs := make([]aggregation.Aggregation, 0, apply.Schema().Len())
Expand All @@ -160,8 +155,7 @@ func (s *decorrelateSolver) optimize(p LogicalPlan, _ context.Context, _ *idAllo
if err != nil {
return nil, errors.Trace(err)
}
agg.SetChildren(np)
np.SetParents(agg)
setParentAndChildren(agg, np)
agg.collectGroupByColumns()
return agg, nil
}
Expand All @@ -174,8 +168,7 @@ func (s *decorrelateSolver) optimize(p LogicalPlan, _ context.Context, _ *idAllo
return nil, errors.Trace(err)
}
newChildren = append(newChildren, np)
np.SetParents(p)
}
p.SetChildren(newChildren...)
setParentAndChildren(p, newChildren...)
return p, nil
}
5 changes: 2 additions & 3 deletions plan/eliminate_projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ func doPhysicalProjectionElimination(p PhysicalPlan) PhysicalPlan {
for _, child := range p.Children() {
newChild := doPhysicalProjectionElimination(child.(PhysicalPlan))
children = append(children, newChild)
newChild.SetParents(p)
}
p.SetChildren(children...)
setParentAndChildren(p, children...)

proj, isProj := p.(*Projection)
if !isProj || !canProjectionBeEliminatedStrict(proj) {
Expand Down Expand Up @@ -128,7 +127,7 @@ func (pe *projectionEliminater) eliminate(p LogicalPlan, replace map[string]*exp
for _, child := range p.Children() {
children = append(children, pe.eliminate(child.(LogicalPlan), replace, childFlag))
}
p.SetChildren(children...)
setParentAndChildren(p, children...)

switch p.(type) {
case *Sort, *TopN, *Limit, *Selection, *MaxOneRow, *Update, *SelectLock:
Expand Down
8 changes: 4 additions & 4 deletions plan/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression.
agg := LogicalAggregation{
AggFuncs: []aggregation.Aggregation{aggFunc},
}.init(er.b.allocator, er.ctx)
addChild(agg, np)
setParentAndChildren(agg, np)
aggCol0 := &expression.Column{
ColName: model.NewCIStr("agg_Col_0"),
FromID: agg.id,
Expand Down Expand Up @@ -431,7 +431,7 @@ func (er *expressionRewriter) buildQuantifierPlan(agg *LogicalAggregation, cond,
IsAggOrSubq: true,
RetType: cond.GetType(),
})
addChild(proj, er.p)
setParentAndChildren(proj, er.p)
er.p = proj
}

Expand All @@ -444,7 +444,7 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np
agg := LogicalAggregation{
AggFuncs: []aggregation.Aggregation{firstRowFunc, countFunc},
}.init(er.b.allocator, er.ctx)
addChild(agg, np)
setParentAndChildren(agg, np)
firstRowResultCol := &expression.Column{
ColName: model.NewCIStr("col_firstRow"),
FromID: agg.id,
Expand Down Expand Up @@ -472,7 +472,7 @@ func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np
agg := LogicalAggregation{
AggFuncs: []aggregation.Aggregation{firstRowFunc, countFunc},
}.init(er.b.allocator, er.ctx)
addChild(agg, np)
setParentAndChildren(agg, np)
firstRowResultCol := &expression.Column{
ColName: model.NewCIStr("col_firstRow"),
FromID: agg.id,
Expand Down
4 changes: 1 addition & 3 deletions plan/join_reorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,8 @@ func (e *joinReOrderSolver) newJoin(lChild, rChild LogicalPlan) *LogicalJoin {
JoinType: InnerJoin,
reordered: true,
}.init(e.allocator, e.ctx)
join.SetChildren(lChild, rChild)
join.SetSchema(expression.MergeSchema(lChild.Schema(), rChild.Schema()))
lChild.SetParents(join)
rChild.SetParents(join)
setParentAndChildren(join, lChild, rChild)
return join
}

Expand Down
38 changes: 16 additions & 22 deletions plan/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (b *planBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.Aggrega
agg.AggFuncs = append(agg.AggFuncs, newFunc)
schema.Append(col.Clone().(*expression.Column))
}
addChild(agg, p)
setParentAndChildren(agg, p)
agg.GroupByItems = gbyItems
agg.SetSchema(schema)
agg.collectGroupByColumns()
Expand Down Expand Up @@ -230,8 +230,7 @@ func (b *planBuilder) buildJoin(join *ast.Join) LogicalPlan {

newSchema := expression.MergeSchema(leftPlan.Schema(), rightPlan.Schema())
joinPlan := LogicalJoin{}.init(b.allocator, b.ctx)
addChild(joinPlan, leftPlan)
addChild(joinPlan, rightPlan)
setParentAndChildren(joinPlan, leftPlan, rightPlan)
joinPlan.SetSchema(newSchema)

// Merge sub join's redundantSchema into this join plan. When handle query like
Expand Down Expand Up @@ -409,7 +408,7 @@ func (b *planBuilder) buildSelection(p LogicalPlan, where ast.ExprNode, AggMappe
}
selection.Conditions = expressions
selection.SetSchema(p.Schema().Clone())
addChild(selection, p)
setParentAndChildren(selection, p)
return selection
}

Expand Down Expand Up @@ -507,7 +506,7 @@ func (b *planBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField,
}
}
proj.SetSchema(schema)
addChild(proj, p)
setParentAndChildren(proj, p)
return proj, oldLen
}

Expand All @@ -522,7 +521,7 @@ func (b *planBuilder) buildDistinct(child LogicalPlan, length int) LogicalPlan {
for _, col := range child.Schema().Columns {
agg.AggFuncs = append(agg.AggFuncs, aggregation.NewAggFunction(ast.AggFuncFirstRow, []expression.Expression{col}, false))
}
addChild(agg, child)
setParentAndChildren(agg, child)
agg.SetSchema(child.Schema().Clone())
return agg
}
Expand Down Expand Up @@ -562,8 +561,7 @@ func (b *planBuilder) buildUnion(union *ast.UnionStmt) LogicalPlan {
col.FromID = proj.ID()
}
proj.SetSchema(schema)
sel.SetParents(proj)
proj.SetChildren(sel)
setParentAndChildren(proj, sel)
sel = proj
u.children[i] = proj
}
Expand Down Expand Up @@ -635,7 +633,7 @@ func (b *planBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper
exprs = append(exprs, &ByItems{Expr: it, Desc: item.Desc})
}
sort.ByItems = exprs
addChild(sort, p)
setParentAndChildren(sort, p)
sort.SetSchema(p.Schema().Clone())
return sort
}
Expand Down Expand Up @@ -686,7 +684,7 @@ func (b *planBuilder) buildLimit(src LogicalPlan, limit *ast.Limit) LogicalPlan
Offset: offset,
Count: count,
}.init(b.allocator, b.ctx)
addChild(li, src)
setParentAndChildren(li, src)
li.SetSchema(src.Schema().Clone())
return li
}
Expand Down Expand Up @@ -1167,7 +1165,7 @@ func (b *planBuilder) buildSelect(sel *ast.SelectStmt) LogicalPlan {
}
if oldLen != p.Schema().Len() {
proj := Projection{Exprs: expression.Column2Exprs(p.Schema().Columns[:oldLen])}.init(b.allocator, b.ctx)
addChild(proj, p)
setParentAndChildren(proj, p)
schema := expression.NewSchema(p.Schema().Clone().Columns[:oldLen]...)
for _, col := range schema.Columns {
col.FromID = proj.ID()
Expand Down Expand Up @@ -1323,8 +1321,7 @@ func (b *planBuilder) projectVirtualColumns(ds *DataSource, columns []*table.Col
proj.Exprs = append(proj.Exprs, expr)
}
proj.SetSchema(ds.Schema().Clone())
proj.SetChildren(ds)
ds.SetParents(proj)
setParentAndChildren(proj, ds)
return proj
}

Expand All @@ -1338,8 +1335,7 @@ func (b *planBuilder) buildApplyWithJoinType(outerPlan, innerPlan LogicalPlan, t
if tp == LeftOuterJoin {
ap.DefaultValues = make([]types.Datum, innerPlan.Schema().Len())
}
addChild(ap, outerPlan)
addChild(ap, innerPlan)
setParentAndChildren(ap, outerPlan, innerPlan)
ap.SetSchema(expression.MergeSchema(outerPlan.Schema(), innerPlan.Schema()))
for i := outerPlan.Schema().Len(); i < ap.Schema().Len(); i++ {
ap.schema.Columns[i].IsAggOrSubq = true
Expand Down Expand Up @@ -1383,7 +1379,7 @@ out:
}
}
exists := Exists{}.init(b.allocator, b.ctx)
addChild(exists, p)
setParentAndChildren(exists, p)
newCol := &expression.Column{
FromID: exists.id,
RetType: types.NewFieldType(mysql.TypeTiny),
Expand All @@ -1394,7 +1390,7 @@ out:

func (b *planBuilder) buildMaxOneRow(p LogicalPlan) LogicalPlan {
maxOneRow := MaxOneRow{}.init(b.allocator, b.ctx)
addChild(maxOneRow, p)
setParentAndChildren(maxOneRow, p)
maxOneRow.SetSchema(p.Schema().Clone())
return maxOneRow
}
Expand All @@ -1404,9 +1400,7 @@ func (b *planBuilder) buildSemiJoin(outerPlan, innerPlan LogicalPlan, onConditio
for i, expr := range onCondition {
onCondition[i] = expr.Decorrelate(outerPlan.Schema())
}
joinPlan.SetChildren(outerPlan, innerPlan)
outerPlan.SetParents(joinPlan)
innerPlan.SetParents(joinPlan)
setParentAndChildren(joinPlan, outerPlan, innerPlan)
joinPlan.attachOnConds(onCondition)
if asScalar {
newSchema := outerPlan.Schema().Clone()
Expand Down Expand Up @@ -1473,7 +1467,7 @@ func (b *planBuilder) buildUpdate(update *ast.UpdateStmt) LogicalPlan {
OrderedList: orderedList,
IgnoreErr: update.IgnoreErr,
}.init(b.allocator, b.ctx)
addChild(updt, p)
setParentAndChildren(updt, p)
updt.SetSchema(p.Schema())
return updt
}
Expand Down Expand Up @@ -1624,7 +1618,7 @@ func (b *planBuilder) buildDelete(delete *ast.DeleteStmt) LogicalPlan {
Tables: tables,
IsMultiTable: delete.IsMultiTable,
}.init(b.allocator, b.ctx)
addChild(del, p)
setParentAndChildren(del, p)
del.SetSchema(expression.NewSchema())

// Collect visitInfo.
Expand Down
12 changes: 7 additions & 5 deletions plan/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,13 +401,15 @@ type Delete struct {
IsMultiTable bool
}

// AddChild for parent.
func addChild(parent Plan, child Plan) {
if child == nil || parent == nil {
// setParentAndChildren sets parent and children relationship.
func setParentAndChildren(parent Plan, children ...Plan) {
if children == nil || parent == nil {
return
}
child.AddParent(parent)
parent.AddChild(child)
for _, child := range children {
child.SetParents(parent)
}
parent.SetChildren(children...)
}

// InsertPlan means inserting plan between two plans.
Expand Down
Loading