Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-16844][SQL] Generate code for sort based aggregation #14481

Closed
wants to merge 2 commits into from

Conversation

yucai
Copy link
Contributor

@yucai yucai commented Aug 3, 2016

This PR is in internal review and will ask for community review later.

@hvanhovell
Copy link
Contributor

Ok to test

@SparkQA
Copy link

SparkQA commented Aug 4, 2016

Test build #3202 has finished for PR 14481 at commit dc0f040.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@yucai
Copy link
Contributor Author

yucai commented Aug 15, 2016

retest this please

@hvanhovell
Copy link
Contributor

@yucai could you post some benchmark results? I would think that the overall runtime of the sort based aggregation path is dominated by the preceding exchange and sort operations, and that as a result this will not yield a enormous speed-up.

Could you also post the generated code for a simple case? The helps during the review.

@yucai
Copy link
Contributor Author

yucai commented Aug 16, 2016

@hvanhovell thanks very much for the advice, yes, I will post the benchmark results first.
And it is WIP, I will post a generated codes, but kindly not review the codes details at present, I am pretty sure some place could be improved :).

@yucai
Copy link
Contributor Author

yucai commented Aug 17, 2016

@hvanhovell
Benchmark Result

Summary
We benchmark sortagg code gen with real custormers cases, and it improves x6 when aggregating without keys, improves x1.18 when aggregating with keys.

Workload
Example 1: aggregate without keys

SELECT
  mbuser_id, 
  max(substr(acc_nbr,1,15)) acc_nbr, 
  max(nbilling_tid) NBILLING_TID, 
  max(obilling_tid) OBILLING_TID, 
  max(BRAND_ID) BRAND_ID, 
  max(tm_tid) tm_tid, 
  max(area_id) AREA_ID, 
  max(VPMN_ID) VPMN_ID, 
  max(ACCT_ID) ACCT_ID, 
  ….
FROM gdi_mb GROUP BY MbUser_ID;

Example 2: aggregate with keys

SELECT
  max(substr(acc_nbr,1,15)) acc_nbr, 
  max(nbilling_tid) NBILLING_TID, 
  max(obilling_tid) OBILLING_TID, 
  max(BRAND_ID) BRAND_ID, 
  max(tm_tid) tm_tid, 
  max(area_id) AREA_ID, 
  max(VPMN_ID) VPMN_ID, 
  max(ACCT_ID) ACCT_ID, 
  ….
FROM gdi_mb;

Report (in seconds)
image

In above workload pattern, sort actually occpies few time, most of time is used in aggregation, that's the main reason why sortagg code gen speeds up.
image

@yucai
Copy link
Contributor Author

yucai commented Aug 17, 2016

Generated code example, not for code review yet

scala> Seq(("a", "3"), ("b", "20"), ("b", "2")).toDF("k", "v").agg(max("v")).debugCodegen()

