From 6c46ad59c063fa6283fb23046300404767a82248 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 1 Oct 2018 14:37:46 +0800 Subject: [PATCH] Address comments --- .../sql/execution/benchmark/AggregateBenchmark.scala | 8 ++++---- .../spark/sql/execution/benchmark/SqlBasedBenchmark.scala | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 3ca2e6255041d..296ae104a94a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -47,7 +47,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { override def benchmark(): Unit = { runBenchmark("aggregate without grouping") { val N = 500L << 22 - runBenchmarkWithCodegen("agg w/o group", N) { + codegenBenchmark("agg w/o group", N) { spark.range(N).selectExpr("sum(id)").collect() } } @@ -55,11 +55,11 @@ object AggregateBenchmark extends SqlBasedBenchmark { runBenchmark("stat functions") { val N = 100L << 20 - runBenchmarkWithCodegen("stddev", N) { + codegenBenchmark("stddev", N) { spark.range(N).groupBy().agg("id" -> "stddev").collect() } - runBenchmarkWithCodegen("kurtosis", N) { + codegenBenchmark("kurtosis", N) { spark.range(N).groupBy().agg("id" -> "kurtosis").collect() } } @@ -313,7 +313,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { runBenchmark("cube") { val N = 5 << 20 - runBenchmarkWithCodegen("cube", N) { + codegenBenchmark("cube", N) { spark.range(N).selectExpr("id", "id % 1000 as k1", "id & 256 as k2") .cube("k1", "k2").sum("id").collect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala index 430f58d346a0d..e95e5a960246b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.internal.SQLConf */ trait SqlBasedBenchmark extends BenchmarkBase with SQLHelper { - val spark: SparkSession = getSparkSession + protected val spark: SparkSession = getSparkSession /** Subclass can override this function to build their own SparkSession */ def getSparkSession: SparkSession = { @@ -40,7 +40,7 @@ trait SqlBasedBenchmark extends BenchmarkBase with SQLHelper { } /** Runs function `f` with whole stage codegen on and off. */ - def runBenchmarkWithCodegen(name: String, cardinality: Long)(f: => Unit): Unit = { + final def codegenBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = { val benchmark = new Benchmark(name, cardinality, output = output) benchmark.addCase(s"$name wholestage off", numIters = 2) { _ =>