Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-16844][SQL] Support codegen for sort-based aggreagate #17164

Closed
wants to merge 3 commits into from

Conversation

maropu
Copy link
Member

@maropu maropu commented Mar 4, 2017

What changes were proposed in this pull request?

This pr supported codegen for SortAggregate.
This is the rework of #14481.

Close #14481

How was this patch tested?

Checked tests in DataFrameAggregateSuite.

@SparkQA
Copy link

SparkQA commented Mar 4, 2017

Test build #73907 has finished for PR 17164 at commit f2ccc65.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • abstract class AggregateExec extends UnaryExecNode
  • trait CodegenAggregateSupport extends CodegenSupport

@maropu
Copy link
Member Author

maropu commented Mar 4, 2017

A benchmark result:
https://github.com/apache/spark/pull/17164/files#diff-b7bf86a20a79d572f81093300568db6eR44

/*
range/limit/sum:                      Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
----------------------------------------------------------------------------------------------
range/limit/sum wholestage off               617 /  617         13.6          73.5       1.0X
range/limit/sum wholestage on                 70 /   92        120.2           8.3       8.8X
*/

 /*
aggregate non-sorted data:              Best/Avg Time(ms)   Rate(M/s)  Per Row(ns)   Relative
----------------------------------------------------------------------------------------------
non-sorted data wholestage off                2540 / 2735         3.3        302.8       1.0X
non-sorted data wholestage on                 1226 / 1528         6.8        146.1       2.1X
*/

/*
aggregate cached and sorted data:     Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
----------------------------------------------------------------------------------------------
cached and sorted data wholestage off       1455 / 1586          5.8         173.4       1.0X
cached and sorted data wholestage on         663 /  767         12.7          79.0       2.2X
*/

@maropu maropu changed the title [SPARK-16844][SQL] Support codegen for sort-based aggreagate [SPARK-16844][SQL][WIP] Support codegen for sort-based aggreagate Mar 4, 2017
@maropu
Copy link
Member Author

maropu commented Mar 4, 2017

import org.apache.spark.sql.execution.debug._
spark.conf.set("spark.sql.aggregate.preferSortAggregate", "true")
val df = spark.range(10).selectExpr("id % 2 AS key", "rand() AS value")
df.groupBy($"key").sum("value").debugCodegen