Found 2 WholeStageCodegen subtrees.
== Subtree 1 / 2 ==
*SortAggregate(key=[], functions=[partial_max(v#6)], output=[max#18])
+- LocalTableScan [v#6]

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private boolean sagg_initAgg;
/* 008 */   private boolean sagg_bufIsNull;
/* 009 */   private UTF8String sagg_bufValue;
/* 010 */   private org.apache.spark.sql.execution.metric.SQLMetric sagg_numOutputRows;
/* 011 */   private org.apache.spark.sql.execution.metric.SQLMetric sagg_aggTime;
/* 012 */   private scala.collection.Iterator inputadapter_input;
/* 013 */   private UnsafeRow sagg_result;
/* 014 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder sagg_holder;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter sagg_rowWriter;
/* 016 */
/* 017 */   public GeneratedIterator(Object[] references) {
/* 018 */     this.references = references;
/* 019 */   }
/* 020 */
/* 021 */   public void init(int index, scala.collection.Iterator inputs[]) {
/* 022 */     partitionIndex = index;
/* 023 */     sagg_initAgg = false;
/* 024 */
/* 025 */     this.sagg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 026 */     this.sagg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[1];
/* 027 */     inputadapter_input = inputs[0];
/* 028 */     sagg_result = new UnsafeRow(1);
/* 029 */     this.sagg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(sagg_result, 32);
/* 030 */     this.sagg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(sagg_holder, 1);
/* 031 */   }
/* 032 */
/* 033 */   private void sagg_doAggregateWithoutKey() throws java.io.IOException {
/* 034 */     // initialize aggregation buffer
/* 035 */     final UTF8String sagg_value = null;
/* 036 */     sagg_bufIsNull = true;
/* 037 */     sagg_bufValue = sagg_value;
/* 038 */
/* 039 */     while (inputadapter_input.hasNext()) {
/* 040 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 041 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 042 */       UTF8String inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getUTF8String(0));
/* 043 */
/* 044 */       // common sub-expressions
/* 045 */
/* 046 */       // evaluate aggregate function
/* 047 */       boolean sagg_isNull1 = sagg_bufIsNull;
/* 048 */       UTF8String sagg_value1 = sagg_bufValue;
/* 049 */
/* 050 */       if (!inputadapter_isNull && (sagg_isNull1 ||
/* 051 */           (inputadapter_value.compare(sagg_value1)) > 0)) {
/* 052 */         sagg_isNull1 = false;
/* 053 */         sagg_value1 = inputadapter_value;
/* 054 */       }
/* 055 */       // update aggregation buffer
/* 056 */       sagg_bufIsNull = sagg_isNull1;
/* 057 */       if (sagg_bufValue != sagg_value1)
/* 058 */       sagg_bufValue = sagg_value1 == null ? null : sagg_value1.clone();
/* 059 */       if (shouldStop()) return;
/* 060 */     }
/* 061 */
/* 062 */   }
/* 063 */
/* 064 */   protected void processNext() throws java.io.IOException {
/* 065 */     while (!sagg_initAgg) {
/* 066 */       sagg_initAgg = true;
/* 067 */       long sagg_beforeAgg = System.nanoTime();
/* 068 */       sagg_doAggregateWithoutKey();
/* 069 */       sagg_aggTime.add((System.nanoTime() - sagg_beforeAgg) / 1000000);
/* 070 */
/* 071 */       // output the result
/* 072 */
/* 073 */       sagg_numOutputRows.add(1);
/* 074 */       sagg_holder.reset();
/* 075 */
/* 076 */       sagg_rowWriter.zeroOutNullBytes();
/* 077 */
/* 078 */       if (sagg_bufIsNull) {
/* 079 */         sagg_rowWriter.setNullAt(0);
/* 080 */       } else {
/* 081 */         sagg_rowWriter.write(0, sagg_bufValue);
/* 082 */       }
/* 083 */       sagg_result.setTotalSize(sagg_holder.totalSize());
/* 084 */       append(sagg_result);
/* 085 */     }
/* 086 */   }
/* 087 */ }

== Subtree 2 / 2 ==
*SortAggregate(key=[], functions=[max(v#6)], output=[max(v)#14])
+- Exchange SinglePartition
   +- *SortAggregate(key=[], functions=[partial_max(v#6)], output=[max#18])
      +- LocalTableScan [v#6]

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private boolean sagg_initAgg;
/* 008 */   private boolean sagg_bufIsNull;
/* 009 */   private UTF8String sagg_bufValue;
/* 010 */   private org.apache.spark.sql.execution.metric.SQLMetric sagg_numOutputRows;
/* 011 */   private org.apache.spark.sql.execution.metric.SQLMetric sagg_aggTime;
/* 012 */   private scala.collection.Iterator inputadapter_input;
/* 013 */   private UnsafeRow sagg_result;
/* 014 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder sagg_holder;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter sagg_rowWriter;
/* 016 */
/* 017 */   public GeneratedIterator(Object[] references) {
/* 018 */     this.references = references;
/* 019 */   }
/* 020 */
/* 021 */   public void init(int index, scala.collection.Iterator inputs[]) {
/* 022 */     partitionIndex = index;
/* 023 */     sagg_initAgg = false;
/* 024 */
/* 025 */     this.sagg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 026 */     this.sagg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[1];
/* 027 */     inputadapter_input = inputs[0];
/* 028 */     sagg_result = new UnsafeRow(1);
/* 029 */     this.sagg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(sagg_result, 32);
/* 030 */     this.sagg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(sagg_holder, 1);
/* 031 */   }
/* 032 */
/* 033 */   private void sagg_doAggregateWithoutKey() throws java.io.IOException {
/* 034 */     // initialize aggregation buffer
/* 035 */     final UTF8String sagg_value = null;
/* 036 */     sagg_bufIsNull = true;
/* 037 */     sagg_bufValue = sagg_value;
/* 038 */
/* 039 */     while (inputadapter_input.hasNext()) {
/* 040 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 041 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 042 */       UTF8String inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getUTF8String(0));
/* 043 */
/* 044 */       // common sub-expressions
/* 045 */
/* 046 */       // evaluate aggregate function
/* 047 */       boolean sagg_isNull3 = sagg_bufIsNull;
/* 048 */       UTF8String sagg_value3 = sagg_bufValue;
/* 049 */
/* 050 */       if (!inputadapter_isNull && (sagg_isNull3 ||
/* 051 */           (inputadapter_value.compare(sagg_value3)) > 0)) {
/* 052 */         sagg_isNull3 = false;
/* 053 */         sagg_value3 = inputadapter_value;
/* 054 */       }
/* 055 */       // update aggregation buffer
/* 056 */       sagg_bufIsNull = sagg_isNull3;
/* 057 */       if (sagg_bufValue != sagg_value3)
/* 058 */       sagg_bufValue = sagg_value3 == null ? null : sagg_value3.clone();
/* 059 */       if (shouldStop()) return;
/* 060 */     }
/* 061 */
/* 062 */   }
/* 063 */
/* 064 */   protected void processNext() throws java.io.IOException {
/* 065 */     while (!sagg_initAgg) {
/* 066 */       sagg_initAgg = true;
/* 067 */       long sagg_beforeAgg = System.nanoTime();
/* 068 */       sagg_doAggregateWithoutKey();
/* 069 */       sagg_aggTime.add((System.nanoTime() - sagg_beforeAgg) / 1000000);
/* 070 */
/* 071 */       // output the result
/* 072 */
/* 073 */       sagg_numOutputRows.add(1);
/* 074 */       sagg_holder.reset();
/* 075 */
/* 076 */       sagg_rowWriter.zeroOutNullBytes();
/* 077 */
/* 078 */       if (sagg_bufIsNull) {
/* 079 */         sagg_rowWriter.setNullAt(0);
/* 080 */       } else {
/* 081 */         sagg_rowWriter.write(0, sagg_bufValue);
/* 082 */       }
/* 083 */       sagg_result.setTotalSize(sagg_holder.totalSize());
/* 084 */       append(sagg_result);
/* 085 */     }
/* 086 */   }
/* 087 */ }

@yucai
Copy link
Contributor Author

yucai commented Aug 17, 2016

Generated code example, not for code view yet.

scala> Seq(("a", "10"), ("b", "1"), ("b", "2"), ("c", "5"), ("c", "3")).
     |       toDF("k", "v").groupBy("k").agg(max("v")).debugCodegen()
Found 2 WholeStageCodegen subtrees.
== Subtree 1 / 2 ==
*SortAggregate(key=[k#24], functions=[partial_max(v#25)], output=[k#24,max#38])
+- *Sort [k#24 ASC], false, 0
   +- LocalTableScan [k#24, v#25]

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private boolean sagg_bufIsNull;
/* 008 */   private UTF8String sagg_bufValue;
/* 009 */   private UnsafeRow sagg_currentGroupingKey;
/* 010 */   private boolean sagg_lastAgg;
/* 011 */   private org.apache.spark.sql.execution.metric.SQLMetric sagg_numOutputRows;
/* 012 */   private boolean sort_needToSort;
/* 013 */   private org.apache.spark.sql.execution.SortExec sort_plan;
/* 014 */   private org.apache.spark.sql.execution.UnsafeExternalRowSorter sort_sorter;
/* 015 */   private org.apache.spark.executor.TaskMetrics sort_metrics;
/* 016 */   private scala.collection.Iterator<UnsafeRow> sort_sortedIter;
/* 017 */   private scala.collection.Iterator inputadapter_input;
/* 018 */   private org.apache.spark.sql.execution.metric.SQLMetric sort_peakMemory;
/* 019 */   private org.apache.spark.sql.execution.metric.SQLMetric sort_spillSize;
/* 020 */   private org.apache.spark.sql.execution.metric.SQLMetric sort_sortTime;
/* 021 */   private UnsafeRow sagg_result;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder sagg_holder;
/* 023 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter sagg_rowWriter;
/* 024 */   private UnsafeRow sagg_result1;
/* 025 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder sagg_holder1;
/* 026 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter sagg_rowWriter1;
/* 027 */   private UnsafeRow wholestagecodegen_result;
/* 028 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder wholestagecodegen_holder;
/* 029 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter wholestagecodegen_rowWriter;
/* 030 */
/* 031 */   public GeneratedIterator(Object[] references) {
/* 032 */     this.references = references;
/* 033 */   }
/* 034 */
/* 035 */   public void init(int index, scala.collection.Iterator inputs[]) {
/* 036 */     partitionIndex = index;
/* 037 */
/* 038 */     sagg_currentGroupingKey = null;
/* 039 */     sagg_lastAgg = true;
/* 040 */     this.sagg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 041 */     sort_needToSort = true;
/* 042 */     this.sort_plan = (org.apache.spark.sql.execution.SortExec) references[1];
/* 043 */     sort_sorter = sort_plan.createSorter();
/* 044 */     sort_metrics = org.apache.spark.TaskContext.get().taskMetrics();
/* 045 */
/* 046 */     inputadapter_input = inputs[0];
/* 047 */     this.sort_peakMemory = (org.apache.spark.sql.execution.metric.SQLMetric) references[2];
/* 048 */     this.sort_spillSize = (org.apache.spark.sql.execution.metric.SQLMetric) references[3];
/* 049 */     this.sort_sortTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[4];
/* 050 */     sagg_result = new UnsafeRow(1);
/* 051 */     this.sagg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(sagg_result, 32);
/* 052 */     this.sagg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(sagg_holder, 1);
/* 053 */     sagg_result1 = new UnsafeRow(2);
/* 054 */     this.sagg_holder1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(sagg_result1, 64);
/* 055 */     this.sagg_rowWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(sagg_holder1, 2);
/* 056 */     wholestagecodegen_result = new UnsafeRow(2);
/* 057 */     this.wholestagecodegen_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(wholestagecodegen_result, 64);
/* 058 */     this.wholestagecodegen_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(wholestagecodegen_holder, 2);
/* 059 */   }
/* 060 */
/* 061 */   private void sort_addToSorter() throws java.io.IOException {
/* 062 */     while (inputadapter_input.hasNext()) {
/* 063 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 064 */       sort_sorter.insertRow((UnsafeRow)inputadapter_row);
/* 065 */       if (shouldStop()) return;
/* 066 */     }
/* 067 */
/* 068 */   }
/* 069 */
/* 070 */   protected void processNext() throws java.io.IOException {
/* 071 */     if (sort_needToSort) {
/* 072 */       long sort_spillSizeBefore = sort_metrics.memoryBytesSpilled();
/* 073 */       sort_addToSorter();
/* 074 */       sort_sortedIter = sort_sorter.sort();
/* 075 */       sort_sortTime.add(sort_sorter.getSortTimeNanos() / 1000000);
/* 076 */       sort_peakMemory.add(sort_sorter.getPeakMemoryUsage());
/* 077 */       sort_spillSize.add(sort_metrics.memoryBytesSpilled() - sort_spillSizeBefore);
/* 078 */       sort_metrics.incPeakExecutionMemory(sort_sorter.getPeakMemoryUsage());
/* 079 */       sort_needToSort = false;
/* 080 */     }
/* 081 */
/* 082 */     while (sort_sortedIter.hasNext()) {
/* 083 */       UnsafeRow sort_outputRow = (UnsafeRow)sort_sortedIter.next();
/* 084 */
/* 085 */       boolean sort_isNull = sort_outputRow.isNullAt(0);
/* 086 */       UTF8String sort_value = sort_isNull ? null : (sort_outputRow.getUTF8String(0));
/* 087 */       boolean sort_isNull1 = sort_outputRow.isNullAt(1);
/* 088 */       UTF8String sort_value1 = sort_isNull1 ? null : (sort_outputRow.getUTF8String(1));
/* 089 */
/* 090 */       // generate grouping key
/* 091 */       sagg_holder.reset();
/* 092 */
/* 093 */       sagg_rowWriter.zeroOutNullBytes();
/* 094 */
/* 095 */       if (sort_isNull) {
/* 096 */         sagg_rowWriter.setNullAt(0);
/* 097 */       } else {
/* 098 */         sagg_rowWriter.write(0, sort_value);
/* 099 */       }
/* 100 */       sagg_result.setTotalSize(sagg_holder.totalSize());
/* 101 */
/* 102 */       if (sagg_currentGroupingKey == null) {
/* 103 */         sagg_currentGroupingKey = sagg_result.copy();
/* 104 */         // init aggregation buffer vars
/* 105 */         final UTF8String sagg_value = null;
/* 106 */         sagg_bufIsNull = true;
/* 107 */         sagg_bufValue = sagg_value;
/* 108 */         // do aggregation
/* 109 */         // common sub-expressions
/* 110 */
/* 111 */         // evaluate aggregate function
/* 112 */         boolean sagg_isNull2 = sagg_bufIsNull;
/* 113 */         UTF8String sagg_value2 = sagg_bufValue;
/* 114 */
/* 115 */         if (!sort_isNull1 && (sagg_isNull2 ||
/* 116 */             (sort_value1.compare(sagg_value2)) > 0)) {
/* 117 */           sagg_isNull2 = false;
/* 118 */           sagg_value2 = sort_value1;
/* 119 */         }
/* 120 */         // update aggregation buffer
/* 121 */         sagg_bufIsNull = sagg_isNull2;
/* 122 */         if (sagg_bufValue != sagg_value2)
/* 123 */         sagg_bufValue = sagg_value2 == null ? null : sagg_value2.clone();
/* 124 */         continue;
/* 125 */       } else {
/* 126 */         if (sagg_currentGroupingKey.equals(sagg_result)) {
/* 127 */           // do aggregation
/* 128 */           // common sub-expressions
/* 129 */
/* 130 */           // evaluate aggregate function
/* 131 */           boolean sagg_isNull2 = sagg_bufIsNull;
/* 132 */           UTF8String sagg_value2 = sagg_bufValue;
/* 133 */
/* 134 */           if (!sort_isNull1 && (sagg_isNull2 ||
/* 135 */               (sort_value1.compare(sagg_value2)) > 0)) {
/* 136 */             sagg_isNull2 = false;
/* 137 */             sagg_value2 = sort_value1;
/* 138 */           }
/* 139 */           // update aggregation buffer
/* 140 */           sagg_bufIsNull = sagg_isNull2;
/* 141 */           if (sagg_bufValue != sagg_value2)
/* 142 */           sagg_bufValue = sagg_value2 == null ? null : sagg_value2.clone();
/* 143 */           continue;
/* 144 */         } else {
/* 145 */           do {
/* 146 */             sagg_numOutputRows.add(1);
/* 147 */
/* 148 */             sagg_holder1.reset();
/* 149 */
/* 150 */             sagg_rowWriter1.zeroOutNullBytes();
/* 151 */
/* 152 */             boolean sagg_isNull5 = sagg_currentGroupingKey.isNullAt(0);
/* 153 */             UTF8String sagg_value5 = sagg_isNull5 ? null : (sagg_currentGroupingKey.getUTF8String(0));
/* 154 */             if (sagg_isNull5) {
/* 155 */               sagg_rowWriter1.setNullAt(0);
/* 156 */             } else {
/* 157 */               sagg_rowWriter1.write(0, sagg_value5);
/* 158 */             }
/* 159 */
/* 160 */             if (sagg_bufIsNull) {
/* 161 */               sagg_rowWriter1.setNullAt(1);
/* 162 */             } else {
/* 163 */               sagg_rowWriter1.write(1, sagg_bufValue);
/* 164 */             }
/* 165 */             sagg_result1.setTotalSize(sagg_holder1.totalSize());
/* 166 */
/* 167 */             append(sagg_result1);
/* 168 */
/* 169 */           } while (false);
/* 170 */           // new grouping key
/* 171 */           sagg_currentGroupingKey = sagg_result.copy();
/* 172 */           final UTF8String sagg_value = null;
/* 173 */           sagg_bufIsNull = true;
/* 174 */           sagg_bufValue = sagg_value;
/* 175 */           // common sub-expressions
/* 176 */
/* 177 */           // evaluate aggregate function
/* 178 */           boolean sagg_isNull2 = sagg_bufIsNull;
/* 179 */           UTF8String sagg_value2 = sagg_bufValue;
/* 180 */
/* 181 */           if (!sort_isNull1 && (sagg_isNull2 ||
/* 182 */               (sort_value1.compare(sagg_value2)) > 0)) {
/* 183 */             sagg_isNull2 = false;
/* 184 */             sagg_value2 = sort_value1;
/* 185 */           }
/* 186 */           // update aggregation buffer
/* 187 */           sagg_bufIsNull = sagg_isNull2;
/* 188 */           if (sagg_bufValue != sagg_value2)
/* 189 */           sagg_bufValue = sagg_value2 == null ? null : sagg_value2.clone();
/* 190 */         }
/* 191 */       }
/* 192 */
/* 193 */       if (shouldStop()) return;
/* 194 */     }
/* 195 */
/* 196 */     while (sagg_lastAgg && sagg_currentGroupingKey != null) {
/* 197 */       sagg_lastAgg = false;
/* 198 */       sagg_numOutputRows.add(1);
/* 199 */
/* 200 */       wholestagecodegen_holder.reset();
/* 201 */
/* 202 */       wholestagecodegen_rowWriter.zeroOutNullBytes();
/* 203 */
/* 204 */       boolean wholestagecodegen_isNull = sagg_currentGroupingKey.isNullAt(0);
/* 205 */       UTF8String wholestagecodegen_value = wholestagecodegen_isNull ? null : (sagg_currentGroupingKey.getUTF8String(0));
/* 206 */       if (wholestagecodegen_isNull) {
/* 207 */         wholestagecodegen_rowWriter.setNullAt(0);
/* 208 */       } else {
/* 209 */         wholestagecodegen_rowWriter.write(0, wholestagecodegen_value);
/* 210 */       }
/* 211 */
/* 212 */       if (sagg_bufIsNull) {
/* 213 */         wholestagecodegen_rowWriter.setNullAt(1);
/* 214 */       } else {
/* 215 */         wholestagecodegen_rowWriter.write(1, sagg_bufValue);
/* 216 */       }
/* 217 */       wholestagecodegen_result.setTotalSize(wholestagecodegen_holder.totalSize());
/* 218 */
/* 219 */       append(wholestagecodegen_result);
/* 220 */
/* 221 */     }
/* 222 */   }
/* 223 */ }

== Subtree 2 / 2 ==
*SortAggregate(key=[k#24], functions=[max(v#25)], output=[k#24,max(v)#33])
+- *Sort [k#24 ASC], false, 0
   +- Exchange hashpartitioning(k#24, 200)
      +- *SortAggregate(key=[k#24], functions=[partial_max(v#25)], output=[k#24,max#38])
         +- *Sort [k#24 ASC], false, 0
            +- LocalTableScan [k#24, v#25]

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private boolean sagg_bufIsNull;
/* 008 */   private UTF8String sagg_bufValue;
/* 009 */   private UnsafeRow sagg_currentGroupingKey;
/* 010 */   private boolean sagg_lastAgg;
/* 011 */   private org.apache.spark.sql.execution.metric.SQLMetric sagg_numOutputRows;
/* 012 */   private boolean sort_needToSort;
/* 013 */   private org.apache.spark.sql.execution.SortExec sort_plan;
/* 014 */   private org.apache.spark.sql.execution.UnsafeExternalRowSorter sort_sorter;
/* 015 */   private org.apache.spark.executor.TaskMetrics sort_metrics;
/* 016 */   private scala.collection.Iterator<UnsafeRow> sort_sortedIter;
/* 017 */   private scala.collection.Iterator inputadapter_input;
/* 018 */   private org.apache.spark.sql.execution.metric.SQLMetric sort_peakMemory;
/* 019 */   private org.apache.spark.sql.execution.metric.SQLMetric sort_spillSize;
/* 020 */   private org.apache.spark.sql.execution.metric.SQLMetric sort_sortTime;
/* 021 */   private UnsafeRow sagg_result;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder sagg_holder;
/* 023 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter sagg_rowWriter;
/* 024 */   private UnsafeRow sagg_result1;
/* 025 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder sagg_holder1;
/* 026 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter sagg_rowWriter1;
/* 027 */   private UnsafeRow wholestagecodegen_result;
/* 028 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder wholestagecodegen_holder;
/* 029 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter wholestagecodegen_rowWriter;
/* 030 */
/* 031 */   public GeneratedIterator(Object[] references) {
/* 032 */     this.references = references;
/* 033 */   }
/* 034 */
/* 035 */   public void init(int index, scala.collection.Iterator inputs[]) {
/* 036 */     partitionIndex = index;
/* 037 */
/* 038 */     sagg_currentGroupingKey = null;
/* 039 */     sagg_lastAgg = true;
/* 040 */     this.sagg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 041 */     sort_needToSort = true;
/* 042 */     this.sort_plan = (org.apache.spark.sql.execution.SortExec) references[1];
/* 043 */     sort_sorter = sort_plan.createSorter();
/* 044 */     sort_metrics = org.apache.spark.TaskContext.get().taskMetrics();
/* 045 */
/* 046 */     inputadapter_input = inputs[0];
/* 047 */     this.sort_peakMemory = (org.apache.spark.sql.execution.metric.SQLMetric) references[2];
/* 048 */     this.sort_spillSize = (org.apache.spark.sql.execution.metric.SQLMetric) references[3];
/* 049 */     this.sort_sortTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[4];
/* 050 */     sagg_result = new UnsafeRow(1);
/* 051 */     this.sagg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(sagg_result, 32);
/* 052 */     this.sagg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(sagg_holder, 1);
/* 053 */     sagg_result1 = new UnsafeRow(2);
/* 054 */     this.sagg_holder1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(sagg_result1, 64);
/* 055 */     this.sagg_rowWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(sagg_holder1, 2);
/* 056 */     wholestagecodegen_result = new UnsafeRow(2);
/* 057 */     this.wholestagecodegen_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(wholestagecodegen_result, 64);
/* 058 */     this.wholestagecodegen_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(wholestagecodegen_holder, 2);
/* 059 */   }
/* 060 */
/* 061 */   private void sort_addToSorter() throws java.io.IOException {
/* 062 */     while (inputadapter_input.hasNext()) {
/* 063 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 064 */       sort_sorter.insertRow((UnsafeRow)inputadapter_row);
/* 065 */       if (shouldStop()) return;
/* 066 */     }
/* 067 */
/* 068 */   }
/* 069 */
/* 070 */   protected void processNext() throws java.io.IOException {
/* 071 */     if (sort_needToSort) {
/* 072 */       long sort_spillSizeBefore = sort_metrics.memoryBytesSpilled();
/* 073 */       sort_addToSorter();
/* 074 */       sort_sortedIter = sort_sorter.sort();
/* 075 */       sort_sortTime.add(sort_sorter.getSortTimeNanos() / 1000000);
/* 076 */       sort_peakMemory.add(sort_sorter.getPeakMemoryUsage());
/* 077 */       sort_spillSize.add(sort_metrics.memoryBytesSpilled() - sort_spillSizeBefore);
/* 078 */       sort_metrics.incPeakExecutionMemory(sort_sorter.getPeakMemoryUsage());
/* 079 */       sort_needToSort = false;
/* 080 */     }
/* 081 */
/* 082 */     while (sort_sortedIter.hasNext()) {
/* 083 */       UnsafeRow sort_outputRow = (UnsafeRow)sort_sortedIter.next();
/* 084 */
/* 085 */       boolean sort_isNull = sort_outputRow.isNullAt(0);
/* 086 */       UTF8String sort_value = sort_isNull ? null : (sort_outputRow.getUTF8String(0));
/* 087 */       boolean sort_isNull1 = sort_outputRow.isNullAt(1);
/* 088 */       UTF8String sort_value1 = sort_isNull1 ? null : (sort_outputRow.getUTF8String(1));
/* 089 */
/* 090 */       // generate grouping key
/* 091 */       sagg_holder.reset();
/* 092 */
/* 093 */       sagg_rowWriter.zeroOutNullBytes();
/* 094 */
/* 095 */       if (sort_isNull) {
/* 096 */         sagg_rowWriter.setNullAt(0);
/* 097 */       } else {
/* 098 */         sagg_rowWriter.write(0, sort_value);
/* 099 */       }
/* 100 */       sagg_result.setTotalSize(sagg_holder.totalSize());
/* 101 */
/* 102 */       if (sagg_currentGroupingKey == null) {
/* 103 */         sagg_currentGroupingKey = sagg_result.copy();
/* 104 */         // init aggregation buffer vars
/* 105 */         final UTF8String sagg_value = null;
/* 106 */         sagg_bufIsNull = true;
/* 107 */         sagg_bufValue = sagg_value;
/* 108 */         // do aggregation
/* 109 */         // common sub-expressions
/* 110 */
/* 111 */         // evaluate aggregate function
/* 112 */         boolean sagg_isNull2 = sagg_bufIsNull;
/* 113 */         UTF8String sagg_value2 = sagg_bufValue;
/* 114 */
/* 115 */         if (!sort_isNull1 && (sagg_isNull2 ||
/* 116 */             (sort_value1.compare(sagg_value2)) > 0)) {
/* 117 */           sagg_isNull2 = false;
/* 118 */           sagg_value2 = sort_value1;
/* 119 */         }
/* 120 */         // update aggregation buffer
/* 121 */         sagg_bufIsNull = sagg_isNull2;
/* 122 */         if (sagg_bufValue != sagg_value2)
/* 123 */         sagg_bufValue = sagg_value2 == null ? null : sagg_value2.clone();
/* 124 */         continue;
/* 125 */       } else {
/* 126 */         if (sagg_currentGroupingKey.equals(sagg_result)) {
/* 127 */           // do aggregation
/* 128 */           // common sub-expressions
/* 129 */
/* 130 */           // evaluate aggregate function
/* 131 */           boolean sagg_isNull2 = sagg_bufIsNull;
/* 132 */           UTF8String sagg_value2 = sagg_bufValue;
/* 133 */
/* 134 */           if (!sort_isNull1 && (sagg_isNull2 ||
/* 135 */               (sort_value1.compare(sagg_value2)) > 0)) {
/* 136 */             sagg_isNull2 = false;
/* 137 */             sagg_value2 = sort_value1;
/* 138 */           }
/* 139 */           // update aggregation buffer
/* 140 */           sagg_bufIsNull = sagg_isNull2;
/* 141 */           if (sagg_bufValue != sagg_value2)
/* 142 */           sagg_bufValue = sagg_value2 == null ? null : sagg_value2.clone();
/* 143 */           continue;
/* 144 */         } else {
/* 145 */           do {
/* 146 */             sagg_numOutputRows.add(1);
/* 147 */
/* 148 */             boolean sagg_isNull5 = sagg_currentGroupingKey.isNullAt(0);
/* 149 */             UTF8String sagg_value5 = sagg_isNull5 ? null : (sagg_currentGroupingKey.getUTF8String(0));
/* 150 */
/* 151 */             sagg_holder1.reset();
/* 152 */
/* 153 */             sagg_rowWriter1.zeroOutNullBytes();
/* 154 */
/* 155 */             if (sagg_isNull5) {
/* 156 */               sagg_rowWriter1.setNullAt(0);
/* 157 */             } else {
/* 158 */               sagg_rowWriter1.write(0, sagg_value5);
/* 159 */             }
/* 160 */
/* 161 */             if (sagg_bufIsNull) {
/* 162 */               sagg_rowWriter1.setNullAt(1);
/* 163 */             } else {
/* 164 */               sagg_rowWriter1.write(1, sagg_bufValue);
/* 165 */             }
/* 166 */             sagg_result1.setTotalSize(sagg_holder1.totalSize());
/* 167 */             append(sagg_result1);
/* 168 */
/* 169 */           } while (false);
/* 170 */           // new grouping key
/* 171 */           sagg_currentGroupingKey = sagg_result.copy();
/* 172 */           final UTF8String sagg_value = null;
/* 173 */           sagg_bufIsNull = true;
/* 174 */           sagg_bufValue = sagg_value;
/* 175 */           // common sub-expressions
/* 176 */
/* 177 */           // evaluate aggregate function
/* 178 */           boolean sagg_isNull2 = sagg_bufIsNull;
/* 179 */           UTF8String sagg_value2 = sagg_bufValue;
/* 180 */
/* 181 */           if (!sort_isNull1 && (sagg_isNull2 ||
/* 182 */               (sort_value1.compare(sagg_value2)) > 0)) {
/* 183 */             sagg_isNull2 = false;
/* 184 */             sagg_value2 = sort_value1;
/* 185 */           }
/* 186 */           // update aggregation buffer
/* 187 */           sagg_bufIsNull = sagg_isNull2;
/* 188 */           if (sagg_bufValue != sagg_value2)
/* 189 */           sagg_bufValue = sagg_value2 == null ? null : sagg_value2.clone();
/* 190 */         }
/* 191 */       }
/* 192 */
/* 193 */       if (shouldStop()) return;
/* 194 */     }
/* 195 */
/* 196 */     while (sagg_lastAgg && sagg_currentGroupingKey != null) {
/* 197 */       sagg_lastAgg = false;
/* 198 */       sagg_numOutputRows.add(1);
/* 199 */
/* 200 */       boolean wholestagecodegen_isNull = sagg_currentGroupingKey.isNullAt(0);
/* 201 */       UTF8String wholestagecodegen_value = wholestagecodegen_isNull ? null : (sagg_currentGroupingKey.getUTF8String(0));
/* 202 */
/* 203 */       wholestagecodegen_holder.reset();
/* 204 */
/* 205 */       wholestagecodegen_rowWriter.zeroOutNullBytes();
/* 206 */
/* 207 */       if (wholestagecodegen_isNull) {
/* 208 */         wholestagecodegen_rowWriter.setNullAt(0);
/* 209 */       } else {
/* 210 */         wholestagecodegen_rowWriter.write(0, wholestagecodegen_value);
/* 211 */       }
/* 212 */
/* 213 */       if (sagg_bufIsNull) {
/* 214 */         wholestagecodegen_rowWriter.setNullAt(1);
/* 215 */       } else {
/* 216 */         wholestagecodegen_rowWriter.write(1, sagg_bufValue);
/* 217 */       }
/* 218 */       wholestagecodegen_result.setTotalSize(wholestagecodegen_holder.totalSize());
/* 219 */       append(wholestagecodegen_result);
/* 220 */
/* 221 */     }
/* 222 */   }
/* 223 */ }

@yucai
Copy link
Contributor Author

yucai commented Aug 17, 2016

@chenghao-intel Hao, kindly take a look at.

@chenghao-intel
Copy link
Contributor

@yucai can you please rebase the code?

@hvanhovell
Copy link
Contributor

@yucai thanks for posting the benchmarks and the code. One high level comment would be to start with a properly sorted dataset for the second benchmark. I would like to know how much time is actually spend in aggregation.

@yucai yucai force-pushed the sortagg branch 2 times, most recently from 9048ff0 to 72a0c8a Compare August 25, 2016 15:03
@yucai yucai force-pushed the sortagg branch 4 times, most recently from 461c737 to 958dc05 Compare September 22, 2016 09:13
@yucai yucai force-pushed the sortagg branch 2 times, most recently from 0a12860 to 2c22f81 Compare September 23, 2016 06:35
@maropu
Copy link
Member

maropu commented Nov 25, 2016

@hvanhovell What's the status of this? If nobody takes this, I'll do.

@yucai
Copy link
Contributor Author

yucai commented Nov 25, 2016

@maropu, I am doing some refactor recently, will update it soon.

@maropu
Copy link
Member

maropu commented Nov 25, 2016

@yucai okay, thanks!

@AmplabJenkins
Copy link

Can one of the admins verify this patch?

@maropu
Copy link
Member

maropu commented Jan 24, 2017

Any update?

@HyukjinKwon
Copy link
Member

HyukjinKwon commented May 11, 2017

gentle ping @yucai, let me propose to close this if it is still inactive.

@asfgit asfgit closed this in 5d2750a May 18, 2017
@yucai yucai changed the title [WIP][SPARK-16844][SQL] Generate code for sort based aggregation [SPARK-16844][SQL] Generate code for sort based aggregation Dec 23, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants