diff --git a/pkg/sql/BUILD.bazel b/pkg/sql/BUILD.bazel index ee85f9eb59c0..2950ad247471 100644 --- a/pkg/sql/BUILD.bazel +++ b/pkg/sql/BUILD.bazel @@ -68,6 +68,7 @@ go_library( "distsql_plan_ctas.go", "distsql_plan_join.go", "distsql_plan_scrub_physical.go", + "distsql_plan_set_op.go", "distsql_plan_stats.go", "distsql_plan_window.go", "distsql_running.go", @@ -425,6 +426,7 @@ go_test( "descriptor_mutation_test.go", "distsql_physical_planner_test.go", "distsql_plan_backfill_test.go", + "distsql_plan_set_op_test.go", "distsql_running_test.go", "drop_helpers_test.go", "drop_test.go", diff --git a/pkg/sql/distsql_physical_planner.go b/pkg/sql/distsql_physical_planner.go index b24db1eaa35e..10cbad9bc13c 100644 --- a/pkg/sql/distsql_physical_planner.go +++ b/pkg/sql/distsql_physical_planner.go @@ -3268,7 +3268,7 @@ func (dsp *DistSQLPlanner) createPlanForSetOp( p.PlanToStreamColMap = planToStreamColMap // Merge the plans' result types and merge ordering. - resultTypes, err := physicalplan.MergeResultTypes(leftPlan.GetResultTypes(), rightPlan.GetResultTypes()) + resultTypes, err := mergeResultTypes(leftPlan.GetResultTypes(), rightPlan.GetResultTypes()) if err != nil { return nil, err } diff --git a/pkg/sql/distsql_plan_set_op.go b/pkg/sql/distsql_plan_set_op.go new file mode 100644 index 000000000000..eded5995e911 --- /dev/null +++ b/pkg/sql/distsql_plan_set_op.go @@ -0,0 +1,48 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package sql + +import ( + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/errors" +) + +// mergeResultTypes reconciles the ResultTypes between two plans. It enforces +// that each pair of ColumnTypes must either match or be null, in which case the +// non-null type is used. This logic is necessary for cases like +// SELECT NULL UNION SELECT 1. +func mergeResultTypes(left, right []*types.T) ([]*types.T, error) { + if len(left) != len(right) { + return nil, errors.Errorf("ResultTypes length mismatch: %d and %d", len(left), len(right)) + } + merged := make([]*types.T, len(left)) + for i := range left { + leftType, rightType := left[i], right[i] + if rightType.Family() == types.UnknownFamily { + merged[i] = leftType + } else if leftType.Family() == types.UnknownFamily { + merged[i] = rightType + } else if equivalentTypes(leftType, rightType) { + merged[i] = leftType + } else { + return nil, errors.Errorf( + "conflicting ColumnTypes: %s and %s", leftType.DebugString(), rightType.DebugString()) + } + } + return merged, nil +} + +// equivalentType checks whether a column type is equivalent to another for the +// purpose of UNION. Precision, Width, Oid, etc. do not affect the merging of +// values. +func equivalentTypes(c, other *types.T) bool { + return c.Equivalent(other) +} diff --git a/pkg/sql/distsql_plan_set_op_test.go b/pkg/sql/distsql_plan_set_op_test.go new file mode 100644 index 000000000000..9119c5c6dd48 --- /dev/null +++ b/pkg/sql/distsql_plan_set_op_test.go @@ -0,0 +1,62 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package sql + +import ( + "reflect" + "testing" + + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" +) + +func TestMergeResultTypes(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + empty := []*types.T{} + null := []*types.T{types.Unknown} + typeInt := []*types.T{types.Int} + + testData := []struct { + name string + left []*types.T + right []*types.T + expected *[]*types.T + err bool + }{ + {"both empty", empty, empty, &empty, false}, + {"left empty", empty, typeInt, nil, true}, + {"right empty", typeInt, empty, nil, true}, + {"both null", null, null, &null, false}, + {"left null", null, typeInt, &typeInt, false}, + {"right null", typeInt, null, &typeInt, false}, + {"both int", typeInt, typeInt, &typeInt, false}, + } + for _, td := range testData { + t.Run(td.name, func(t *testing.T) { + result, err := mergeResultTypes(td.left, td.right) + if td.err { + if err == nil { + t.Fatalf("expected error, got %+v", result) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if !reflect.DeepEqual(*td.expected, result) { + t.Fatalf("expected %+v, got %+v", *td.expected, result) + } + }) + } +} diff --git a/pkg/sql/physicalplan/physical_plan.go b/pkg/sql/physicalplan/physical_plan.go index 5c3f6eab4bdb..c3c993140d8a 100644 --- a/pkg/sql/physicalplan/physical_plan.go +++ b/pkg/sql/physicalplan/physical_plan.go @@ -964,38 +964,6 @@ func MergePlans( mergedPlan.Distribution = leftPlanDistribution.compose(rightPlanDistribution) } -// MergeResultTypes reconciles the ResultTypes between two plans. It enforces -// that each pair of ColumnTypes must either match or be null, in which case the -// non-null type is used. This logic is necessary for cases like -// SELECT NULL UNION SELECT 1. -func MergeResultTypes(left, right []*types.T) ([]*types.T, error) { - if len(left) != len(right) { - return nil, errors.Errorf("ResultTypes length mismatch: %d and %d", len(left), len(right)) - } - merged := make([]*types.T, len(left)) - for i := range left { - leftType, rightType := left[i], right[i] - if rightType.Family() == types.UnknownFamily { - merged[i] = leftType - } else if leftType.Family() == types.UnknownFamily { - merged[i] = rightType - } else if equivalentTypes(leftType, rightType) { - merged[i] = leftType - } else { - return nil, errors.Errorf( - "conflicting ColumnTypes: %s and %s", leftType.DebugString(), rightType.DebugString()) - } - } - return merged, nil -} - -// equivalentType checks whether a column type is equivalent to another for the -// purpose of UNION. Precision, Width, Oid, etc. do not affect the merging of -// values. -func equivalentTypes(c, other *types.T) bool { - return c.Equivalent(other) -} - // AddJoinStage adds join processors at each of the specified nodes, and wires // the left and right-side outputs to these processors. func (p *PhysicalPlan) AddJoinStage( diff --git a/pkg/sql/physicalplan/physical_plan_test.go b/pkg/sql/physicalplan/physical_plan_test.go index 873666418085..55f47ec90f8e 100644 --- a/pkg/sql/physicalplan/physical_plan_test.go +++ b/pkg/sql/physicalplan/physical_plan_test.go @@ -8,10 +8,6 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -// This file defines structures and basic functionality that is useful when -// building distsql plans. It does not contain the actual physical planning -// code. - package physicalplan import ( @@ -378,45 +374,3 @@ func TestProjectionAndRendering(t *testing.T) { } } } - -func TestMergeResultTypes(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - - empty := []*types.T{} - null := []*types.T{types.Unknown} - typeInt := []*types.T{types.Int} - - testData := []struct { - name string - left []*types.T - right []*types.T - expected *[]*types.T - err bool - }{ - {"both empty", empty, empty, &empty, false}, - {"left empty", empty, typeInt, nil, true}, - {"right empty", typeInt, empty, nil, true}, - {"both null", null, null, &null, false}, - {"left null", null, typeInt, &typeInt, false}, - {"right null", typeInt, null, &typeInt, false}, - {"both int", typeInt, typeInt, &typeInt, false}, - } - for _, td := range testData { - t.Run(td.name, func(t *testing.T) { - result, err := MergeResultTypes(td.left, td.right) - if td.err { - if err == nil { - t.Fatalf("expected error, got %+v", result) - } - return - } - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - if !reflect.DeepEqual(*td.expected, result) { - t.Fatalf("expected %+v, got %+v", *td.expected, result) - } - }) - } -}