Found 2 WholeStageCodegen subtrees.
== Subtree 1 / 2 ==
*SortAggregate(key=[key#3L], functions=[sum(value#4)], output=[key#3L, sum(value)#12])
+- *Sort [key#3L ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(key#3L, 200)
      +- *SortAggregate(key=[key#3L], functions=[partial_sum(value#4)], output=[key#3L, sum#17])
         +- *Sort [key#3L ASC NULLS FIRST], false, 0
            +- *Project [(id#0L % 2) AS key#3L, rand(-2342342825239413884) AS value#4]
               +- *Range (0, 10, step=1, splits=Some(4))

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private org.apache.spark.sql.execution.metric.SQLMetric sagg_numOutputRows;
/* 009 */   private boolean sort_needToSort;
/* 010 */   private org.apache.spark.sql.execution.SortExec sort_plan;
/* 011 */   private org.apache.spark.sql.execution.UnsafeExternalRowSorter sort_sorter;
/* 012 */   private org.apache.spark.executor.TaskMetrics sort_metrics;
/* 013 */   private scala.collection.Iterator<UnsafeRow> sort_sortedIter;
/* 014 */   private scala.collection.Iterator inputadapter_input;
/* 015 */   private org.apache.spark.sql.execution.metric.SQLMetric sort_peakMemory;
/* 016 */   private org.apache.spark.sql.execution.metric.SQLMetric sort_spillSize;
/* 017 */   private org.apache.spark.sql.execution.metric.SQLMetric sort_sortTime;
/* 018 */   private UnsafeRow sagg_currentGroupingKey;
/* 019 */   private boolean sagg_bufIsNull;
/* 020 */   private double sagg_bufValue;
/* 021 */   private UnsafeRow sagg_result;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder sagg_holder;
/* 023 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter sagg_rowWriter;
/* 024 */   private UnsafeRow sagg_result1;
/* 025 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder sagg_holder1;
/* 026 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter sagg_rowWriter1;
/* 027 */   private org.apache.spark.sql.execution.aggregate.SortAggregateExec sagg_sortAggregate;
/* 028 */   private UnsafeRow sagg_result2;
/* 029 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder sagg_holder2;
/* 030 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter sagg_rowWriter2;
/* 031 */   private org.apache.spark.sql.execution.metric.SQLMetric wholestagecodegen_numOutputRows;
/* 032 */
/* 033 */   public GeneratedIterator(Object[] references) {
/* 034 */     this.references = references;
/* 035 */   }
/* 036 */
/* 037 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 038 */     partitionIndex = index;
/* 039 */     this.inputs = inputs;
/* 040 */     wholestagecodegen_init_0();
/* 041 */     wholestagecodegen_init_1();
/* 042 */ 
/* 043 */   }
/* 044 */ 
/* 045 */   private void wholestagecodegen_init_0() {
/* 046 */     this.sagg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 047 */     sort_needToSort = true;
/* 048 */     this.sort_plan = (org.apache.spark.sql.execution.SortExec) references[1];
/* 049 */     sort_sorter = sort_plan.createSorter();
/* 050 */     sort_metrics = org.apache.spark.TaskContext.get().taskMetrics();
/* 051 */ 
/* 052 */     inputadapter_input = inputs[0];
/* 053 */     this.sort_peakMemory = (org.apache.spark.sql.execution.metric.SQLMetric) references[2];
/* 054 */     this.sort_spillSize = (org.apache.spark.sql.execution.metric.SQLMetric) references[3];
/* 055 */     this.sort_sortTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[4];
/* 056 */     sagg_currentGroupingKey = null;
/* 057 */ 
/* 058 */     sagg_result = new UnsafeRow(1);
/* 059 */     this.sagg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(sagg_result, 0);
/* 060 */     this.sagg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(sagg_holder, 1);
/* 061 */     sagg_result1 = new UnsafeRow(1);
/* 062 */     this.sagg_holder1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(sagg_result1, 0);
/* 063 */     this.sagg_rowWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(sagg_holder1, 1);
/* 064 */ 
/* 065 */   }
/* 066 */ 
/* 067 */   private void wholestagecodegen_init_1() {
/* 068 */     this.sagg_sortAggregate = (org.apache.spark.sql.execution.aggregate.SortAggregateExec) references[5];
/* 069 */     sagg_result2 = new UnsafeRow(2);
/* 070 */     this.sagg_holder2 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(sagg_result2, 0);
/* 071 */     this.sagg_rowWriter2 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(sagg_holder2, 2);
/* 072 */     this.wholestagecodegen_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[6];
/* 073 */ 
/* 074 */   }
/* 075 */ 
/* 076 */   private void sort_addToSorter() throws java.io.IOException {
/* 077 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 078 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 079 */       sort_sorter.insertRow((UnsafeRow)inputadapter_row);
/* 080 */       if (shouldStop()) return;
/* 081 */     }
/* 082 */ 
/* 083 */   }
/* 084 */ 
/* 085 */   protected void processNext() throws java.io.IOException {
/* 086 */     if (sort_needToSort) {
/* 087 */       long sort_spillSizeBefore = sort_metrics.memoryBytesSpilled();
/* 088 */       sort_addToSorter();
/* 089 */       sort_sortedIter = sort_sorter.sort();
/* 090 */       sort_sortTime.add(sort_sorter.getSortTimeNanos() / 1000000);
/* 091 */       sort_peakMemory.add(sort_sorter.getPeakMemoryUsage());
/* 092 */       sort_spillSize.add(sort_metrics.memoryBytesSpilled() - sort_spillSizeBefore);
/* 093 */       sort_metrics.incPeakExecutionMemory(sort_sorter.getPeakMemoryUsage());
/* 094 */       sort_needToSort = false;
/* 095 */     }
/* 096 */ 
/* 097 */     while (sort_sortedIter.hasNext()) {
/* 098 */       UnsafeRow sort_outputRow = (UnsafeRow)sort_sortedIter.next();
/* 099 */ 
/* 100 */       boolean sort_isNull = sort_outputRow.isNullAt(0);
/* 101 */       long sort_value = sort_isNull ? -1L : (sort_outputRow.getLong(0));
/* 102 */ 
/* 103 */       // generate grouping keys
/* 104 */       sagg_rowWriter.zeroOutNullBytes();
/* 105 */ 
/* 106 */       if (sort_isNull) {
/* 107 */         sagg_rowWriter.setNullAt(0);
/* 108 */       } else {
/* 109 */         sagg_rowWriter.write(0, sort_value);
/* 110 */       }
/* 111 */ 
/* 112 */       if (sagg_currentGroupingKey == null) {
/* 113 */         sagg_currentGroupingKey = sagg_result.copy();
/* 114 */         // init aggregation buffer vars
/* 115 */         final double sagg_value = -1.0;
/* 116 */         sagg_bufIsNull = true;
/* 117 */         sagg_bufValue = sagg_value;
/* 118 */         // do aggregation
/* 119 */ 
/* 120 */         // do aggregate
/* 121 */         // common sub-expressions
/* 122 */ 
/* 123 */         // evaluate aggregate function
/* 124 */         boolean sagg_isNull2 = true;
/* 125 */         double sagg_value2 = -1.0;
/* 126 */ 
/* 127 */         boolean sagg_isNull3 = sagg_bufIsNull;
/* 128 */         double sagg_value3 = sagg_bufValue;
/* 129 */         if (sagg_isNull3) {
/* 130 */           boolean sagg_isNull5 = false;
/* 131 */           double sagg_value5 = -1.0;
/* 132 */           if (!false) {
/* 133 */             sagg_value5 = (double) 0;
/* 134 */           }
/* 135 */           if (!sagg_isNull5) {
/* 136 */             sagg_isNull3 = false;
/* 137 */             sagg_value3 = sagg_value5;
/* 138 */           }
/* 139 */         }
/* 140 */
/* 141 */         boolean sort_isNull1 = sort_outputRow.isNullAt(1);
/* 142 */         double sort_value1 = sort_isNull1 ? -1.0 : (sort_outputRow.getDouble(1));
/* 143 */         if (!sort_isNull1) {
/* 144 */           sagg_isNull2 = false; // resultCode could change nullability.
/* 145 */           sagg_value2 = sagg_value3 + sort_value1;
/* 146 */
/* 147 */         }
/* 148 */         boolean sagg_isNull1 = sagg_isNull2;
/* 149 */         double sagg_value1 = sagg_value2;
/* 150 */         if (sagg_isNull1) {
/* 151 */           if (!sagg_bufIsNull) {
/* 152 */             sagg_isNull1 = false;
/* 153 */             sagg_value1 = sagg_bufValue;
/* 154 */           }
/* 155 */         }
/* 156 */         // update aggregation buffer
/* 157 */         sagg_bufIsNull = sagg_isNull1;
/* 158 */         sagg_bufValue = sagg_value1;
/* 159 */ 
/* 160 */         // continue;
/* 161 */       } else {
/* 162 */         if (sagg_currentGroupingKey.equals(sagg_result)) {
/* 163 */           // do aggregate
/* 164 */           // common sub-expressions
/* 165 */
/* 166 */           // evaluate aggregate function
/* 167 */           boolean sagg_isNull2 = true;
/* 168 */           double sagg_value2 = -1.0;
/* 169 */
/* 170 */           boolean sagg_isNull3 = sagg_bufIsNull;
/* 171 */           double sagg_value3 = sagg_bufValue;
/* 172 */           if (sagg_isNull3) {
/* 173 */             boolean sagg_isNull5 = false;
/* 174 */             double sagg_value5 = -1.0;
/* 175 */             if (!false) {
/* 176 */               sagg_value5 = (double) 0;
/* 177 */             }
/* 178 */             if (!sagg_isNull5) {
/* 179 */               sagg_isNull3 = false;
/* 180 */               sagg_value3 = sagg_value5;
/* 181 */             }
/* 182 */           }
/* 183 */ 
/* 184 */           boolean sort_isNull1 = sort_outputRow.isNullAt(1);
/* 185 */           double sort_value1 = sort_isNull1 ? -1.0 : (sort_outputRow.getDouble(1));
/* 186 */           if (!sort_isNull1) {
/* 187 */             sagg_isNull2 = false; // resultCode could change nullability.
/* 188 */             sagg_value2 = sagg_value3 + sort_value1;
/* 189 */ 
/* 190 */           }
/* 191 */           boolean sagg_isNull1 = sagg_isNull2;
/* 192 */           double sagg_value1 = sagg_value2;
/* 193 */           if (sagg_isNull1) {
/* 194 */             if (!sagg_bufIsNull) {
/* 195 */               sagg_isNull1 = false;
/* 196 */               sagg_value1 = sagg_bufValue;
/* 197 */             }
/* 198 */           }
/* 199 */           // update aggregation buffer
/* 200 */           sagg_bufIsNull = sagg_isNull1;
/* 201 */           sagg_bufValue = sagg_value1;
/* 202 */
/* 203 */           // continue;
/* 204 */         } else {
/* 205 */           wholestagecodegen_numOutputRows.add(1);
/* 206 */
/* 207 */           sagg_rowWriter1.zeroOutNullBytes();
/* 208 */
/* 209 */           if (sagg_bufIsNull) {
/* 210 */             sagg_rowWriter1.setNullAt(0);
/* 211 */           } else {
/* 212 */             sagg_rowWriter1.write(0, sagg_bufValue);
/* 213 */           }
/* 214 */
/* 215 */           boolean sagg_isNull11 = sagg_currentGroupingKey.isNullAt(0);
/* 216 */           long sagg_value11 = sagg_isNull11 ? -1L : (sagg_currentGroupingKey.getLong(0));
/* 217 */           boolean sagg_isNull12 = sagg_result1.isNullAt(0);
/* 218 */           double sagg_value12 = sagg_isNull12 ? -1.0 : (sagg_result1.getDouble(0));
/* 219 */ 
/* 220 */           sagg_rowWriter2.zeroOutNullBytes();
/* 221 */ 
/* 222 */           if (sagg_isNull11) {
/* 223 */             sagg_rowWriter2.setNullAt(0);
/* 224 */           } else {
/* 225 */             sagg_rowWriter2.write(0, sagg_value11);
/* 226 */           }
/* 227 */ 
/* 228 */           if (sagg_isNull12) {
/* 229 */             sagg_rowWriter2.setNullAt(1);
/* 230 */           } else {
/* 231 */             sagg_rowWriter2.write(1, sagg_value12);
/* 232 */           }
/* 233 */           append(sagg_result2);
/* 234 */
/* 235 */           // init buffer vars for a next partition
/* 236 */           sagg_currentGroupingKey = sagg_result.copy();
/* 237 */           final double sagg_value = -1.0;
/* 238 */           sagg_bufIsNull = true;
/* 239 */           sagg_bufValue = sagg_value;
/* 240 */
/* 241 */           // do aggregate
/* 242 */           // common sub-expressions
/* 243 */
/* 244 */           // evaluate aggregate function
/* 245 */           boolean sagg_isNull2 = true;
/* 246 */           double sagg_value2 = -1.0;
/* 247 */ 
/* 248 */           boolean sagg_isNull3 = sagg_bufIsNull;
/* 249 */           double sagg_value3 = sagg_bufValue;
/* 250 */           if (sagg_isNull3) {
/* 251 */             boolean sagg_isNull5 = false;
/* 252 */             double sagg_value5 = -1.0;
/* 253 */             if (!false) {
/* 254 */               sagg_value5 = (double) 0;
/* 255 */             }
/* 256 */             if (!sagg_isNull5) {
/* 257 */               sagg_isNull3 = false;
/* 258 */               sagg_value3 = sagg_value5;
/* 259 */             }
/* 260 */           }
/* 261 */
/* 262 */           boolean sort_isNull1 = sort_outputRow.isNullAt(1);
/* 263 */           double sort_value1 = sort_isNull1 ? -1.0 : (sort_outputRow.getDouble(1));
/* 264 */           if (!sort_isNull1) {
/* 265 */             sagg_isNull2 = false; // resultCode could change nullability.
/* 266 */             sagg_value2 = sagg_value3 + sort_value1;
/* 267 */
/* 268 */           }
/* 269 */           boolean sagg_isNull1 = sagg_isNull2;
/* 270 */           double sagg_value1 = sagg_value2;
/* 271 */           if (sagg_isNull1) {
/* 272 */             if (!sagg_bufIsNull) {
/* 273 */               sagg_isNull1 = false;
/* 274 */               sagg_value1 = sagg_bufValue;
/* 275 */             }
/* 276 */           }
/* 277 */           // update aggregation buffer
/* 278 */           sagg_bufIsNull = sagg_isNull1;
/* 279 */           sagg_bufValue = sagg_value1;
/* 280 */ 
/* 281 */         }
/* 282 */       }
/* 283 */ 
/* 284 */       if (shouldStop()) return;
/* 285 */     }
/* 286 */ 
/* 287 */     if (sagg_currentGroupingKey != null) {
/* 288 */       // for the last aggregation
/* 289 */       sagg_numOutputRows.add(1);
/* 290 */ 
/* 291 */       sagg_rowWriter1.zeroOutNullBytes();
/* 292 */ 
/* 293 */       if (sagg_bufIsNull) {
/* 294 */         sagg_rowWriter1.setNullAt(0);
/* 295 */       } else {
/* 296 */         sagg_rowWriter1.write(0, sagg_bufValue);
/* 297 */       }
/* 298 */
/* 299 */       boolean sagg_isNull11 = sagg_currentGroupingKey.isNullAt(0);
/* 300 */       long sagg_value11 = sagg_isNull11 ? -1L : (sagg_currentGroupingKey.getLong(0));
/* 301 */       boolean sagg_isNull12 = sagg_result1.isNullAt(0);
/* 302 */       double sagg_value12 = sagg_isNull12 ? -1.0 : (sagg_result1.getDouble(0));
/* 303 */
/* 304 */       sagg_rowWriter2.zeroOutNullBytes();
/* 305 */
/* 306 */       if (sagg_isNull11) {
/* 307 */         sagg_rowWriter2.setNullAt(0);
/* 308 */       } else {
/* 309 */         sagg_rowWriter2.write(0, sagg_value11);
/* 310 */       }
/* 311 */
/* 312 */       if (sagg_isNull12) {
/* 313 */         sagg_rowWriter2.setNullAt(1);
/* 314 */       } else {
/* 315 */         sagg_rowWriter2.write(1, sagg_value12);
/* 316 */       }
/* 317 */       append(sagg_result2);
/* 318 */ 
/* 319 */       sagg_currentGroupingKey = null;
/* 320 */     }
/* 321 */   }
/* 322 */ }

== Subtree 2 / 2 ==
*SortAggregate(key=[key#3L], functions=[partial_sum(value#4)], output=[key#3L, sum#17])
+- *Sort [key#3L ASC NULLS FIRST], false, 0
   +- *Project [(id#0L % 2) AS key#3L, rand(-2342342825239413884) AS value#4]
      +- *Range (0, 10, step=1, splits=Some(4))

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private org.apache.spark.sql.execution.metric.SQLMetric sagg_numOutputRows;
/* 009 */   private boolean sort_needToSort;
/* 010 */   private org.apache.spark.sql.execution.SortExec sort_plan;
/* 011 */   private org.apache.spark.sql.execution.UnsafeExternalRowSorter sort_sorter;
/* 012 */   private org.apache.spark.executor.TaskMetrics sort_metrics;
/* 013 */   private scala.collection.Iterator<UnsafeRow> sort_sortedIter;
/* 014 */   private org.apache.spark.sql.execution.metric.SQLMetric range_numOutputRows;
/* 015 */   private org.apache.spark.sql.execution.metric.SQLMetric range_numGeneratedRows;
/* 016 */   private boolean range_initRange;
/* 017 */   private long range_number;
/* 018 */   private TaskContext range_taskContext;
/* 019 */   private InputMetrics range_inputMetrics;
/* 020 */   private long range_batchEnd;
/* 021 */   private long range_numElementsTodo;
/* 022 */   private scala.collection.Iterator range_input;
/* 023 */   private UnsafeRow range_result;
/* 024 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder;
/* 025 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter range_rowWriter;
/* 026 */   private org.apache.spark.util.random.XORShiftRandom project_rng;
/* 027 */   private UnsafeRow project_result;
/* 028 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder;
/* 029 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter project_rowWriter;
/* 030 */   private org.apache.spark.sql.execution.metric.SQLMetric sort_peakMemory;
/* 031 */   private org.apache.spark.sql.execution.metric.SQLMetric sort_spillSize;
/* 032 */   private org.apache.spark.sql.execution.metric.SQLMetric sort_sortTime;
/* 033 */   private UnsafeRow sagg_currentGroupingKey;
/* 034 */   private boolean sagg_bufIsNull;
/* 035 */   private double sagg_bufValue;
/* 036 */   private UnsafeRow sagg_result;
/* 037 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder sagg_holder;
/* 038 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter sagg_rowWriter;
/* 039 */   private UnsafeRow sagg_result1;
/* 040 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder sagg_holder1;
/* 041 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter sagg_rowWriter1;
/* 042 */   private org.apache.spark.sql.execution.aggregate.SortAggregateExec sagg_sortAggregate;
/* 043 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowJoiner sagg_unsafeRowJoiner;
/* 044 */   private org.apache.spark.sql.execution.metric.SQLMetric wholestagecodegen_numOutputRows;
/* 045 */ 
/* 046 */   public GeneratedIterator(Object[] references) {
/* 047 */     this.references = references;
/* 048 */   }
/* 049 */ 
/* 050 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 051 */     partitionIndex = index;
/* 052 */     this.inputs = inputs;
/* 053 */     wholestagecodegen_init_0();
/* 054 */     wholestagecodegen_init_1();
/* 055 */     wholestagecodegen_init_2();
/* 056 */     project_rng = new org.apache.spark.util.random.XORShiftRandom(-2342342825239413884L + partitionIndex);
/* 057 */   }
/* 058 */
/* 059 */   private void wholestagecodegen_init_0() {
/* 060 */     this.sagg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 061 */     sort_needToSort = true;
/* 062 */     this.sort_plan = (org.apache.spark.sql.execution.SortExec) references[1];
/* 063 */     sort_sorter = sort_plan.createSorter();
/* 064 */     sort_metrics = org.apache.spark.TaskContext.get().taskMetrics();
/* 065 */
/* 066 */     this.range_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2];
/* 067 */     this.range_numGeneratedRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[3];
/* 068 */     range_initRange = false;
/* 069 */     range_number = 0L;
/* 070 */     range_taskContext = TaskContext.get();
/* 071 */     range_inputMetrics = range_taskContext.taskMetrics().inputMetrics();
/* 072 */     range_batchEnd = 0;
/* 073 */     range_numElementsTodo = 0L;
/* 074 */     range_input = inputs[0];
/* 075 */     range_result = new UnsafeRow(1);
/* 076 */     this.range_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0);
/* 077 */     this.range_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 1);
/* 078 */
/* 079 */     project_result = new UnsafeRow(2);
/* 080 */     this.project_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result, 0);
/* 081 */
/* 082 */   }
/* 083 */
/* 084 */   private void wholestagecodegen_init_2() {
/* 085 */     this.wholestagecodegen_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[8];
/* 086 */
/* 087 */   }
/* 088 */
/* 089 */   private void wholestagecodegen_init_1() {
/* 090 */     this.project_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder, 2);
/* 091 */     this.sort_peakMemory = (org.apache.spark.sql.execution.metric.SQLMetric) references[4];
/* 092 */     this.sort_spillSize = (org.apache.spark.sql.execution.metric.SQLMetric) references[5];
/* 093 */     this.sort_sortTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[6];
/* 094 */     sagg_currentGroupingKey = null;
/* 095 */
/* 096 */     sagg_result = new UnsafeRow(1);
/* 097 */     this.sagg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(sagg_result, 0);
/* 098 */     this.sagg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(sagg_holder, 1);
/* 099 */     sagg_result1 = new UnsafeRow(1);
/* 100 */     this.sagg_holder1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(sagg_result1, 0);
/* 101 */     this.sagg_rowWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(sagg_holder1, 1);
/* 102 */     this.sagg_sortAggregate = (org.apache.spark.sql.execution.aggregate.SortAggregateExec) references[7];
/* 103 */     sagg_unsafeRowJoiner = sagg_sortAggregate.createUnsafeJoiner();
/* 104 */
/* 105 */   }
/* 106 */
/* 107 */   private void initRange(int idx) {
/* 108 */     java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 109 */     java.math.BigInteger numSlice = java.math.BigInteger.valueOf(4L);
/* 110 */     java.math.BigInteger numElement = java.math.BigInteger.valueOf(10L);
/* 111 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 112 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 113 */     long partitionEnd;
/* 114 */
/* 115 */     java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 116 */     if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 117 */       range_number = Long.MAX_VALUE;
/* 118 */     } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 119 */       range_number = Long.MIN_VALUE;
/* 120 */     } else {
/* 121 */       range_number = st.longValue();
/* 122 */     }
/* 123 */     range_batchEnd = range_number;
/* 124 */
/* 125 */     java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 126 */     .multiply(step).add(start);
/* 127 */     if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 128 */       partitionEnd = Long.MAX_VALUE;
/* 129 */     } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 130 */       partitionEnd = Long.MIN_VALUE;
/* 131 */     } else {
/* 132 */       partitionEnd = end.longValue();
/* 133 */     }
/* 134 */
/* 135 */     java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 136 */       java.math.BigInteger.valueOf(range_number));
/* 137 */     range_numElementsTodo  = startToEnd.divide(step).longValue();
/* 138 */     if (range_numElementsTodo < 0) {
/* 139 */       range_numElementsTodo = 0;
/* 140 */     } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 141 */       range_numElementsTodo++;
/* 142 */     }
/* 143 */   }
/* 144 */
/* 145 */   private void sort_addToSorter() throws java.io.IOException {
/* 146 */     // initialize Range
/* 147 */     if (!range_initRange) {
/* 148 */       range_initRange = true;
/* 149 */       initRange(partitionIndex);
/* 150 */     }
/* 151 */
/* 152 */     while (true) {
/* 153 */       while (range_number != range_batchEnd) {
/* 154 */         long range_value = range_number;
/* 155 */         range_number += 1L;
/* 156 */
/* 157 */         final double project_value3 = project_rng.nextDouble();
/* 158 */
/* 159 */         boolean project_isNull = false;
/* 160 */         long project_value = -1L;
/* 161 */         if (2L == 0) {
/* 162 */           project_isNull = true;
/* 163 */         } else {
/* 164 */           project_value = (long)(range_value % 2L);
/* 165 */         }
/* 166 */         project_rowWriter.zeroOutNullBytes();
/* 167 */ 
/* 168 */         if (project_isNull) {
/* 169 */           project_rowWriter.setNullAt(0);
/* 170 */         } else {
/* 171 */           project_rowWriter.write(0, project_value);
/* 172 */         }
/* 173 */
/* 174 */         project_rowWriter.write(1, project_value3);
/* 175 */         sort_sorter.insertRow((UnsafeRow)project_result);
/* 176 */
/* 177 */         if (shouldStop()) return;
/* 178 */       }
/* 179 */
/* 180 */       if (range_taskContext.isInterrupted()) {
/* 181 */         throw new TaskKilledException();
/* 182 */       }
/* 183 */
/* 184 */       long range_nextBatchTodo;
/* 185 */       if (range_numElementsTodo > 1000L) {
/* 186 */         range_nextBatchTodo = 1000L;
/* 187 */         range_numElementsTodo -= 1000L;
/* 188 */       } else {
/* 189 */         range_nextBatchTodo = range_numElementsTodo;
/* 190 */         range_numElementsTodo = 0;
/* 191 */         if (range_nextBatchTodo == 0) break;
/* 192 */       }
/* 193 */       range_numOutputRows.add(range_nextBatchTodo);
/* 194 */       range_inputMetrics.incRecordsRead(range_nextBatchTodo);
/* 195 */
/* 196 */       range_batchEnd += range_nextBatchTodo * 1L;
/* 197 */     }
/* 198 */
/* 199 */   }
/* 200 */
/* 201 */   protected void processNext() throws java.io.IOException {
/* 202 */     if (sort_needToSort) {
/* 203 */       long sort_spillSizeBefore = sort_metrics.memoryBytesSpilled();
/* 204 */       sort_addToSorter();
/* 205 */       sort_sortedIter = sort_sorter.sort();
/* 206 */       sort_sortTime.add(sort_sorter.getSortTimeNanos() / 1000000);
/* 207 */       sort_peakMemory.add(sort_sorter.getPeakMemoryUsage());
/* 208 */       sort_spillSize.add(sort_metrics.memoryBytesSpilled() - sort_spillSizeBefore);
/* 209 */       sort_metrics.incPeakExecutionMemory(sort_sorter.getPeakMemoryUsage());
/* 210 */       sort_needToSort = false;
/* 211 */     }
/* 212 */
/* 213 */     while (sort_sortedIter.hasNext()) {
/* 214 */       UnsafeRow sort_outputRow = (UnsafeRow)sort_sortedIter.next();
/* 215 */
/* 216 */       boolean sort_isNull = sort_outputRow.isNullAt(0);
/* 217 */       long sort_value = sort_isNull ? -1L : (sort_outputRow.getLong(0));
/* 218 */       double sort_value1 = sort_outputRow.getDouble(1);
/* 219 */
/* 220 */       // generate grouping keys
/* 221 */       sagg_rowWriter.zeroOutNullBytes();
/* 222 */
/* 223 */       if (sort_isNull) {
/* 224 */         sagg_rowWriter.setNullAt(0);
/* 225 */       } else {
/* 226 */         sagg_rowWriter.write(0, sort_value);
/* 227 */       }
/* 228 */
/* 229 */       if (sagg_currentGroupingKey == null) {
/* 230 */         sagg_currentGroupingKey = sagg_result.copy();
/* 231 */         // init aggregation buffer vars
/* 232 */         final double sagg_value = -1.0;
/* 233 */         sagg_bufIsNull = true;
/* 234 */         sagg_bufValue = sagg_value;
/* 235 */         // do aggregation
/* 236 */
/* 237 */         // do aggregate
/* 238 */         // common sub-expressions
/* 239 */
/* 240 */         // evaluate aggregate function
/* 241 */         boolean sagg_isNull1 = false;
/* 242 */
/* 243 */         boolean sagg_isNull2 = sagg_bufIsNull;
/* 244 */         double sagg_value2 = sagg_bufValue;
/* 245 */         if (sagg_isNull2) {
/* 246 */           boolean sagg_isNull4 = false;
/* 247 */           double sagg_value4 = -1.0;
/* 248 */           if (!false) {
/* 249 */             sagg_value4 = (double) 0;
/* 250 */           }
/* 251 */           if (!sagg_isNull4) {
/* 252 */             sagg_isNull2 = false;
/* 253 */             sagg_value2 = sagg_value4;
/* 254 */           }
/* 255 */         }
/* 256 */
/* 257 */         boolean sagg_isNull6 = false;
/* 258 */         double sagg_value6 = -1.0;
/* 259 */         if (!false) {
/* 260 */           sagg_value6 = sort_value1;
/* 261 */         }
/* 262 */         double sagg_value1 = -1.0;
/* 263 */         sagg_value1 = sagg_value2 + sagg_value6;
/* 264 */         // update aggregation buffer
/* 265 */         sagg_bufIsNull = false;
/* 266 */         sagg_bufValue = sagg_value1;
/* 267 */
/* 268 */         // continue;
/* 269 */       } else {
/* 270 */         if (sagg_currentGroupingKey.equals(sagg_result)) {
/* 271 */           // do aggregate
/* 272 */           // common sub-expressions
/* 273 */
/* 274 */           // evaluate aggregate function
/* 275 */           boolean sagg_isNull1 = false;
/* 276 */
/* 277 */           boolean sagg_isNull2 = sagg_bufIsNull;
/* 278 */           double sagg_value2 = sagg_bufValue;
/* 279 */           if (sagg_isNull2) {
/* 280 */             boolean sagg_isNull4 = false;
/* 281 */             double sagg_value4 = -1.0;
/* 282 */             if (!false) {
/* 283 */               sagg_value4 = (double) 0;
/* 284 */             }
/* 285 */             if (!sagg_isNull4) {
/* 286 */               sagg_isNull2 = false;
/* 287 */               sagg_value2 = sagg_value4;
/* 288 */             }
/* 289 */           }
/* 290 */
/* 291 */           boolean sagg_isNull6 = false;
/* 292 */           double sagg_value6 = -1.0;
/* 293 */           if (!false) {
/* 294 */             sagg_value6 = sort_value1;
/* 295 */           }
/* 296 */           double sagg_value1 = -1.0;
/* 297 */           sagg_value1 = sagg_value2 + sagg_value6;
/* 298 */           // update aggregation buffer
/* 299 */           sagg_bufIsNull = false;
/* 300 */           sagg_bufValue = sagg_value1;
/* 301 */
/* 302 */           // continue;
/* 303 */         } else {
/* 304 */           wholestagecodegen_numOutputRows.add(1);
/* 305 */
/* 306 */           sagg_rowWriter1.zeroOutNullBytes();
/* 307 */
/* 308 */           if (sagg_bufIsNull) {
/* 309 */             sagg_rowWriter1.setNullAt(0);
/* 310 */           } else {
/* 311 */             sagg_rowWriter1.write(0, sagg_bufValue);
/* 312 */           }
/* 313 */
/* 314 */           UnsafeRow sagg_resultRow = sagg_unsafeRowJoiner.join(sagg_currentGroupingKey, sagg_result1);
/* 315 */
/* 316 */           append(sagg_resultRow);
/* 317 */
/* 318 */           // init buffer vars for a next partition
/* 319 */           sagg_currentGroupingKey = sagg_result.copy();
/* 320 */           final double sagg_value = -1.0;
/* 321 */           sagg_bufIsNull = true;
/* 322 */           sagg_bufValue = sagg_value;
/* 323 */
/* 324 */           // do aggregate
/* 325 */           // common sub-expressions
/* 326 */
/* 327 */           // evaluate aggregate function
/* 328 */           boolean sagg_isNull1 = false;
/* 329 */
/* 330 */           boolean sagg_isNull2 = sagg_bufIsNull;
/* 331 */           double sagg_value2 = sagg_bufValue;
/* 332 */           if (sagg_isNull2) {
/* 333 */             boolean sagg_isNull4 = false;
/* 334 */             double sagg_value4 = -1.0;
/* 335 */             if (!false) {
/* 336 */               sagg_value4 = (double) 0;
/* 337 */             }
/* 338 */             if (!sagg_isNull4) {
/* 339 */               sagg_isNull2 = false;
/* 340 */               sagg_value2 = sagg_value4;
/* 341 */             }
/* 342 */           }
/* 343 */
/* 344 */           boolean sagg_isNull6 = false;
/* 345 */           double sagg_value6 = -1.0;
/* 346 */           if (!false) {
/* 347 */             sagg_value6 = sort_value1;
/* 348 */           }
/* 349 */           double sagg_value1 = -1.0;
/* 350 */           sagg_value1 = sagg_value2 + sagg_value6;
/* 351 */           // update aggregation buffer
/* 352 */           sagg_bufIsNull = false;
/* 353 */           sagg_bufValue = sagg_value1;
/* 354 */ 
/* 353 */           sagg_bufValue = sagg_value1;
/* 354 */ 
/* 355 */         }
/* 356 */       }
/* 357 */
/* 358 */       if (shouldStop()) return;
/* 359 */     }
/* 360 */
/* 361 */     if (sagg_currentGroupingKey != null) {
/* 362 */       // for the last aggregation
/* 363 */       sagg_numOutputRows.add(1);
/* 364 */
/* 365 */       sagg_rowWriter1.zeroOutNullBytes();
/* 366 */
/* 367 */       if (sagg_bufIsNull) {
/* 368 */         sagg_rowWriter1.setNullAt(0);
/* 369 */       } else {
/* 370 */         sagg_rowWriter1.write(0, sagg_bufValue);
/* 371 */       }
/* 372 */
/* 373 */       UnsafeRow sagg_resultRow = sagg_unsafeRowJoiner.join(sagg_currentGroupingKey, sagg_result1);
/* 374 */
/* 375 */       append(sagg_resultRow);
/* 376 */
/* 377 */       sagg_currentGroupingKey = null;
/* 378 */     }
/* 379 */   }
/* 380 */ }

