Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
Fix bugs when using ObjectHashAggregate and joins with ArrayType
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi Xue <[email protected]>
  • Loading branch information
xuechendi committed Aug 12, 2021
1 parent bf1e3f9 commit 1071edb
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1774,6 +1774,11 @@ final void setNulls(int rowId, int count) {
writer.setNull(rowId + i);
}
}

@Override
final void setBytes(int rowId, int count, byte[] src, int srcIndex) {
writer.setSafe(rowId, src, srcIndex, count);
}
}

private static class DateWriter extends ArrowVectorWriter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ object RowToColumnConverter {
case LongType | TimestampType => LongConverter
case DoubleType => DoubleConverter
case StringType => StringConverter
case BinaryType => BinaryConverter
case CalendarIntervalType => CalendarConverter
case at: ArrayType => new ArrayConverter(getConverterForType(at.elementType, nullable))
case st: StructType => new StructConverter(st.fields.map(
Expand Down Expand Up @@ -150,6 +151,13 @@ object RowToColumnConverter {
}
}

private object BinaryConverter extends TypeConverter {
override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = {
val data = row.getBinary(column)
cv.asInstanceOf[ArrowWritableColumnVector].appendString(data, 0, data.length)
}
}

private object CalendarConverter extends TypeConverter {
override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = {
val c = row.getInterval(column)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import com.intel.oap.vectorized.{
NativePartitioning
}
import org.apache.arrow.gandiva.expression.TreeBuilder
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
Expand All @@ -53,7 +53,7 @@ import org.apache.spark.sql.execution.metric.{
SQLShuffleWriteMetricsReporter
}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{MutablePair, Utils}

Expand Down Expand Up @@ -85,13 +85,27 @@ case class ColumnarShuffleExchangeExec(

override def nodeName: String = "ColumnarExchange"
override def output: Seq[Attribute] = child.output
buildCheck()

override def supportsColumnar: Boolean = true

override def stringArgs =
super.stringArgs ++ Iterator(s"[id=#$id]")
//super.stringArgs ++ Iterator(output.map(o => s"${o}#${o.dataType.simpleString}"))

def buildCheck(): Unit = {
// check input datatype
for (attr <- child.output) {
try {
ColumnarShuffleExchangeExec.createArrowField(attr)
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarShuffleExchange")
}
}
}

val serializer: Serializer = new ArrowColumnarBatchSerializer(
longMetric("avgReadBatchNumRows"),
longMetric("numOutputRows"))
Expand Down Expand Up @@ -275,6 +289,22 @@ object ColumnarShuffleExchangeExec extends Logging {
}
}

def createArrowField(name: String, dt: DataType): Field = dt match {
case at: ArrayType =>
throw new UnsupportedOperationException(s"${dt} is not supported in ColumnarShuffleExchange")
case mt: MapType =>
throw new UnsupportedOperationException(s"${dt} is not supported in ColumnarShuffleExchange")
case st: StructType =>
throw new UnsupportedOperationException(s"${dt} is not supported in ColumnarShuffleExchange")
/*new Field(name, FieldType.nullable(ArrowType.List.INSTANCE),
Lists.newArrayList(createArrowField(s"${name}_${dt}", at.elementType)))*/
case _ =>
Field.nullable(name, CodeGeneration.getResultType(dt))
}

def createArrowField(attr: Attribute): Field =
createArrowField(s"${attr.name}#${attr.exprId.id}", attr.dataType)

def prepareShuffleDependency(
rdd: RDD[ColumnarBatch],
outputAttributes: Seq[Attribute],
Expand All @@ -288,12 +318,7 @@ object ColumnarShuffleExchangeExec extends Logging {
splitTime: SQLMetric,
spillTime: SQLMetric,
compressTime: SQLMetric): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {

val arrowFields = outputAttributes.map(attr => {
Field
.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
})

val arrowFields = outputAttributes.map(attr => createArrowField(attr))
def serializeSchema(fields: Seq[Field]): Array[Byte] = {
val schema = new Schema(fields.asJava)
ConverterUtils.getSchemaBytesBuf(schema)
Expand Down

0 comments on commit 1071edb

Please sign in to comment.