From a56120147b41151308822a77ca61dda3f72038cf Mon Sep 17 00:00:00 2001 From: Marcus Gartner Date: Mon, 28 Aug 2023 16:04:11 -0400 Subject: [PATCH] opt: fix plan gist decoding internal error This commit fixes some cases where `crdb_internal.decode_plan_gist` could raise internal index-out-of-bound errors when given incorrectly formed input. Fixes #109560 Release note: None --- .../exec/execbuilder/testdata/explain_gist | 7 ++++ pkg/sql/opt/exec/explain/explain_factory.go | 7 ++++ pkg/sql/opt/exec/explain/plan_gist_factory.go | 3 ++ pkg/sql/opt/exec/explain/result_columns.go | 36 +++++++++++++++++++ 4 files changed, 53 insertions(+) diff --git a/pkg/sql/opt/exec/execbuilder/testdata/explain_gist b/pkg/sql/opt/exec/execbuilder/testdata/explain_gist index 8da59387f74b..51a3cbc0e222 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/explain_gist +++ b/pkg/sql/opt/exec/execbuilder/testdata/explain_gist @@ -189,3 +189,10 @@ SELECT crdb_internal.decode_plan_gist('AgGwAgQAgQIAAgAEBQITsAICAxgGDA==') └── • scan table: ?@? spans: 1+ spans + +# Regression test for #109560. Incorrectly formed plan gist should not cause +# internal error. +query T nosort +SELECT crdb_internal.decode_external_plan_gist('Ag8f') +---- +• union all diff --git a/pkg/sql/opt/exec/explain/explain_factory.go b/pkg/sql/opt/exec/explain/explain_factory.go index 95f152702bca..880ec9e6b0a3 100644 --- a/pkg/sql/opt/exec/explain/explain_factory.go +++ b/pkg/sql/opt/exec/explain/explain_factory.go @@ -90,6 +90,13 @@ func (n *Node) Annotate(id exec.ExplainAnnotationID, value interface{}) { func newNode( op execOperator, args interface{}, ordering exec.OutputOrdering, children ...*Node, ) (*Node, error) { + nonNilChildren := make([]*Node, 0, len(children)) + for i := range children { + if children[i] != nil { + nonNilChildren = append(nonNilChildren, children[i]) + } + } + children = nonNilChildren inputNodeCols := make([]colinfo.ResultColumns, len(children)) for i := range children { inputNodeCols[i] = children[i].Columns() diff --git a/pkg/sql/opt/exec/explain/plan_gist_factory.go b/pkg/sql/opt/exec/explain/plan_gist_factory.go index b213e33f50dd..f1fa44a96b4c 100644 --- a/pkg/sql/opt/exec/explain/plan_gist_factory.go +++ b/pkg/sql/opt/exec/explain/plan_gist_factory.go @@ -248,6 +248,9 @@ func (f *PlanGistFactory) decodeOp() execOperator { func (f *PlanGistFactory) popChild() *Node { l := len(f.nodeStack) + if l == 0 { + return nil + } n := f.nodeStack[l-1] f.nodeStack = f.nodeStack[:l-1] diff --git a/pkg/sql/opt/exec/explain/result_columns.go b/pkg/sql/opt/exec/explain/result_columns.go index 9ef00112bd62..93fde441a6ea 100644 --- a/pkg/sql/opt/exec/explain/result_columns.go +++ b/pkg/sql/opt/exec/explain/result_columns.go @@ -42,13 +42,22 @@ func getResultColumns( case filterOp, invertedFilterOp, limitOp, max1RowOp, sortOp, topKOp, bufferOp, hashSetOpOp, streamingSetOpOp, unionAllOp, distinctOp, saveTableOp, recursiveCTEOp: // These ops inherit the columns from their first input. + if len(inputs) == 0 { + return nil, nil + } return inputs[0], nil case simpleProjectOp: + if len(inputs) == 0 { + return nil, nil + } a := args.(*simpleProjectArgs) return projectCols(inputs[0], a.Cols, nil /* colNames */), nil case serializingProjectOp: + if len(inputs) == 0 { + return nil, nil + } a := args.(*serializingProjectArgs) return projectCols(inputs[0], a.Cols, a.ColNames), nil @@ -67,19 +76,34 @@ func getResultColumns( return args.(*renderArgs).Columns, nil case projectSetOp: + if len(inputs) == 0 { + return nil, nil + } return appendColumns(inputs[0], args.(*projectSetArgs).ZipCols...), nil case applyJoinOp: + if len(inputs) == 0 { + return nil, nil + } a := args.(*applyJoinArgs) return joinColumns(a.JoinType, inputs[0], a.RightColumns), nil case hashJoinOp: + if len(inputs) < 2 { + return nil, nil + } return joinColumns(args.(*hashJoinArgs).JoinType, inputs[0], inputs[1]), nil case mergeJoinOp: + if len(inputs) < 2 { + return nil, nil + } return joinColumns(args.(*mergeJoinArgs).JoinType, inputs[0], inputs[1]), nil case lookupJoinOp: + if len(inputs) == 0 { + return nil, nil + } a := args.(*lookupJoinArgs) cols := joinColumns(a.JoinType, inputs[0], tableColumns(a.Table, a.LookupCols)) // The following matches the behavior of execFactory.ConstructLookupJoin. @@ -89,16 +113,25 @@ func getResultColumns( return cols, nil case ordinalityOp: + if len(inputs) == 0 { + return nil, nil + } return appendColumns(inputs[0], colinfo.ResultColumn{ Name: args.(*ordinalityArgs).ColName, Typ: types.Int, }), nil case groupByOp: + if len(inputs) == 0 { + return nil, nil + } a := args.(*groupByArgs) return groupByColumns(inputs[0], a.GroupCols, a.Aggregations), nil case scalarGroupByOp: + if len(inputs) == 0 { + return nil, nil + } a := args.(*scalarGroupByArgs) return groupByColumns(inputs[0], nil /* groupCols */, a.Aggregations), nil @@ -106,6 +139,9 @@ func getResultColumns( return args.(*windowArgs).Window.Cols, nil case invertedJoinOp: + if len(inputs) == 0 { + return nil, nil + } a := args.(*invertedJoinArgs) cols := joinColumns(a.JoinType, inputs[0], tableColumns(a.Table, a.LookupCols)) // The following matches the behavior of execFactory.ConstructInvertedJoin.