@maropu
Copy link
Member Author

maropu commented Mar 4, 2017

import org.apache.spark.sql.execution.debug._
spark.conf.set("spark.sql.aggregate.preferSortAggregate", "true")
val df = spark.range(10).selectExpr("id % 2 AS key", "rand() AS value")
df.groupBy().count.debugCodegen

Found 2 WholeStageCodegen subtrees.
== Subtree 1 / 2 ==
*SortAggregate(key=[], functions=[partial_count(1)], output=[count#51L])
+- *Project
   +- *Range (0, 10, step=1, splits=Some(4))

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private boolean sagg_initAgg;
/* 009 */   private boolean sagg_bufIsNull;
/* 010 */   private long sagg_bufValue;
/* 011 */   private org.apache.spark.sql.execution.metric.SQLMetric range_numOutputRows;
/* 012 */   private org.apache.spark.sql.execution.metric.SQLMetric range_numGeneratedRows;
/* 013 */   private boolean range_initRange;
/* 014 */   private long range_number;
/* 015 */   private TaskContext range_taskContext;
/* 016 */   private InputMetrics range_inputMetrics;
/* 017 */   private long range_batchEnd;
/* 018 */   private long range_numElementsTodo;
/* 019 */   private scala.collection.Iterator range_input;
/* 020 */   private UnsafeRow range_result;
/* 021 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter range_rowWriter;
/* 023 */   private org.apache.spark.sql.execution.metric.SQLMetric sagg_numOutputRows;
/* 024 */   private org.apache.spark.sql.execution.metric.SQLMetric sagg_aggTime;
/* 025 */   private UnsafeRow sagg_result;
/* 026 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder sagg_holder;
/* 027 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter sagg_rowWriter;
/* 028 */
/* 029 */   public GeneratedIterator(Object[] references) {
/* 030 */     this.references = references;
/* 031 */   }
/* 032 */
/* 033 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 034 */     partitionIndex = index;
/* 035 */     this.inputs = inputs;
/* 036 */     sagg_initAgg = false;
/* 037 */
/* 038 */     this.range_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 039 */     this.range_numGeneratedRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[1];
/* 040 */     range_initRange = false;
/* 041 */     range_number = 0L;
/* 042 */     range_taskContext = TaskContext.get();
/* 043 */     range_inputMetrics = range_taskContext.taskMetrics().inputMetrics();
/* 044 */     range_batchEnd = 0;
/* 045 */     range_numElementsTodo = 0L;
/* 046 */     range_input = inputs[0];
/* 047 */     range_result = new UnsafeRow(1);
/* 048 */     this.range_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0);
/* 049 */     this.range_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 1);
/* 050 */     this.sagg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2];
/* 051 */     this.sagg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[3];
/* 052 */     sagg_result = new UnsafeRow(1);
/* 053 */     this.sagg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(sagg_result, 0);
/* 054 */     this.sagg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(sagg_holder, 1);
/* 055 */
/* 056 */   }
/* 057 */
/* 058 */   private void sagg_doAggregateWithoutKey() throws java.io.IOException {
/* 059 */     // initialize aggregation buffer
/* 060 */     sagg_bufIsNull = false;
/* 061 */     sagg_bufValue = 0L;
/* 062 */
/* 063 */     // initialize Range
/* 064 */     if (!range_initRange) {
/* 065 */       range_initRange = true;
/* 066 */       initRange(partitionIndex);
/* 067 */     }
/* 068 */
/* 069 */     while (true) {
/* 070 */       while (range_number != range_batchEnd) {
/* 071 */         long range_value = range_number;
/* 072 */         range_number += 1L;
/* 073 */
/* 074 */         // do aggregate
/* 075 */         // common sub-expressions
/* 076 */
/* 077 */         // evaluate aggregate function
/* 078 */         boolean sagg_isNull1 = false;
/* 079 */
/* 080 */         long sagg_value1 = -1L;
/* 081 */         sagg_value1 = sagg_bufValue + 1L;
/* 082 */         // update aggregation buffer
/* 083 */         sagg_bufIsNull = false;
/* 084 */         sagg_bufValue = sagg_value1;
/* 085 */
/* 086 */         if (shouldStop()) return;
/* 087 */       }
/* 088 */
/* 089 */       if (range_taskContext.isInterrupted()) {
/* 090 */         throw new TaskKilledException();
/* 091 */       }
/* 092 */
/* 093 */       long range_nextBatchTodo;
/* 094 */       if (range_numElementsTodo > 1000L) {
/* 095 */         range_nextBatchTodo = 1000L;
/* 096 */         range_numElementsTodo -= 1000L;
/* 097 */       } else {
/* 098 */         range_nextBatchTodo = range_numElementsTodo;
/* 099 */         range_numElementsTodo = 0;
/* 100 */         if (range_nextBatchTodo == 0) break;
/* 101 */       }
/* 102 */       range_numOutputRows.add(range_nextBatchTodo);
/* 103 */       range_inputMetrics.incRecordsRead(range_nextBatchTodo);
/* 104 */
/* 105 */       range_batchEnd += range_nextBatchTodo * 1L;
/* 106 */     }
/* 107 */
/* 108 */   }
/* 109 */
/* 110 */   private void initRange(int idx) {
/* 111 */     java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 112 */     java.math.BigInteger numSlice = java.math.BigInteger.valueOf(4L);
/* 113 */     java.math.BigInteger numElement = java.math.BigInteger.valueOf(10L);
/* 114 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 115 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 116 */     long partitionEnd;
/* 117 */
/* 118 */     java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 119 */     if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 120 */       range_number = Long.MAX_VALUE;
/* 121 */     } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 122 */       range_number = Long.MIN_VALUE;
/* 123 */     } else {
/* 124 */       range_number = st.longValue();
/* 125 */     }
/* 126 */     range_batchEnd = range_number;
/* 127 */
/* 128 */     java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 129 */     .multiply(step).add(start);
/* 130 */     if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 131 */       partitionEnd = Long.MAX_VALUE;
/* 132 */     } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 133 */       partitionEnd = Long.MIN_VALUE;
/* 134 */     } else {
/* 135 */       partitionEnd = end.longValue();
/* 136 */     }
/* 137 */
/* 138 */     java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 139 */       java.math.BigInteger.valueOf(range_number));
/* 140 */     range_numElementsTodo  = startToEnd.divide(step).longValue();
/* 141 */     if (range_numElementsTodo < 0) {
/* 142 */       range_numElementsTodo = 0;
/* 143 */     } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 144 */       range_numElementsTodo++;
/* 145 */     }
/* 146 */   }
/* 147 */
/* 148 */   protected void processNext() throws java.io.IOException {
/* 149 */     while (!sagg_initAgg) {
/* 150 */       sagg_initAgg = true;
/* 151 */       long sagg_beforeAgg = System.nanoTime();
/* 152 */       sagg_doAggregateWithoutKey();
/* 153 */       sagg_aggTime.add((System.nanoTime() - sagg_beforeAgg) / 1000000);
/* 154 */
/* 155 */       // output the result
/* 156 */
/* 157 */       sagg_numOutputRows.add(1);
/* 158 */       sagg_rowWriter.zeroOutNullBytes();
/* 159 */
/* 160 */       if (sagg_bufIsNull) {
/* 161 */         sagg_rowWriter.setNullAt(0);
/* 162 */       } else {
/* 163 */         sagg_rowWriter.write(0, sagg_bufValue);
/* 164 */       }
/* 165 */       append(sagg_result);
/* 166 */     }
/* 167 */   }
/* 168 */ }

