diff --git a/pkg/sql/distsqlrun/hashjoiner_test.go b/pkg/sql/distsqlrun/hashjoiner_test.go index dc071578738d..522ec2e61580 100644 --- a/pkg/sql/distsqlrun/hashjoiner_test.go +++ b/pkg/sql/distsqlrun/hashjoiner_test.go @@ -860,33 +860,31 @@ func BenchmarkHashJoiner(b *testing.B) { ctx := context.Background() evalCtx := tree.MakeTestingEvalContext() defer evalCtx.Stop(ctx) - flowCtx := FlowCtx{ + flowCtx := &FlowCtx{ Settings: cluster.MakeTestingClusterSettings(), EvalCtx: evalCtx, } - spec := HashJoinerSpec{ + spec := &HashJoinerSpec{ LeftEqColumns: []uint32{0}, RightEqColumns: []uint32{0}, Type: JoinType_INNER, + // Implicit @1 = @2 constraint. } - post := PostProcessSpec{Projection: true, OutputColumns: []uint32{0}} + post := &PostProcessSpec{} - const numCols = 4 - for _, inputSize := range []int{0, 1 << 2, 1 << 4, 1 << 8, 1 << 12, 1 << 16} { - b.Run(fmt.Sprintf("InputSize=%d", inputSize), func(b *testing.B) { - types := make([]sqlbase.ColumnType, numCols) - for i := 0; i < numCols; i++ { - types[i] = intType - } - rows := makeIntRows(inputSize, numCols) - leftInput := NewRepeatableRowSource(types, rows) - rightInput := NewRepeatableRowSource(types, rows) + const numCols = 1 + for _, numRows := range []int{0, 1 << 2, 1 << 4, 1 << 8, 1 << 12, 1 << 16} { + b.Run(fmt.Sprintf("rows=%d", numRows), func(b *testing.B) { + rows := makeIntRows(numRows, numCols) + leftInput := NewRepeatableRowSource(oneIntCol, rows) + rightInput := NewRepeatableRowSource(oneIntCol, rows) + b.SetBytes(int64(8 * numRows * numCols)) b.ResetTimer() for i := 0; i < b.N; i++ { // TODO(asubiotto): Get rid of uncleared state between // hashJoiner Run()s to omit instantiation time from benchmarks. - h, err := newHashJoiner(&flowCtx, &spec, leftInput, rightInput, &post, &RowDisposer{}) + h, err := newHashJoiner(flowCtx, spec, leftInput, rightInput, post, &RowDisposer{}) if err != nil { b.Fatal(err) } diff --git a/pkg/sql/distsqlrun/mergejoiner_test.go b/pkg/sql/distsqlrun/mergejoiner_test.go index 52af7a625ad1..48cc87d7c088 100644 --- a/pkg/sql/distsqlrun/mergejoiner_test.go +++ b/pkg/sql/distsqlrun/mergejoiner_test.go @@ -15,6 +15,7 @@ package distsqlrun import ( + "fmt" "testing" "golang.org/x/net/context" @@ -553,3 +554,48 @@ func TestConsumerClosed(t *testing.T) { }) } } + +func BenchmarkMergeJoiner(b *testing.B) { + ctx := context.Background() + evalCtx := tree.MakeTestingEvalContext() + defer evalCtx.Stop(ctx) + flowCtx := &FlowCtx{ + Settings: cluster.MakeTestingClusterSettings(), + EvalCtx: evalCtx, + } + + spec := &MergeJoinerSpec{ + LeftOrdering: convertToSpecOrdering( + sqlbase.ColumnOrdering{ + {ColIdx: 0, Direction: encoding.Ascending}, + }), + RightOrdering: convertToSpecOrdering( + sqlbase.ColumnOrdering{ + {ColIdx: 0, Direction: encoding.Ascending}, + }), + Type: JoinType_INNER, + // Implicit @1 = @2 constraint. + } + post := &PostProcessSpec{} + disposer := &RowDisposer{} + + const numCols = 1 + for _, inputSize := range []int{0, 1 << 2, 1 << 4, 1 << 8, 1 << 12, 1 << 16} { + b.Run(fmt.Sprintf("InputSize=%d", inputSize), func(b *testing.B) { + rows := makeIntRows(inputSize, numCols) + leftInput := NewRepeatableRowSource(oneIntCol, rows) + rightInput := NewRepeatableRowSource(oneIntCol, rows) + b.SetBytes(int64(8 * inputSize * numCols)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + m, err := newMergeJoiner(flowCtx, spec, leftInput, rightInput, post, disposer) + if err != nil { + b.Fatal(err) + } + m.Run(ctx, nil) + leftInput.Reset() + rightInput.Reset() + } + }) + } +}