== Subtree 2 / 2 ==
*SortAggregate(key=[], functions=[count(1)], output=[count#47L])
+- Exchange SinglePartition
   +- *SortAggregate(key=[], functions=[partial_count(1)], output=[count#51L])
      +- *Project
         +- *Range (0, 10, step=1, splits=Some(4))

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private boolean sagg_initAgg;
/* 009 */   private boolean sagg_bufIsNull;
/* 010 */   private long sagg_bufValue;
/* 011 */   private scala.collection.Iterator inputadapter_input;
/* 012 */   private org.apache.spark.sql.execution.metric.SQLMetric sagg_numOutputRows;
/* 013 */   private org.apache.spark.sql.execution.metric.SQLMetric sagg_aggTime;
/* 014 */   private UnsafeRow sagg_result;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder sagg_holder;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter sagg_rowWriter;
/* 017 */
/* 018 */   public GeneratedIterator(Object[] references) {
/* 019 */     this.references = references;
/* 020 */   }
/* 021 */
/* 022 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */     partitionIndex = index;
/* 024 */     this.inputs = inputs;
/* 025 */     sagg_initAgg = false;
/* 026 */
/* 027 */     inputadapter_input = inputs[0];
/* 028 */     this.sagg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 029 */     this.sagg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[1];
/* 030 */     sagg_result = new UnsafeRow(1);
/* 031 */     this.sagg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(sagg_result, 0);
/* 032 */     this.sagg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(sagg_holder, 1);
/* 033 */
/* 034 */   }
/* 035 */
/* 036 */   private void sagg_doAggregateWithoutKey() throws java.io.IOException {
/* 037 */     // initialize aggregation buffer
/* 038 */     sagg_bufIsNull = false;
/* 039 */     sagg_bufValue = 0L;
/* 040 */
/* 041 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 042 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 043 */       // do aggregate
/* 044 */       // common sub-expressions
/* 045 */
/* 046 */       // evaluate aggregate function
/* 047 */       boolean sagg_isNull3 = false;
/* 048 */
/* 049 */       long inputadapter_value = inputadapter_row.getLong(0);
/* 050 */       long sagg_value3 = -1L;
/* 051 */       sagg_value3 = sagg_bufValue + inputadapter_value;
/* 052 */       // update aggregation buffer
/* 053 */       sagg_bufIsNull = false;
/* 054 */       sagg_bufValue = sagg_value3;
/* 055 */       if (shouldStop()) return;
/* 056 */     }
/* 057 */
/* 058 */   }
/* 059 */
/* 060 */   protected void processNext() throws java.io.IOException {
/* 061 */     while (!sagg_initAgg) {
/* 062 */       sagg_initAgg = true;
/* 063 */       long sagg_beforeAgg = System.nanoTime();
/* 064 */       sagg_doAggregateWithoutKey();
/* 065 */       sagg_aggTime.add((System.nanoTime() - sagg_beforeAgg) / 1000000);
/* 066 */
/* 067 */       // output the result
/* 068 */
/* 069 */       sagg_numOutputRows.add(1);
/* 070 */       sagg_rowWriter.zeroOutNullBytes();
/* 071 */
/* 072 */       if (sagg_bufIsNull) {
/* 073 */         sagg_rowWriter.setNullAt(0);
/* 074 */       } else {
/* 075 */         sagg_rowWriter.write(0, sagg_bufValue);
/* 076 */       }
/* 077 */       append(sagg_result);
/* 078 */     }
/* 079 */   }
/* 080 */ }

@maropu
Copy link
Member Author

maropu commented Mar 4, 2017

@hvanhovell I reworked #14481 though, I'm not sure it is still worth trying this codegen. Could you give me insight first? Thanks!

@SparkQA
Copy link

SparkQA commented Mar 4, 2017

Test build #73908 has finished for PR 17164 at commit 9a26a0a.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • abstract class AggregateExec extends UnaryExecNode
  • trait CodegenAggregateSupport extends CodegenSupport

@hvanhovell
Copy link
Contributor

@maropu I think this is pretty exciting. This is very useful in situations where we have a lot of groups, in that case I will happily take a 2x performance improvement any day. This is still pretty decent if you consider that this aggregate is dominate by sorting.

@maropu
Copy link
Member Author

maropu commented Mar 4, 2017

@hvanhovell okay! I'll brush up code, then if I finished, I'll let you know for code review. Thanks.

@maropu maropu force-pushed the SPARK-16844 branch 3 times, most recently from a45048a to b29c22d Compare March 5, 2017 09:53
@SparkQA
Copy link

SparkQA commented Mar 5, 2017

Test build #73924 has finished for PR 17164 at commit f63e663.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • abstract class AggregateExec extends UnaryExecNode
  • trait CodegenAggregateSupport extends CodegenSupport

@SparkQA
Copy link

SparkQA commented Mar 5, 2017

Test build #73925 has finished for PR 17164 at commit a45048a.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • trait AggregateCodegenHelper
  • abstract class AggregateExec extends UnaryExecNode

@SparkQA
Copy link

SparkQA commented Mar 5, 2017

Test build #73926 has finished for PR 17164 at commit b29c22d.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • trait AggregateCodegenHelper
  • abstract class AggregateExec extends UnaryExecNode

@SparkQA
Copy link

SparkQA commented Mar 5, 2017

Test build #73935 has finished for PR 17164 at commit 5b138f7.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • trait AggregateCodegenHelper
  • abstract class AggregateExec extends UnaryExecNode

@maropu maropu force-pushed the SPARK-16844 branch 2 times, most recently from 8413dd7 to 29c713b Compare March 5, 2017 23:58
@SparkQA
Copy link

SparkQA commented Mar 6, 2017

Test build #73948 has finished for PR 17164 at commit 8413dd7.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • trait AggregateCodegenHelper
  • abstract class AggregateExec extends UnaryExecNode

@SparkQA
Copy link

SparkQA commented Mar 6, 2017

Test build #73949 has finished for PR 17164 at commit 29c713b.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • trait AggregateCodegenHelper
  • abstract class AggregateExec extends UnaryExecNode

@SparkQA
Copy link

SparkQA commented Mar 6, 2017

Test build #73952 has finished for PR 17164 at commit 6af8064.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • trait AggregateCodegenHelper
  • abstract class AggregateExec extends UnaryExecNode

@SparkQA
Copy link

SparkQA commented Mar 6, 2017

Test build #73974 has started for PR 17164 at commit d5cc0f0.

@maropu
Copy link
Member Author

maropu commented Mar 6, 2017

Jenkins, retest this please.

@SparkQA
Copy link

SparkQA commented Mar 6, 2017

Test build #73986 has finished for PR 17164 at commit d5cc0f0.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • trait AggregateCodegenHelper
  • abstract class AggregateExec extends UnaryExecNode

@maropu
Copy link
Member Author

maropu commented Mar 6, 2017

Jenkins, retest this please.

@SparkQA
Copy link

SparkQA commented Mar 6, 2017

Test build #73993 has finished for PR 17164 at commit 2a018cb.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 7, 2017

Test build #74066 has finished for PR 17164 at commit fc01d07.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • trait AggregateCodegenHelper
  • abstract class AggregateExec extends UnaryExecNode

@maropu
Copy link
Member Author

maropu commented Mar 7, 2017

This pr added an new SQL option spark.sql.aggregate.preferSortAggregate to preferably select SortAggregate for easy-to-test in DataFrameAggregateSuite.scala. In some cases (e.g., input data is already sorted in cache), sort aggregate is faster than hash one (See: https://issues.apache.org/jira/browse/SPARK-18591). But, you know, the current spark does not adaptively select sort aggregate in these cases. So, I probably think this option is some useful to control aggregate strategies by user. What do u think? cc: @hvanhovell If yes, I'd like to make another pr to add this option before this pr reviewed. master...maropu:SPARK-16844-3

@maropu maropu changed the title [SPARK-16844][SQL][WIP] Support codegen for sort-based aggreagate [SPARK-16844][SQL] Support codegen for sort-based aggreagate Mar 7, 2017
@maropu
Copy link
Member Author

maropu commented Mar 10, 2017

@hvanhovell ping

val aggTime = metricTerm(ctx, "aggTime")
val beforeAgg = ctx.freshName("beforeAgg")
s"""
| while (!$initAgg) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use while? Can we use if instead of while?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC we can't because continue may exist in ${consume(ctx, resultVars).trim}.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks

| $currentGroupingKeyTerm = null;
|
| if (shouldStop()) return;
| } while (false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this do { } while (false);?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

@maropu
Copy link
Member Author

maropu commented Mar 14, 2017

@hvanhovell ping

@SparkQA
Copy link

SparkQA commented Mar 15, 2017

Test build #74569 has finished for PR 17164 at commit 5baa928.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • trait AggregateCodegenHelper
  • abstract class AggregateExec extends UnaryExecNode

@maropu
Copy link
Member Author

maropu commented Mar 18, 2017

@hvanhovell ping

@hvanhovell
Copy link
Contributor

@maropu I do think this is useful. However we really need to refactor the planner, if we want to get the most value from this.

@maropu
Copy link
Member Author

maropu commented Mar 27, 2017

okay, it'd be better to close this?

@hvanhovell
Copy link
Contributor

@maropu I am not sure. I like to keep interesting PRs open. Would you be interested in doing some work on the planner?

@maropu
Copy link
Member Author

maropu commented Mar 27, 2017

okay, I keep this open. Yea, sure and I'm interested in. If there are sub-tasks for that, I'd be grad if you ping me. Thanks!

@gatorsmile
Copy link
Member

@maropu Maybe we can close this PR at first?

@maropu
Copy link
Member Author

maropu commented Jun 13, 2017

ok

@maropu maropu closed this Jun 13, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants