diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 208ec92987ac8..41bb4f012f2e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import scala.util.hashing.MurmurHash3 + import org.apache.spark.sql.catalyst.expressions.GenericRow @@ -32,7 +34,7 @@ object Row { * } * }}} */ - def unapplySeq(row: Row): Some[Seq[Any]] = Some(row) + def unapplySeq(row: Row): Some[Seq[Any]] = Some(row.toSeq) /** * This method can be used to construct a [[Row]] with the given values. @@ -43,6 +45,16 @@ object Row { * This method can be used to construct a [[Row]] from a [[Seq]] of values. */ def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray) + + def fromTuple(tuple: Product): Row = fromSeq(tuple.productIterator.toSeq) + + /** + * Merge multiple rows into a single row, one after another. + */ + def merge(rows: Row*): Row = { + // TODO: Improve the performance of this if used in performance critical part. + new GenericRow(rows.flatMap(_.toSeq).toArray) + } } @@ -103,7 +115,13 @@ object Row { * * @group row */ -trait Row extends Seq[Any] with Serializable { +trait Row extends Serializable { + /** Number of elements in the Row. */ + def size: Int = length + + /** Number of elements in the Row. */ + def length: Int + /** * Returns the value at position i. If the value is null, null is returned. The following * is a mapping between Spark SQL types and return types: @@ -291,12 +309,61 @@ trait Row extends Seq[Any] with Serializable { /** Returns true if there are any NULL values in this row. */ def anyNull: Boolean = { - val l = length + val len = length var i = 0 - while (i < l) { + while (i < len) { if (isNullAt(i)) { return true } i += 1 } false } + + override def equals(that: Any): Boolean = that match { + case null => false + case that: Row => + if (this.length != that.length) { + return false + } + var i = 0 + val len = this.length + while (i < len) { + if (apply(i) != that.apply(i)) { + return false + } + i += 1 + } + true + case _ => false + } + + override def hashCode: Int = { + // Using Scala's Seq hash code implementation. + var n = 0 + var h = MurmurHash3.seqSeed + val len = length + while (n < len) { + h = MurmurHash3.mix(h, apply(n).##) + n += 1 + } + MurmurHash3.finalizeHash(h, n) + } + + /* ---------------------- utility methods for Scala ---------------------- */ + + /** + * Return a Scala Seq representing the row. ELements are placed in the same order in the Seq. + */ + def toSeq: Seq[Any] + + /** Displays all elements of this sequence in a string (without a separator). */ + def mkString: String = toSeq.mkString + + /** Displays all elements of this sequence in a string using a separator string. */ + def mkString(sep: String): String = toSeq.mkString(sep) + + /** + * Displays all elements of this traversable or iterator in a string using + * start, end, and separator strings. + */ + def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d280db83b26f7..191d16fb10b5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -84,8 +84,9 @@ trait ScalaReflection { } def convertRowToScala(r: Row, schema: StructType): Row = { + // TODO: This is very slow!!! new GenericRow( - r.zip(schema.fields.map(_.dataType)) + r.toSeq.zip(schema.fields.map(_.dataType)) .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 26c855878d202..417659eed5957 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -272,9 +272,6 @@ package object dsl { def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) = Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) - def sfilter(dynamicUdf: (DynamicRow) => Boolean) = - Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan) - def sample( fraction: Double, withReplacement: Boolean = true, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 1a2133bbbcec7..ece5ee73618cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -407,7 +407,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val casts = from.fields.zip(to.fields).map { case (fromField, toField) => cast(fromField.dataType, toField.dataType) } - buildCast[Row](_, row => Row(row.zip(casts).map { + // TODO: This is very slow! + buildCast[Row](_, row => Row(row.toSeq.zip(casts).map { case (v, cast) => if (v == null) null else cast(v) }: _*)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index e7e81a21fdf03..db5d897ee569f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -105,45 +105,45 @@ class JoinedRow extends Row { this } - def iterator = row1.iterator ++ row2.iterator + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - def length = row1.length + row2.length + override def length = row1.length + row2.length - def apply(i: Int) = - if (i < row1.size) row1(i) else row2(i - row1.size) + override def apply(i: Int) = + if (i < row1.length) row1(i) else row2(i - row1.length) - def isNullAt(i: Int) = - if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + override def isNullAt(i: Int) = + if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - def getInt(i: Int): Int = - if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + override def getInt(i: Int): Int = + if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - def getLong(i: Int): Long = - if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + override def getLong(i: Int): Long = + if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - def getDouble(i: Int): Double = - if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + override def getDouble(i: Int): Double = + if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - def getBoolean(i: Int): Boolean = - if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + override def getBoolean(i: Int): Boolean = + if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - def getShort(i: Int): Short = - if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + override def getShort(i: Int): Short = + if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - def getByte(i: Int): Byte = - if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + override def getByte(i: Int): Byte = + if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - def getFloat(i: Int): Float = - if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + override def getFloat(i: Int): Float = + if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - def getString(i: Int): String = - if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getString(i: Int): String = + if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) override def getAs[T](i: Int): T = - if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - def copy() = { - val totalSize = row1.size + row2.size + override def copy() = { + val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { @@ -154,8 +154,16 @@ class JoinedRow extends Row { } override def toString() = { - val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) - s"[${row.mkString(",")}]" + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } } } @@ -197,45 +205,45 @@ class JoinedRow2 extends Row { this } - def iterator = row1.iterator ++ row2.iterator + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - def length = row1.length + row2.length + override def length = row1.length + row2.length - def apply(i: Int) = - if (i < row1.size) row1(i) else row2(i - row1.size) + override def apply(i: Int) = + if (i < row1.length) row1(i) else row2(i - row1.length) - def isNullAt(i: Int) = - if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + override def isNullAt(i: Int) = + if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - def getInt(i: Int): Int = - if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + override def getInt(i: Int): Int = + if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - def getLong(i: Int): Long = - if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + override def getLong(i: Int): Long = + if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - def getDouble(i: Int): Double = - if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + override def getDouble(i: Int): Double = + if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - def getBoolean(i: Int): Boolean = - if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + override def getBoolean(i: Int): Boolean = + if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - def getShort(i: Int): Short = - if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + override def getShort(i: Int): Short = + if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - def getByte(i: Int): Byte = - if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + override def getByte(i: Int): Byte = + if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - def getFloat(i: Int): Float = - if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + override def getFloat(i: Int): Float = + if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - def getString(i: Int): String = - if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getString(i: Int): String = + if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) override def getAs[T](i: Int): T = - if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - def copy() = { - val totalSize = row1.size + row2.size + override def copy() = { + val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { @@ -246,8 +254,16 @@ class JoinedRow2 extends Row { } override def toString() = { - val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) - s"[${row.mkString(",")}]" + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } } } @@ -283,45 +299,45 @@ class JoinedRow3 extends Row { this } - def iterator = row1.iterator ++ row2.iterator + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - def length = row1.length + row2.length + override def length = row1.length + row2.length - def apply(i: Int) = - if (i < row1.size) row1(i) else row2(i - row1.size) + override def apply(i: Int) = + if (i < row1.length) row1(i) else row2(i - row1.length) - def isNullAt(i: Int) = - if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + override def isNullAt(i: Int) = + if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - def getInt(i: Int): Int = - if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + override def getInt(i: Int): Int = + if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - def getLong(i: Int): Long = - if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + override def getLong(i: Int): Long = + if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - def getDouble(i: Int): Double = - if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + override def getDouble(i: Int): Double = + if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - def getBoolean(i: Int): Boolean = - if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + override def getBoolean(i: Int): Boolean = + if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - def getShort(i: Int): Short = - if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + override def getShort(i: Int): Short = + if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - def getByte(i: Int): Byte = - if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + override def getByte(i: Int): Byte = + if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - def getFloat(i: Int): Float = - if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + override def getFloat(i: Int): Float = + if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - def getString(i: Int): String = - if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getString(i: Int): String = + if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) override def getAs[T](i: Int): T = - if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - def copy() = { - val totalSize = row1.size + row2.size + override def copy() = { + val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { @@ -332,8 +348,16 @@ class JoinedRow3 extends Row { } override def toString() = { - val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) - s"[${row.mkString(",")}]" + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } } } @@ -369,45 +393,45 @@ class JoinedRow4 extends Row { this } - def iterator = row1.iterator ++ row2.iterator + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - def length = row1.length + row2.length + override def length = row1.length + row2.length - def apply(i: Int) = - if (i < row1.size) row1(i) else row2(i - row1.size) + override def apply(i: Int) = + if (i < row1.length) row1(i) else row2(i - row1.length) - def isNullAt(i: Int) = - if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + override def isNullAt(i: Int) = + if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - def getInt(i: Int): Int = - if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + override def getInt(i: Int): Int = + if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - def getLong(i: Int): Long = - if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + override def getLong(i: Int): Long = + if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - def getDouble(i: Int): Double = - if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + override def getDouble(i: Int): Double = + if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - def getBoolean(i: Int): Boolean = - if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + override def getBoolean(i: Int): Boolean = + if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - def getShort(i: Int): Short = - if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + override def getShort(i: Int): Short = + if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - def getByte(i: Int): Byte = - if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + override def getByte(i: Int): Byte = + if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - def getFloat(i: Int): Float = - if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + override def getFloat(i: Int): Float = + if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - def getString(i: Int): String = - if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getString(i: Int): String = + if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) override def getAs[T](i: Int): T = - if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - def copy() = { - val totalSize = row1.size + row2.size + override def copy() = { + val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { @@ -418,8 +442,16 @@ class JoinedRow4 extends Row { } override def toString() = { - val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) - s"[${row.mkString(",")}]" + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } } } @@ -455,45 +487,45 @@ class JoinedRow5 extends Row { this } - def iterator = row1.iterator ++ row2.iterator + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - def length = row1.length + row2.length + override def length = row1.length + row2.length - def apply(i: Int) = - if (i < row1.size) row1(i) else row2(i - row1.size) + override def apply(i: Int) = + if (i < row1.length) row1(i) else row2(i - row1.length) - def isNullAt(i: Int) = - if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + override def isNullAt(i: Int) = + if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - def getInt(i: Int): Int = - if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + override def getInt(i: Int): Int = + if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - def getLong(i: Int): Long = - if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + override def getLong(i: Int): Long = + if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - def getDouble(i: Int): Double = - if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + override def getDouble(i: Int): Double = + if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - def getBoolean(i: Int): Boolean = - if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + override def getBoolean(i: Int): Boolean = + if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - def getShort(i: Int): Short = - if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + override def getShort(i: Int): Short = + if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - def getByte(i: Int): Byte = - if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + override def getByte(i: Int): Byte = + if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - def getFloat(i: Int): Float = - if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + override def getFloat(i: Int): Float = + if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - def getString(i: Int): String = - if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getString(i: Int): String = + if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) override def getAs[T](i: Int): T = - if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - def copy() = { - val totalSize = row1.size + row2.size + override def copy() = { + val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { @@ -504,7 +536,15 @@ class JoinedRow5 extends Row { } override def toString() = { - val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) - s"[${row.mkString(",")}]" + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.mkString("[", ",", "]") + } else if (row2 eq null) { + row1.mkString("[", ",", "]") + } else { + mkString("[", ",", "]") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 37d9f0ed5c79e..7434165f654f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -209,6 +209,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def length: Int = values.length + override def toSeq: Seq[Any] = values.map(_.boxed).toSeq + override def setNullAt(i: Int): Unit = { values(i).isNull = true } @@ -231,8 +233,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR if (value == null) setNullAt(ordinal) else values(ordinal).update(value) } - override def iterator: Iterator[Any] = values.map(_.boxed).iterator - override def setString(ordinal: Int, value: String) = update(ordinal, value) override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala deleted file mode 100644 index e2f5c7332d9ab..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import scala.language.dynamics - -import org.apache.spark.sql.types.DataType - -/** - * The data type representing [[DynamicRow]] values. - */ -case object DynamicType extends DataType { - - /** - * The default size of a value of the DynamicType is 4096 bytes. - */ - override def defaultSize: Int = 4096 -} - -/** - * Wrap a [[Row]] as a [[DynamicRow]]. - */ -case class WrapDynamic(children: Seq[Attribute]) extends Expression { - type EvaluatedType = DynamicRow - - def nullable = false - - def dataType = DynamicType - - override def eval(input: Row): DynamicRow = input match { - // Avoid copy for generic rows. - case g: GenericRow => new DynamicRow(children, g.values) - case otherRowType => new DynamicRow(children, otherRowType.toArray) - } -} - -/** - * DynamicRows use scala's Dynamic trait to emulate an ORM of in a dynamically typed language. - * Since the type of the column is not known at compile time, all attributes are converted to - * strings before being passed to the function. - */ -class DynamicRow(val schema: Seq[Attribute], values: Array[Any]) - extends GenericRow(values) with Dynamic { - - def selectDynamic(attributeName: String): String = { - val ordinal = schema.indexWhere(_.name == attributeName) - values(ordinal).toString - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index cc97cb4f50b69..69397a73a8880 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -77,14 +77,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { """.children : Seq[Tree] } - val iteratorFunction = { - val allColumns = (0 until expressions.size).map { i => - val iLit = ru.Literal(Constant(i)) - q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" - } - q"override def iterator = Iterator[Any](..$allColumns)" - } - val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)""" val applyFunction = { val cases = (0 until expressions.size).map { i => @@ -191,20 +183,26 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } """ + val allColumns = (0 until expressions.size).map { i => + val iLit = ru.Literal(Constant(i)) + q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" + } + val copyFunction = - q""" - override def copy() = new $genericRowType(this.toArray) - """ + q"override def copy() = new $genericRowType(Array[Any](..$allColumns))" + + val toSeqFunction = + q"override def toSeq: Seq[Any] = Seq(..$allColumns)" val classBody = nullFunctions ++ ( lengthDef +: - iteratorFunction +: applyFunction +: updateFunction +: equalsFunction +: hashCodeFunction +: copyFunction +: + toSeqFunction +: (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions)) val code = q""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index c22b8426841da..8df150e2f855f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -44,7 +44,7 @@ trait MutableRow extends Row { */ object EmptyRow extends Row { override def apply(i: Int): Any = throw new UnsupportedOperationException - override def iterator = Iterator.empty + override def toSeq = Seq.empty override def length = 0 override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException override def getInt(i: Int): Int = throw new UnsupportedOperationException @@ -70,7 +70,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { def this(size: Int) = this(new Array[Any](size)) - override def iterator = values.iterator + override def toSeq = values.toSeq override def length = values.length @@ -119,7 +119,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } // Custom hashCode function that matches the efficient code generated version. - override def hashCode(): Int = { + override def hashCode: Int = { var result: Int = 37 var i = 0 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 6df5db4c80f34..5138942a55daa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -244,7 +244,7 @@ class ScalaReflectionSuite extends FunSuite { test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) - val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) + val convertedData = Row(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) val dataType = schemaFor[PrimitiveData].dataType assert(convertToCatalyst(data, dataType) === convertedData) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index ae4d8ba90c5bd..d1e21dffeb8c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -330,25 +330,6 @@ class SchemaRDD( sqlContext, Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)) - /** - * :: Experimental :: - * Filters tuples using a function over a `Dynamic` version of a given Row. DynamicRows use - * scala's Dynamic trait to emulate an ORM of in a dynamically typed language. Since the type of - * the column is not known at compile time, all attributes are converted to strings before - * being passed to the function. - * - * {{{ - * schemaRDD.where(r => r.firstName == "Bob" && r.lastName == "Smith") - * }}} - * - * @group Query - */ - @Experimental - def where(dynamicUdf: (DynamicRow) => Boolean) = - new SchemaRDD( - sqlContext, - Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan)) - /** * :: Experimental :: * Returns a sampled version of the underlying dataset. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 065fae3c83df1..11d5943fb427f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -21,7 +21,6 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -128,8 +127,7 @@ private[sql] case class InMemoryRelation( rowCount += 1 } - val stats = Row.fromSeq( - columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _)) + val stats = Row.merge(columnBuilders.map(_.columnStats.collectedStatistics) : _*) batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) @@ -271,9 +269,10 @@ private[sql] case class InMemoryColumnarTableScan( // Extract rows via column accessors new Iterator[Row] { + private[this] val rowLen = nextRow.length override def next() = { var i = 0 - while (i < nextRow.length) { + while (i < rowLen) { columnAccessors(i).extractTo(nextRow, i) i += 1 } @@ -297,7 +296,7 @@ private[sql] case class InMemoryColumnarTableScan( cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { def statsString = relation.partitionStatistics.schema - .zip(cachedBatch.stats) + .zip(cachedBatch.stats.toSeq) .map { case (a, s) => s"${a.name}: $s" } .mkString(", ") logInfo(s"Skipping partition based on stats $statsString") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 64673248394c6..68a5b1de7691b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -127,7 +127,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { while (from.hasRemaining) { columnType.extract(from, value, 0) - if (value.head == currentValue.head) { + if (value(0) == currentValue(0)) { currentRun += 1 } else { // Writes current run diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 46245cd5a1869..4d7e338e8ed13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -144,7 +144,7 @@ package object debug { case (null, _) => case (row: Row, StructType(fields)) => - row.zip(fields.map(_.dataType)).foreach { case(d,t) => typeCheck(d,t) } + row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } case (s: Seq[_], ArrayType(elemType, _)) => s.foreach(typeCheck(_, elemType)) case (m: Map[_, _], MapType(keyType, valueType, _)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 7ed64aad10d4e..b85021acc9d4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -116,9 +116,9 @@ object EvaluatePython { def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { case (null, _) => null - case (row: Seq[Any], struct: StructType) => + case (row: Row, struct: StructType) => val fields = struct.fields.map(field => field.dataType) - row.zip(fields).map { + row.toSeq.zip(fields).map { case (obj, dataType) => toJava(obj, dataType) }.toArray @@ -143,7 +143,8 @@ object EvaluatePython { * Convert Row into Java Array (for pickled into Python) */ def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = { - row.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray + // TODO: this is slow! + row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray } // Converts value to the type specified by the data type. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index db70a7eac72b9..9171939f7e8f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -458,16 +458,16 @@ private[sql] object JsonRDD extends Logging { gen.writeEndArray() case (MapType(kv,vv, _), v: Map[_,_]) => - gen.writeStartObject + gen.writeStartObject() v.foreach { p => gen.writeFieldName(p._1.toString) valWriter(vv,p._2) } - gen.writeEndObject + gen.writeEndObject() - case (StructType(ty), v: Seq[_]) => + case (StructType(ty), v: Row) => gen.writeStartObject() - ty.zip(v).foreach { + ty.zip(v.toSeq).foreach { case (_, null) => case (field, v) => gen.writeFieldName(field.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index b4aed04199129..9d9150246c8d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -66,7 +66,7 @@ private[sql] object CatalystConverter { // TODO: consider using Array[T] for arrays to avoid boxing of primitive types type ArrayScalaType[T] = Seq[T] - type StructScalaType[T] = Seq[T] + type StructScalaType[T] = Row type MapScalaType[K, V] = Map[K, V] protected[parquet] def createConverter( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 2bcfe28456997..afbfe214f1ce4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -45,28 +45,28 @@ class DslQuerySuite extends QueryTest { test("agg") { checkAnswer( testData2.groupBy('a)('a, sum('b)), - Seq((1,3),(2,3),(3,3)) + Seq(Row(1,3), Row(2,3), Row(3,3)) ) checkAnswer( testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)), - 9 + Row(9) ) checkAnswer( testData2.aggregate(sum('b)), - 9 + Row(9) ) } test("convert $\"attribute name\" into unresolved attribute") { checkAnswer( testData.where($"key" === 1).select($"value"), - Seq(Seq("1"))) + Row("1")) } test("convert Scala Symbol 'attrname into unresolved attribute") { checkAnswer( testData.where('key === 1).select('value), - Seq(Seq("1"))) + Row("1")) } test("select *") { @@ -78,61 +78,61 @@ class DslQuerySuite extends QueryTest { test("simple select") { checkAnswer( testData.where('key === 1).select('value), - Seq(Seq("1"))) + Row("1")) } test("select with functions") { checkAnswer( testData.select(sum('value), avg('value), count(1)), - Seq(Seq(5050.0, 50.5, 100))) + Row(5050.0, 50.5, 100)) checkAnswer( testData2.select('a + 'b, 'a < 'b), Seq( - Seq(2, false), - Seq(3, true), - Seq(3, false), - Seq(4, false), - Seq(4, false), - Seq(5, false))) + Row(2, false), + Row(3, true), + Row(3, false), + Row(4, false), + Row(4, false), + Row(5, false))) checkAnswer( testData2.select(sumDistinct('a)), - Seq(Seq(6))) + Row(6)) } test("global sorting") { checkAnswer( testData2.orderBy('a.asc, 'b.asc), - Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) + Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) checkAnswer( testData2.orderBy('a.asc, 'b.desc), - Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1))) + Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) checkAnswer( testData2.orderBy('a.desc, 'b.desc), - Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1))) + Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) checkAnswer( testData2.orderBy('a.desc, 'b.asc), - Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) + Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) checkAnswer( arrayData.orderBy('data.getItem(0).asc), - arrayData.collect().sortBy(_.data(0)).toSeq) + arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) checkAnswer( arrayData.orderBy('data.getItem(0).desc), - arrayData.collect().sortBy(_.data(0)).reverse.toSeq) + arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) checkAnswer( - mapData.orderBy('data.getItem(1).asc), - mapData.collect().sortBy(_.data(1)).toSeq) + arrayData.orderBy('data.getItem(1).asc), + arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) checkAnswer( - mapData.orderBy('data.getItem(1).desc), - mapData.collect().sortBy(_.data(1)).reverse.toSeq) + arrayData.orderBy('data.getItem(1).desc), + arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) } test("partition wide sorting") { @@ -147,19 +147,19 @@ class DslQuerySuite extends QueryTest { // (3, 2) checkAnswer( testData2.sortBy('a.asc, 'b.asc), - Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) + Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) checkAnswer( testData2.sortBy('a.asc, 'b.desc), - Seq((1,2), (1,1), (2,1), (2,2), (3,2), (3,1))) + Seq(Row(1,2), Row(1,1), Row(2,1), Row(2,2), Row(3,2), Row(3,1))) checkAnswer( testData2.sortBy('a.desc, 'b.desc), - Seq((2,1), (1,2), (1,1), (3,2), (3,1), (2,2))) + Seq(Row(2,1), Row(1,2), Row(1,1), Row(3,2), Row(3,1), Row(2,2))) checkAnswer( testData2.sortBy('a.desc, 'b.asc), - Seq((2,1), (1,1), (1,2), (3,1), (3,2), (2,2))) + Seq(Row(2,1), Row(1,1), Row(1,2), Row(3,1), Row(3,2), Row(2,2))) } test("limit") { @@ -169,11 +169,11 @@ class DslQuerySuite extends QueryTest { checkAnswer( arrayData.limit(1), - arrayData.take(1).toSeq) + arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) checkAnswer( mapData.limit(1), - mapData.take(1).toSeq) + mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) } test("SPARK-3395 limit distinct") { @@ -184,8 +184,8 @@ class DslQuerySuite extends QueryTest { .registerTempTable("onerow") checkAnswer( sql("select * from onerow inner join testData2 on onerow.a = testData2.a"), - (1, 1, 1, 1) :: - (1, 1, 1, 2) :: Nil) + Row(1, 1, 1, 1) :: + Row(1, 1, 1, 2) :: Nil) } test("SPARK-3858 generator qualifiers are discarded") { @@ -193,55 +193,55 @@ class DslQuerySuite extends QueryTest { arrayData.as('ad) .generate(Explode("data" :: Nil, 'data), alias = Some("ex")) .select("ex.data".attr), - Seq(1, 2, 3, 2, 3, 4).map(Seq(_))) + Seq(1, 2, 3, 2, 3, 4).map(Row(_))) } test("average") { checkAnswer( testData2.aggregate(avg('a)), - 2.0) + Row(2.0)) checkAnswer( testData2.aggregate(avg('a), sumDistinct('a)), // non-partial - (2.0, 6.0) :: Nil) + Row(2.0, 6.0) :: Nil) checkAnswer( decimalData.aggregate(avg('a)), - new java.math.BigDecimal(2.0)) + Row(new java.math.BigDecimal(2.0))) checkAnswer( decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial - (new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) checkAnswer( decimalData.aggregate(avg('a cast DecimalType(10, 2))), - new java.math.BigDecimal(2.0)) + Row(new java.math.BigDecimal(2.0))) checkAnswer( decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial - (new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) } test("null average") { checkAnswer( testData3.aggregate(avg('b)), - 2.0) + Row(2.0)) checkAnswer( testData3.aggregate(avg('b), countDistinct('b)), - (2.0, 1) :: Nil) + Row(2.0, 1)) checkAnswer( testData3.aggregate(avg('b), sumDistinct('b)), // non-partial - (2.0, 2.0) :: Nil) + Row(2.0, 2.0)) } test("zero average") { checkAnswer( emptyTableData.aggregate(avg('a)), - null) + Row(null)) checkAnswer( emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial - (null, null) :: Nil) + Row(null, null)) } test("count") { @@ -249,28 +249,28 @@ class DslQuerySuite extends QueryTest { checkAnswer( testData2.aggregate(count('a), sumDistinct('a)), // non-partial - (6, 6.0) :: Nil) + Row(6, 6.0)) } test("null count") { checkAnswer( testData3.groupBy('a)('a, count('b)), - Seq((1,0), (2, 1)) + Seq(Row(1,0), Row(2, 1)) ) checkAnswer( testData3.groupBy('a)('a, count('a + 'b)), - Seq((1,0), (2, 1)) + Seq(Row(1,0), Row(2, 1)) ) checkAnswer( testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)), - (2, 1, 2, 2, 1) :: Nil + Row(2, 1, 2, 2, 1) ) checkAnswer( testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial - (1, 1, 2) :: Nil + Row(1, 1, 2) ) } @@ -279,28 +279,28 @@ class DslQuerySuite extends QueryTest { checkAnswer( emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial - (0, null) :: Nil) + Row(0, null)) } test("zero sum") { checkAnswer( emptyTableData.aggregate(sum('a)), - null) + Row(null)) } test("zero sum distinct") { checkAnswer( emptyTableData.aggregate(sumDistinct('a)), - null) + Row(null)) } test("except") { checkAnswer( lowerCaseData.except(upperCaseData), - (1, "a") :: - (2, "b") :: - (3, "c") :: - (4, "d") :: Nil) + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer(lowerCaseData.except(lowerCaseData), Nil) checkAnswer(upperCaseData.except(upperCaseData), Nil) } @@ -308,10 +308,10 @@ class DslQuerySuite extends QueryTest { test("intersect") { checkAnswer( lowerCaseData.intersect(lowerCaseData), - (1, "a") :: - (2, "b") :: - (3, "c") :: - (4, "d") :: Nil) + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) } @@ -321,75 +321,75 @@ class DslQuerySuite extends QueryTest { checkAnswer( // SELECT *, foo(key, value) FROM testData testData.select(Star(None), foo.call('key, 'value)).limit(3), - (1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil + Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil ) } test("sqrt") { checkAnswer( testData.select(sqrt('key)).orderBy('key asc), - (1 to 100).map(n => Seq(math.sqrt(n))) + (1 to 100).map(n => Row(math.sqrt(n))) ) checkAnswer( testData.select(sqrt('value), 'key).orderBy('key asc, 'value asc), - (1 to 100).map(n => Seq(math.sqrt(n), n)) + (1 to 100).map(n => Row(math.sqrt(n), n)) ) checkAnswer( testData.select(sqrt(Literal(null))), - (1 to 100).map(_ => Seq(null)) + (1 to 100).map(_ => Row(null)) ) } test("abs") { checkAnswer( testData.select(abs('key)).orderBy('key asc), - (1 to 100).map(n => Seq(n)) + (1 to 100).map(n => Row(n)) ) checkAnswer( negativeData.select(abs('key)).orderBy('key desc), - (1 to 100).map(n => Seq(n)) + (1 to 100).map(n => Row(n)) ) checkAnswer( testData.select(abs(Literal(null))), - (1 to 100).map(_ => Seq(null)) + (1 to 100).map(_ => Row(null)) ) } test("upper") { checkAnswer( lowerCaseData.select(upper('l)), - ('a' to 'd').map(c => Seq(c.toString.toUpperCase())) + ('a' to 'd').map(c => Row(c.toString.toUpperCase())) ) checkAnswer( testData.select(upper('value), 'key), - (1 to 100).map(n => Seq(n.toString, n)) + (1 to 100).map(n => Row(n.toString, n)) ) checkAnswer( testData.select(upper(Literal(null))), - (1 to 100).map(n => Seq(null)) + (1 to 100).map(n => Row(null)) ) } test("lower") { checkAnswer( upperCaseData.select(lower('L)), - ('A' to 'F').map(c => Seq(c.toString.toLowerCase())) + ('A' to 'F').map(c => Row(c.toString.toLowerCase())) ) checkAnswer( testData.select(lower('value), 'key), - (1 to 100).map(n => Seq(n.toString, n)) + (1 to 100).map(n => Row(n.toString, n)) ) checkAnswer( testData.select(lower(Literal(null))), - (1 to 100).map(n => Seq(null)) + (1 to 100).map(n => Row(null)) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e5ab16f9dd661..cd36da7751e83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -117,10 +117,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( upperCaseData.join(lowerCaseData, Inner).where('n === 'N), Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d") + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d") )) } @@ -128,10 +128,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)), Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d") + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d") )) } @@ -140,10 +140,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val y = testData2.where('a === 1).as('y) checkAnswer( x.join(y).where("x.a".attr === "y.a".attr), - (1,1,1,1) :: - (1,1,1,2) :: - (1,2,1,1) :: - (1,2,1,2) :: Nil + Row(1,1,1,1) :: + Row(1,1,1,2) :: + Row(1,2,1,1) :: + Row(1,2,1,2) :: Nil ) } @@ -163,54 +163,54 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr), testData.flatMap( - row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq) + row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } test("cartisian product join") { checkAnswer( testData3.join(testData3), - (1, null, 1, null) :: - (1, null, 2, 2) :: - (2, 2, 1, null) :: - (2, 2, 2, 2) :: Nil) + Row(1, null, 1, null) :: + Row(1, null, 2, 2) :: + Row(2, 2, 1, null) :: + Row(2, 2, 2, 2) :: Nil) } test("left outer join") { checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), - (1, "A", 1, "a") :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + Row(1, "A", 1, "a") :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), - (1, "A", null, null) :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), - (1, "A", null, null) :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), - (1, "A", 1, "a") :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + Row(1, "A", 1, "a") :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. @@ -221,12 +221,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) |GROUP BY l.N """.stripMargin), - (1, 1) :: - (2, 1) :: - (3, 1) :: - (4, 1) :: - (5, 1) :: - (6, 1) :: Nil) + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: Nil) checkAnswer( sql( @@ -235,42 +235,42 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) |GROUP BY r.a """.stripMargin), - (null, 6) :: Nil) + Row(null, 6) :: Nil) } test("right outer join") { checkAnswer( lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), - (1, "a", 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "a", 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) checkAnswer( lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)), - (null, null, 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(null, null, 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) checkAnswer( lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)), - (null, null, 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(null, null, 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) checkAnswer( lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)), - (1, "a", 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "a", 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. @@ -281,7 +281,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY l.a """.stripMargin), - (null, 6) :: Nil) + Row(null, 6)) checkAnswer( sql( @@ -290,12 +290,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY r.N """.stripMargin), - (1, 1) :: - (2, 1) :: - (3, 1) :: - (4, 1) :: - (5, 1) :: - (6, 1) :: Nil) + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: Nil) } test("full outer join") { @@ -307,32 +307,32 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) checkAnswer( left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", null, null) :: - (null, null, 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", null, null) :: + Row(null, null, 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) checkAnswer( left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", null, null) :: - (null, null, 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", null, null) :: + Row(null, null, 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. checkAnswer( @@ -342,7 +342,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY l.a """.stripMargin), - (null, 10) :: Nil) + Row(null, 10)) checkAnswer( sql( @@ -351,13 +351,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY r.N """.stripMargin), - (1, 1) :: - (2, 1) :: - (3, 1) :: - (4, 1) :: - (5, 1) :: - (6, 1) :: - (null, 4) :: Nil) + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: + Row(null, 4) :: Nil) checkAnswer( sql( @@ -366,13 +366,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) |GROUP BY l.N """.stripMargin), - (1, 1) :: - (2, 1) :: - (3, 1) :: - (4, 1) :: - (5, 1) :: - (6, 1) :: - (null, 4) :: Nil) + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: + Row(null, 4) :: Nil) checkAnswer( sql( @@ -381,7 +381,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) |GROUP BY r.a """.stripMargin), - (null, 10) :: Nil) + Row(null, 10)) } test("broadcasted left semi join operator selection") { @@ -412,12 +412,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("left semi join") { val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(rdd, - (1, 1) :: - (1, 2) :: - (2, 1) :: - (2, 2) :: - (3, 1) :: - (3, 2) :: Nil) + Row(1, 1) :: + Row(1, 2) :: + Row(2, 1) :: + Row(2, 2) :: + Row(3, 1) :: + Row(3, 2) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 68ddecc7f610d..42a21c148df53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -47,26 +47,17 @@ class QueryTest extends PlanTest { * @param rdd the [[SchemaRDD]] to be executed * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Any): Unit = { - val convertedAnswer = expectedAnswer match { - case s: Seq[_] if s.isEmpty => s - case s: Seq[_] if s.head.isInstanceOf[Product] && - !s.head.isInstanceOf[Seq[_]] => s.map(_.asInstanceOf[Product].productIterator.toIndexedSeq) - case s: Seq[_] => s - case singleItem => Seq(Seq(singleItem)) - } - + protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty - def prepareAnswer(answer: Seq[Any]): Seq[Any] = { + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). - val converted = answer.map { - case s: Seq[_] => s.map { + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { case d: java.math.BigDecimal => BigDecimal(d) case o => o - } - case o => o + }) } if (!isSorted) converted.sortBy(_.toString) else converted } @@ -82,7 +73,7 @@ class QueryTest extends PlanTest { """.stripMargin) } - if (prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) { + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { fail(s""" |Results do not match for query: |${rdd.logicalPlan} @@ -92,15 +83,19 @@ class QueryTest extends PlanTest { |${rdd.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${convertedAnswer.size} ==" +: - prepareAnswer(convertedAnswer).map(_.toString), + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString), s"== Spark Answer - ${sparkAnswer.size} ==" +: prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} """.stripMargin) } } - def sqlTest(sqlString: String, expectedAnswer: Any)(implicit sqlContext: SQLContext): Unit = { + protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { + checkAnswer(rdd, Seq(expectedAnswer)) + } + + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { test(sqlString) { checkAnswer(sqlContext.sql(sqlString), expectedAnswer) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 54fabc5c915fb..03b44ca1d6695 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -46,7 +46,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( sql("SELECT a FROM testData2 SORT BY a"), - Seq(1, 1, 2 ,2 ,3 ,3).map(Seq(_)) + Seq(1, 1, 2 ,2 ,3 ,3).map(Row(_)) ) } @@ -70,13 +70,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3176 Added Parser of SQL ABS()") { checkAnswer( sql("SELECT ABS(-1.3)"), - 1.3) + Row(1.3)) checkAnswer( sql("SELECT ABS(0.0)"), - 0.0) + Row(0.0)) checkAnswer( sql("SELECT ABS(2.5)"), - 2.5) + Row(2.5)) } test("aggregation with codegen") { @@ -89,13 +89,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3176 Added Parser of SQL LAST()") { checkAnswer( sql("SELECT LAST(n) FROM lowerCaseData"), - 4) + Row(4)) } test("SPARK-2041 column name equals tablename") { checkAnswer( sql("SELECT tableName FROM tableName"), - "test") + Row("test")) } test("SQRT") { @@ -115,40 +115,40 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-2407 Added Parser of SQL SUBSTR()") { checkAnswer( sql("SELECT substr(tableName, 1, 2) FROM tableName"), - "te") + Row("te")) checkAnswer( sql("SELECT substr(tableName, 3) FROM tableName"), - "st") + Row("st")) checkAnswer( sql("SELECT substring(tableName, 1, 2) FROM tableName"), - "te") + Row("te")) checkAnswer( sql("SELECT substring(tableName, 3) FROM tableName"), - "st") + Row("st")) } test("SPARK-3173 Timestamp support in the parser") { checkAnswer(sql( "SELECT time FROM timestamps WHERE time=CAST('1970-01-01 00:00:00.001' AS TIMESTAMP)"), - Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))) + Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) checkAnswer(sql( "SELECT time FROM timestamps WHERE time='1970-01-01 00:00:00.001'"), - Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))) + Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) checkAnswer(sql( "SELECT time FROM timestamps WHERE '1970-01-01 00:00:00.001'=time"), - Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")))) + Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) checkAnswer(sql( """SELECT time FROM timestamps WHERE time<'1970-01-01 00:00:00.003' AND time>'1970-01-01 00:00:00.001'"""), - Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))) + Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002"))) checkAnswer(sql( "SELECT time FROM timestamps WHERE time IN ('1970-01-01 00:00:00.001','1970-01-01 00:00:00.002')"), - Seq(Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")), - Seq(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))) + Seq(Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")), + Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))) checkAnswer(sql( "SELECT time FROM timestamps WHERE time='123'"), @@ -158,13 +158,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("index into array") { checkAnswer( sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"), - arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq) + arrayData.map(d => Row(d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect()) } test("left semi greater than predicate") { checkAnswer( sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), - Seq((3,1), (3,2)) + Seq(Row(3,1), Row(3,2)) ) } @@ -173,7 +173,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql( "SELECT nestedData, nestedData[0][0], nestedData[0][0] + nestedData[0][1] FROM arrayData"), arrayData.map(d => - (d.nestedData, + Row(d.nestedData, d.nestedData(0)(0), d.nestedData(0)(0) + d.nestedData(0)(1))).collect().toSeq) } @@ -181,13 +181,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("agg") { checkAnswer( sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), - Seq((1,3),(2,3),(3,3))) + Seq(Row(1,3), Row(2,3), Row(3,3))) } test("aggregates with nulls") { checkAnswer( sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), - (1, 3, 2, 6, 3) :: Nil + Row(1, 3, 2, 6, 3) ) } @@ -200,29 +200,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("simple select") { checkAnswer( sql("SELECT value FROM testData WHERE key = 1"), - Seq(Seq("1"))) + Row("1")) } def sortTest() = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), - Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) + Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"), - Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1))) + Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"), - Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1))) + Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), - Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) + Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) checkAnswer( sql("SELECT b FROM binaryData ORDER BY a ASC"), - (1 to 5).map(Row(_)).toSeq) + (1 to 5).map(Row(_))) checkAnswer( sql("SELECT b FROM binaryData ORDER BY a DESC"), @@ -230,19 +230,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM arrayData ORDER BY data[0] ASC"), - arrayData.collect().sortBy(_.data(0)).toSeq) + arrayData.collect().sortBy(_.data(0)).map(Row.fromTuple).toSeq) checkAnswer( sql("SELECT * FROM arrayData ORDER BY data[0] DESC"), - arrayData.collect().sortBy(_.data(0)).reverse.toSeq) + arrayData.collect().sortBy(_.data(0)).reverse.map(Row.fromTuple).toSeq) checkAnswer( sql("SELECT * FROM mapData ORDER BY data[1] ASC"), - mapData.collect().sortBy(_.data(1)).toSeq) + mapData.collect().sortBy(_.data(1)).map(Row.fromTuple).toSeq) checkAnswer( sql("SELECT * FROM mapData ORDER BY data[1] DESC"), - mapData.collect().sortBy(_.data(1)).reverse.toSeq) + mapData.collect().sortBy(_.data(1)).reverse.map(Row.fromTuple).toSeq) } test("sorting") { @@ -266,94 +266,94 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM arrayData LIMIT 1"), - arrayData.collect().take(1).toSeq) + arrayData.collect().take(1).map(Row.fromTuple).toSeq) checkAnswer( sql("SELECT * FROM mapData LIMIT 1"), - mapData.collect().take(1).toSeq) + mapData.collect().take(1).map(Row.fromTuple).toSeq) } test("from follow multiple brackets") { checkAnswer(sql( "select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"), - 1 + Row(1) ) checkAnswer(sql( "select key from (select * from testData) x limit 1"), - 1 + Row(1) ) checkAnswer(sql( "select key from (select * from testData limit 1 union all select * from testData limit 1) x limit 1"), - 1 + Row(1) ) } test("average") { checkAnswer( sql("SELECT AVG(a) FROM testData2"), - 2.0) + Row(2.0)) } test("average overflow") { checkAnswer( sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), - Seq((2147483645.0,1),(2.0,2))) + Seq(Row(2147483645.0,1), Row(2.0,2))) } test("count") { checkAnswer( sql("SELECT COUNT(*) FROM testData2"), - testData2.count()) + Row(testData2.count())) } test("count distinct") { checkAnswer( sql("SELECT COUNT(DISTINCT b) FROM testData2"), - 2) + Row(2)) } test("approximate count distinct") { checkAnswer( sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"), - 3) + Row(3)) } test("approximate count distinct with user provided standard deviation") { checkAnswer( sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"), - 3) + Row(3)) } test("null count") { checkAnswer( sql("SELECT a, COUNT(b) FROM testData3 GROUP BY a"), - Seq((1, 0), (2, 1))) + Seq(Row(1, 0), Row(2, 1))) checkAnswer( sql("SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), - (2, 1, 2, 2, 1) :: Nil) + Row(2, 1, 2, 2, 1)) } test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d"))) + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d"))) } test("inner join ON, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON n = N"), Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d"))) + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d"))) } test("inner join, where, multiple matches") { @@ -363,10 +363,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | (SELECT * FROM testData2 WHERE a = 1) x JOIN | (SELECT * FROM testData2 WHERE a = 1) y |WHERE x.a = y.a""".stripMargin), - (1,1,1,1) :: - (1,1,1,2) :: - (1,2,1,1) :: - (1,2,1,2) :: Nil) + Row(1,1,1,1) :: + Row(1,1,1,2) :: + Row(1,2,1,1) :: + Row(1,2,1,2) :: Nil) } test("inner join, no matches") { @@ -397,38 +397,38 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | SELECT * FROM testData) y |WHERE x.key = y.key""".stripMargin), testData.flatMap( - row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq) + row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } ignore("cartesian product join") { checkAnswer( testData3.join(testData3), - (1, null, 1, null) :: - (1, null, 2, 2) :: - (2, 2, 1, null) :: - (2, 2, 2, 2) :: Nil) + Row(1, null, 1, null) :: + Row(1, null, 2, 2) :: + Row(2, 2, 1, null) :: + Row(2, 2, 2, 2) :: Nil) } test("left outer join") { checkAnswer( sql("SELECT * FROM upperCaseData LEFT OUTER JOIN lowerCaseData ON n = N"), - (1, "A", 1, "a") :: - (2, "B", 2, "b") :: - (3, "C", 3, "c") :: - (4, "D", 4, "d") :: - (5, "E", null, null) :: - (6, "F", null, null) :: Nil) + Row(1, "A", 1, "a") :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) } test("right outer join") { checkAnswer( sql("SELECT * FROM lowerCaseData RIGHT OUTER JOIN upperCaseData ON n = N"), - (1, "a", 1, "A") :: - (2, "b", 2, "B") :: - (3, "c", 3, "C") :: - (4, "d", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "a", 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) } test("full outer join") { @@ -440,12 +440,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | (SELECT * FROM upperCaseData WHERE N >= 3) rightTable | ON leftTable.N = rightTable.N """.stripMargin), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", 3, "C") :: + Row (4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) } test("SPARK-3349 partitioning after limit") { @@ -457,12 +457,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .registerTempTable("subset2") checkAnswer( sql("SELECT * FROM lowerCaseData INNER JOIN subset1 ON subset1.n = lowerCaseData.n"), - (3, "c", 3) :: - (4, "d", 4) :: Nil) + Row(3, "c", 3) :: + Row(4, "d", 4) :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"), - (1, "a", 1) :: - (2, "b", 2) :: Nil) + Row(1, "a", 1) :: + Row(2, "b", 2) :: Nil) } test("mixed-case keywords") { @@ -474,28 +474,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | (sElEcT * FROM upperCaseData whERe N >= 3) rightTable | oN leftTable.N = rightTable.N """.stripMargin), - (1, "A", null, null) :: - (2, "B", null, null) :: - (3, "C", 3, "C") :: - (4, "D", 4, "D") :: - (null, null, 5, "E") :: - (null, null, 6, "F") :: Nil) + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) } test("select with table name as qualifier") { checkAnswer( sql("SELECT testData.value FROM testData WHERE testData.key = 1"), - Seq(Seq("1"))) + Row("1")) } test("inner join ON with table name as qualifier") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON lowerCaseData.n = upperCaseData.N"), Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d"))) + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d"))) } test("qualified select with inner join ON with table name as qualifier") { @@ -503,72 +503,72 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT upperCaseData.N, upperCaseData.L FROM upperCaseData JOIN lowerCaseData " + "ON lowerCaseData.n = upperCaseData.N"), Seq( - (1, "A"), - (2, "B"), - (3, "C"), - (4, "D"))) + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"))) } test("system function upper()") { checkAnswer( sql("SELECT n,UPPER(l) FROM lowerCaseData"), Seq( - (1, "A"), - (2, "B"), - (3, "C"), - (4, "D"))) + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"))) checkAnswer( sql("SELECT n, UPPER(s) FROM nullStrings"), Seq( - (1, "ABC"), - (2, "ABC"), - (3, null))) + Row(1, "ABC"), + Row(2, "ABC"), + Row(3, null))) } test("system function lower()") { checkAnswer( sql("SELECT N,LOWER(L) FROM upperCaseData"), Seq( - (1, "a"), - (2, "b"), - (3, "c"), - (4, "d"), - (5, "e"), - (6, "f"))) + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "d"), + Row(5, "e"), + Row(6, "f"))) checkAnswer( sql("SELECT n, LOWER(s) FROM nullStrings"), Seq( - (1, "abc"), - (2, "abc"), - (3, null))) + Row(1, "abc"), + Row(2, "abc"), + Row(3, null))) } test("UNION") { checkAnswer( sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"), - (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") :: - (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil) + Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") :: + Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"), - (1, "a") :: (2, "b") :: (3, "c") :: (4, "d") :: Nil) + Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Row(4, "d") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"), - (1, "a") :: (1, "a") :: (2, "b") :: (2, "b") :: (3, "c") :: (3, "c") :: - (4, "d") :: (4, "d") :: Nil) + Row(1, "a") :: Row(1, "a") :: Row(2, "b") :: Row(2, "b") :: Row(3, "c") :: Row(3, "c") :: + Row(4, "d") :: Row(4, "d") :: Nil) } test("UNION with column mismatches") { // Column name mismatches are allowed. checkAnswer( sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"), - (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") :: - (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil) + Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") :: + Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) // Column type mismatches are not allowed, forcing a type coercion. checkAnswer( sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"), - ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Tuple1(_))) + ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_))) // Column type mismatches where a coercion is not possible, in this case between integer // and array types, trigger a TreeNodeException. intercept[TreeNodeException[_]] { @@ -579,10 +579,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("EXCEPT") { checkAnswer( sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"), - (1, "a") :: - (2, "b") :: - (3, "c") :: - (4, "d") :: Nil) + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil) checkAnswer( @@ -592,10 +592,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("INTERSECT") { checkAnswer( sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"), - (1, "a") :: - (2, "b") :: - (3, "c") :: - (4, "d") :: Nil) + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM upperCaseData"), Nil) } @@ -613,25 +613,25 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Seq(Seq(s"$testKey=$testVal")) + Row(s"$testKey=$testVal") ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Seq(s"$testKey=$testVal"), - Seq(s"${testKey + testKey}=${testVal + testVal}")) + Row(s"$testKey=$testVal"), + Row(s"${testKey + testKey}=${testVal + testVal}")) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Seq(Seq(s"$testKey=$testVal")) + Row(s"$testKey=$testVal") ) checkAnswer( sql(s"SET $nonexistentKey"), - Seq(Seq(s"$nonexistentKey=")) + Row(s"$nonexistentKey=") ) conf.clear() } @@ -655,17 +655,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { schemaRDD1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), - (1, "A1", true, null) :: - (2, "B2", false, null) :: - (3, "C3", true, null) :: - (4, "D4", true, 2147483644) :: Nil) + Row(1, "A1", true, null) :: + Row(2, "B2", false, null) :: + Row(3, "C3", true, null) :: + Row(4, "D4", true, 2147483644) :: Nil) checkAnswer( sql("SELECT f1, f4 FROM applySchema1"), - (1, null) :: - (2, null) :: - (3, null) :: - (4, 2147483644) :: Nil) + Row(1, null) :: + Row(2, null) :: + Row(3, null) :: + Row(4, 2147483644) :: Nil) val schema2 = StructType( StructField("f1", StructType( @@ -685,17 +685,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { schemaRDD2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), - (Seq(1, true), Map("A1" -> null)) :: - (Seq(2, false), Map("B2" -> null)) :: - (Seq(3, true), Map("C3" -> null)) :: - (Seq(4, true), Map("D4" -> 2147483644)) :: Nil) + Row(Row(1, true), Map("A1" -> null)) :: + Row(Row(2, false), Map("B2" -> null)) :: + Row(Row(3, true), Map("C3" -> null)) :: + Row(Row(4, true), Map("D4" -> 2147483644)) :: Nil) checkAnswer( sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), - (1, null) :: - (2, null) :: - (3, null) :: - (4, 2147483644) :: Nil) + Row(1, null) :: + Row(2, null) :: + Row(3, null) :: + Row(4, 2147483644) :: Nil) // The value of a MapType column can be a mutable map. val rowRDD3 = unparsedStrings.map { r => @@ -711,26 +711,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), - (1, null) :: - (2, null) :: - (3, null) :: - (4, 2147483644) :: Nil) + Row(1, null) :: + Row(2, null) :: + Row(3, null) :: + Row(4, 2147483644) :: Nil) } test("SPARK-3423 BETWEEN") { checkAnswer( sql("SELECT key, value FROM testData WHERE key BETWEEN 5 and 7"), - Seq((5, "5"), (6, "6"), (7, "7")) + Seq(Row(5, "5"), Row(6, "6"), Row(7, "7")) ) checkAnswer( sql("SELECT key, value FROM testData WHERE key BETWEEN 7 and 7"), - Seq((7, "7")) + Row(7, "7") ) checkAnswer( sql("SELECT key, value FROM testData WHERE key BETWEEN 9 and 7"), - Seq() + Nil ) } @@ -738,7 +738,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // TODO Ensure true/false string letter casing is consistent with Hive in all cases. checkAnswer( sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), - ("true", "false") :: Nil) + Row("true", "false")) } test("metadata is propagated correctly") { @@ -768,17 +768,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3371 Renaming a function expression with group by gives error") { udf.register("len", (s: String) => s.length) checkAnswer( - sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1) + sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), + Row(1)) } test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") { checkAnswer( - sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1) + sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), + Row(1)) } test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") { checkAnswer( - sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1) + sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), + Row(1)) } test("throw errors for non-aggregate attributes with aggregation") { @@ -808,130 +811,131 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("Test to check we can use Long.MinValue") { checkAnswer( - sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Long.MinValue + sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Row(Long.MinValue) ) checkAnswer( - sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), (1 to 100).map(Row(_)).toSeq + sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), + (1 to 100).map(Row(_)).toSeq ) } test("Floating point number format") { checkAnswer( - sql("SELECT 0.3"), 0.3 + sql("SELECT 0.3"), Row(0.3) ) checkAnswer( - sql("SELECT -0.8"), -0.8 + sql("SELECT -0.8"), Row(-0.8) ) checkAnswer( - sql("SELECT .5"), 0.5 + sql("SELECT .5"), Row(0.5) ) checkAnswer( - sql("SELECT -.18"), -0.18 + sql("SELECT -.18"), Row(-0.18) ) } test("Auto cast integer type") { checkAnswer( - sql(s"SELECT ${Int.MaxValue + 1L}"), Int.MaxValue + 1L + sql(s"SELECT ${Int.MaxValue + 1L}"), Row(Int.MaxValue + 1L) ) checkAnswer( - sql(s"SELECT ${Int.MinValue - 1L}"), Int.MinValue - 1L + sql(s"SELECT ${Int.MinValue - 1L}"), Row(Int.MinValue - 1L) ) checkAnswer( - sql("SELECT 9223372036854775808"), new java.math.BigDecimal("9223372036854775808") + sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808")) ) checkAnswer( - sql("SELECT -9223372036854775809"), new java.math.BigDecimal("-9223372036854775809") + sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809")) ) } test("Test to check we can apply sign to expression") { checkAnswer( - sql("SELECT -100"), -100 + sql("SELECT -100"), Row(-100) ) checkAnswer( - sql("SELECT +230"), 230 + sql("SELECT +230"), Row(230) ) checkAnswer( - sql("SELECT -5.2"), -5.2 + sql("SELECT -5.2"), Row(-5.2) ) checkAnswer( - sql("SELECT +6.8"), 6.8 + sql("SELECT +6.8"), Row(6.8) ) checkAnswer( - sql("SELECT -key FROM testData WHERE key = 2"), -2 + sql("SELECT -key FROM testData WHERE key = 2"), Row(-2) ) checkAnswer( - sql("SELECT +key FROM testData WHERE key = 3"), 3 + sql("SELECT +key FROM testData WHERE key = 3"), Row(3) ) checkAnswer( - sql("SELECT -(key + 1) FROM testData WHERE key = 1"), -2 + sql("SELECT -(key + 1) FROM testData WHERE key = 1"), Row(-2) ) checkAnswer( - sql("SELECT - key + 1 FROM testData WHERE key = 10"), -9 + sql("SELECT - key + 1 FROM testData WHERE key = 10"), Row(-9) ) checkAnswer( - sql("SELECT +(key + 5) FROM testData WHERE key = 5"), 10 + sql("SELECT +(key + 5) FROM testData WHERE key = 5"), Row(10) ) checkAnswer( - sql("SELECT -MAX(key) FROM testData"), -100 + sql("SELECT -MAX(key) FROM testData"), Row(-100) ) checkAnswer( - sql("SELECT +MAX(key) FROM testData"), 100 + sql("SELECT +MAX(key) FROM testData"), Row(100) ) checkAnswer( - sql("SELECT - (-10)"), 10 + sql("SELECT - (-10)"), Row(10) ) checkAnswer( - sql("SELECT + (-key) FROM testData WHERE key = 32"), -32 + sql("SELECT + (-key) FROM testData WHERE key = 32"), Row(-32) ) checkAnswer( - sql("SELECT - (+Max(key)) FROM testData"), -100 + sql("SELECT - (+Max(key)) FROM testData"), Row(-100) ) checkAnswer( - sql("SELECT - - 3"), 3 + sql("SELECT - - 3"), Row(3) ) checkAnswer( - sql("SELECT - + 20"), -20 + sql("SELECT - + 20"), Row(-20) ) checkAnswer( - sql("SELEcT - + 45"), -45 + sql("SELEcT - + 45"), Row(-45) ) checkAnswer( - sql("SELECT + + 100"), 100 + sql("SELECT + + 100"), Row(100) ) checkAnswer( - sql("SELECT - - Max(key) FROM testData"), 100 + sql("SELECT - - Max(key) FROM testData"), Row(100) ) checkAnswer( - sql("SELECT + - key FROM testData WHERE key = 33"), -33 + sql("SELECT + - key FROM testData WHERE key = 33"), Row(-33) ) } @@ -943,7 +947,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { |JOIN testData b ON a.key = b.key |JOIN testData c ON a.key = c.key """.stripMargin), - (1 to 100).map(i => Seq(i, i, i))) + (1 to 100).map(i => Row(i, i, i))) } test("SPARK-3483 Special chars in column names") { @@ -953,19 +957,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3814 Support Bitwise & operator") { - checkAnswer(sql("SELECT key&1 FROM testData WHERE key = 1 "), 1) + checkAnswer(sql("SELECT key&1 FROM testData WHERE key = 1 "), Row(1)) } test("SPARK-3814 Support Bitwise | operator") { - checkAnswer(sql("SELECT key|0 FROM testData WHERE key = 1 "), 1) + checkAnswer(sql("SELECT key|0 FROM testData WHERE key = 1 "), Row(1)) } test("SPARK-3814 Support Bitwise ^ operator") { - checkAnswer(sql("SELECT key^0 FROM testData WHERE key = 1 "), 1) + checkAnswer(sql("SELECT key^0 FROM testData WHERE key = 1 "), Row(1)) } test("SPARK-3814 Support Bitwise ~ operator") { - checkAnswer(sql("SELECT ~key FROM testData WHERE key = 1 "), -2) + checkAnswer(sql("SELECT ~key FROM testData WHERE key = 1 "), Row(-2)) } test("SPARK-4120 Join of multiple tables does not work in SparkSQL") { @@ -975,40 +979,40 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { |FROM testData a,testData b,testData c |where a.key = b.key and a.key = c.key """.stripMargin), - (1 to 100).map(i => Seq(i, i, i))) + (1 to 100).map(i => Row(i, i, i))) } test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { checkAnswer(sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"), - (11 to 100).map(i => Seq(i))) + (11 to 100).map(i => Row(i))) } test("SPARK-4207 Query which has syntax like 'not like' is not working in Spark SQL") { checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"), - (1 to 99).map(i => Seq(i))) + (1 to 99).map(i => Row(i))) } test("SPARK-4322 Grouping field with struct field as sub expression") { jsonRDD(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") - checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), 1) + checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) dropTempTable("data") jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") - checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), 2) + checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) dropTempTable("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { checkAnswer( sql("SELECT a + b FROM testData2 ORDER BY a"), - Seq(2, 3, 3 ,4 ,4 ,5).map(Seq(_)) + Seq(2, 3, 3 ,4 ,4 ,5).map(Row(_)) ) } test("oder by asc by default when not specify ascending and descending") { checkAnswer( sql("SELECT a, b FROM testData2 ORDER BY a desc, b"), - Seq((3, 1), (3, 2), (2, 1), (2,2), (1, 1), (1, 2)) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2,2), Row(1, 1), Row(1, 2)) ) } @@ -1021,13 +1025,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { rdd2.registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), - (1 to 2).map(i => Seq(i))) + (1 to 2).map(i => Row(i))) } test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.registerTempTable("distinctData") - checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), 2) + checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index ee381da491054..a015884bae282 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -82,7 +82,7 @@ class ScalaReflectionRelationSuite extends FunSuite { rdd.registerTempTable("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === - Seq("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, + Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))) } @@ -91,7 +91,7 @@ class ScalaReflectionRelationSuite extends FunSuite { val rdd = sparkContext.parallelize(data :: Nil) rdd.registerTempTable("reflectNullData") - assert(sql("SELECT * FROM reflectNullData").collect().head === Seq.fill(7)(null)) + assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } test("query case class RDD with Nones") { @@ -99,7 +99,7 @@ class ScalaReflectionRelationSuite extends FunSuite { val rdd = sparkContext.parallelize(data :: Nil) rdd.registerTempTable("reflectOptionalData") - assert(sql("SELECT * FROM reflectOptionalData").collect().head === Seq.fill(7)(null)) + assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } // Equality is broken for Arrays, so we test that separately. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 9be0b38e689ff..be2b34de077c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -42,8 +42,8 @@ class ColumnStatsSuite extends FunSuite { test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => - assert(actual === expected) + columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + case (actual, expected) => assert(actual === expected) } } @@ -54,7 +54,7 @@ class ColumnStatsSuite extends FunSuite { val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.take(10).map(_.head.asInstanceOf[T#JvmType]) + val values = rows.take(10).map(_(0).asInstanceOf[T#JvmType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]] val stats = columnStats.collectedStatistics diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index d94729ba92360..e61f3c39631da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -49,7 +49,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key - }.toSeq) + }.map(Row.fromTuple)) } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { @@ -63,49 +63,49 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("SPARK-1678 regression: compression must not lose repeated values") { checkAnswer( sql("SELECT * FROM repeatedData"), - repeatedData.collect().toSeq) + repeatedData.collect().toSeq.map(Row.fromTuple)) cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), - repeatedData.collect().toSeq) + repeatedData.collect().toSeq.map(Row.fromTuple)) } test("with null values") { checkAnswer( sql("SELECT * FROM nullableRepeatedData"), - nullableRepeatedData.collect().toSeq) + nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), - nullableRepeatedData.collect().toSeq) + nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) } test("SPARK-2729 regression: timestamp data type") { checkAnswer( sql("SELECT time FROM timestamps"), - timestamps.collect().toSeq) + timestamps.collect().toSeq.map(Row.fromTuple)) cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), - timestamps.collect().toSeq) + timestamps.collect().toSeq.map(Row.fromTuple)) } test("SPARK-3320 regression: batched column buffer building should work with empty partitions") { checkAnswer( sql("SELECT * FROM withEmptyParts"), - withEmptyParts.collect().toSeq) + withEmptyParts.collect().toSeq.map(Row.fromTuple)) cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), - withEmptyParts.collect().toSeq) + withEmptyParts.collect().toSeq.map(Row.fromTuple)) } test("SPARK-4182 Caching complex types") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 592cafbbdc203..c3a3f8ddc3ebf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -108,7 +108,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be val queryExecution = schemaRdd.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { - schemaRdd.collect().map(_.head).toArray + schemaRdd.collect().map(_(0)).toArray } val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index d9e488e0ffd16..8b518f094174c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -34,7 +34,7 @@ class BooleanBitSetSuite extends FunSuite { val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN)) - val values = rows.map(_.head) + val values = rows.map(_(0)) rows.foreach(builder.appendFrom(_, 0)) val buffer = builder.build() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala index 2cab5e0c44d92..272c0d4cb2335 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala @@ -59,7 +59,7 @@ class TgfSuite extends QueryTest { checkAnswer( inputData.generate(ExampleTGF()), Seq( - "michael is 29 years old" :: Nil, - "Next year, michael will be 30 years old" :: Nil)) + Row("michael is 29 years old"), + Row("Next year, michael will be 30 years old"))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 2bc9aede32f2a..94d14acccbb18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -229,13 +229,13 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), - (new java.math.BigDecimal("92233720368547758070"), - true, - 1.7976931348623157E308, - 10, - 21474836470L, - null, - "this is a simple string.") :: Nil + Row(new java.math.BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") ) } @@ -271,48 +271,49 @@ class JsonSuite extends QueryTest { // Access elements of a primitive array. checkAnswer( sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"), - ("str1", "str2", null) :: Nil + Row("str1", "str2", null) ) // Access an array of null values. checkAnswer( sql("select arrayOfNull from jsonTable"), - Seq(Seq(null, null, null, null)) :: Nil + Row(Seq(null, null, null, null)) ) // Access elements of a BigInteger array (we use DecimalType internally). checkAnswer( sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"), - (new java.math.BigDecimal("922337203685477580700"), - new java.math.BigDecimal("-922337203685477580800"), null) :: Nil + Row(new java.math.BigDecimal("922337203685477580700"), + new java.math.BigDecimal("-922337203685477580800"), null) ) // Access elements of an array of arrays. checkAnswer( sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"), - (Seq("1", "2", "3"), Seq("str1", "str2")) :: Nil + Row(Seq("1", "2", "3"), Seq("str1", "str2")) ) // Access elements of an array of arrays. checkAnswer( sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"), - (Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) :: Nil + Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) ) // Access elements of an array inside a filed with the type of ArrayType(ArrayType). checkAnswer( sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"), - ("str2", 2.1) :: Nil + Row("str2", 2.1) ) // Access elements of an array of structs. checkAnswer( sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " + "from jsonTable"), - (true :: "str1" :: null :: Nil, - false :: null :: null :: Nil, - null :: null :: null :: Nil, - null) :: Nil + Row( + Row(true, "str1", null), + Row(false, null, null), + Row(null, null, null), + null) ) // Access a struct and fields inside of it. @@ -327,13 +328,13 @@ class JsonSuite extends QueryTest { // Access an array field of a struct. checkAnswer( sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"), - (Seq(4, 5, 6), Seq("str1", "str2")) :: Nil + Row(Seq(4, 5, 6), Seq("str1", "str2")) ) // Access elements of an array field of a struct. checkAnswer( sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), - (5, null) :: Nil + Row(5, null) ) } @@ -344,14 +345,14 @@ class JsonSuite extends QueryTest { // Right now, "field1" and "field2" are treated as aliases. We should fix it. checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), - (true, "str1") :: Nil + Row(true, "str1") ) // Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2. // Getting all values of a specific field from an array of structs. checkAnswer( sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), - (Seq(true, false), Seq("str1", null)) :: Nil + Row(Seq(true, false), Seq("str1", null)) ) } @@ -372,57 +373,57 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), - ("true", 11L, null, 1.1, "13.1", "str1") :: - ("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") :: - ("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: - (null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil + Row("true", 11L, null, 1.1, "13.1", "str1") :: + Row("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") :: + Row("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: + Row(null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil ) // Number and Boolean conflict: resolve the type as number in this query. checkAnswer( sql("select num_bool - 10 from jsonTable where num_bool > 11"), - 2 + Row(2) ) // Widening to LongType checkAnswer( sql("select num_num_1 - 100 from jsonTable where num_num_1 > 11"), - Seq(21474836370L) :: Seq(21474836470L) :: Nil + Row(21474836370L) :: Row(21474836470L) :: Nil ) checkAnswer( sql("select num_num_1 - 100 from jsonTable where num_num_1 > 10"), - Seq(-89) :: Seq(21474836370L) :: Seq(21474836470L) :: Nil + Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil ) // Widening to DecimalType checkAnswer( sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"), - Seq(new java.math.BigDecimal("21474836472.1")) :: Seq(new java.math.BigDecimal("92233720368547758071.2")) :: Nil + Row(new java.math.BigDecimal("21474836472.1")) :: Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil ) // Widening to DoubleType checkAnswer( sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), - Seq(101.2) :: Seq(21474836471.2) :: Nil + Row(101.2) :: Row(21474836471.2) :: Nil ) // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 14"), - 92233720368547758071.2 + Row(92233720368547758071.2) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"), - new java.math.BigDecimal("92233720368547758061.2").doubleValue + Row(new java.math.BigDecimal("92233720368547758061.2").doubleValue) ) // String and Boolean conflict: resolve the type as string. checkAnswer( sql("select * from jsonTable where str_bool = 'str1'"), - ("true", 11L, null, 1.1, "13.1", "str1") :: Nil + Row("true", 11L, null, 1.1, "13.1", "str1") ) } @@ -434,24 +435,24 @@ class JsonSuite extends QueryTest { // Number and Boolean conflict: resolve the type as boolean in this query. checkAnswer( sql("select num_bool from jsonTable where NOT num_bool"), - false + Row(false) ) checkAnswer( sql("select str_bool from jsonTable where NOT str_bool"), - false + Row(false) ) // Right now, the analyzer does not know that num_bool should be treated as a boolean. // Number and Boolean conflict: resolve the type as boolean in this query. checkAnswer( sql("select num_bool from jsonTable where num_bool"), - true + Row(true) ) checkAnswer( sql("select str_bool from jsonTable where str_bool"), - false + Row(false) ) // The plan of the following DSL is @@ -464,7 +465,7 @@ class JsonSuite extends QueryTest { jsonSchemaRDD. where('num_str > BigDecimal("92233720368547758060")). select('num_str + 1.2 as Symbol("num")), - new java.math.BigDecimal("92233720368547758061.2") + Row(new java.math.BigDecimal("92233720368547758061.2")) ) // The following test will fail. The type of num_str is StringType. @@ -475,7 +476,7 @@ class JsonSuite extends QueryTest { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 13"), - Seq(14.3) :: Seq(92233720368547758071.2) :: Nil + Row(14.3) :: Row(92233720368547758071.2) :: Nil ) } @@ -496,10 +497,10 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), - (Seq(), "11", "[1,2,3]", Seq(null), "[]") :: - (null, """{"field":false}""", null, null, "{}") :: - (Seq(4, 5, 6), null, "str", Seq(null), "[7,8,9]") :: - (Seq(7), "{}","[str1,str2,33]", Seq("str"), """{"field":true}""") :: Nil + Row(Seq(), "11", "[1,2,3]", Row(null), "[]") :: + Row(null, """{"field":false}""", null, null, "{}") :: + Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") :: + Row(Seq(7), "{}","[str1,str2,33]", Row("str"), """{"field":true}""") :: Nil ) } @@ -518,16 +519,16 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), - Seq(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", - """{"field":"str"}"""), Seq(Seq(214748364700L), Seq(1)), null) :: - Seq(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) :: - Seq(null, null, Seq("1", "2", "3")) :: Nil + Row(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", + """{"field":"str"}"""), Seq(Row(214748364700L), Row(1)), null) :: + Row(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) :: + Row(null, null, Seq("1", "2", "3")) :: Nil ) // Treat an element as a number. checkAnswer( sql("select array1[0] + 1 from jsonTable where array1 is not null"), - 2 + Row(2) ) } @@ -568,13 +569,13 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), - (new java.math.BigDecimal("92233720368547758070"), + Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, 21474836470L, null, - "this is a simple string.") :: Nil + "this is a simple string.") ) } @@ -594,13 +595,13 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTableSQL"), - (new java.math.BigDecimal("92233720368547758070"), + Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, 21474836470L, null, - "this is a simple string.") :: Nil + "this is a simple string.") ) } @@ -626,13 +627,13 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable1"), - (new java.math.BigDecimal("92233720368547758070"), + Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, 21474836470L, null, - "this is a simple string.") :: Nil + "this is a simple string.") ) val jsonSchemaRDD2 = jsonRDD(primitiveFieldAndType, schema) @@ -643,13 +644,13 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable2"), - (new java.math.BigDecimal("92233720368547758070"), + Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, 21474836470L, null, - "this is a simple string.") :: Nil + "this is a simple string.") ) } @@ -659,7 +660,7 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), - (true, "str1") :: Nil + Row(true, "str1") ) checkAnswer( sql( @@ -667,7 +668,7 @@ class JsonSuite extends QueryTest { |select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] |from jsonTable """.stripMargin), - ("str2", 6) :: Nil + Row("str2", 6) ) } @@ -681,7 +682,7 @@ class JsonSuite extends QueryTest { |select arrayOfArray1[0][0][0], arrayOfArray1[1][0][1], arrayOfArray1[1][1][0] |from jsonTable """.stripMargin), - (5, 7, 8) :: Nil + Row(5, 7, 8) ) checkAnswer( sql( @@ -690,7 +691,7 @@ class JsonSuite extends QueryTest { |arrayOfArray2[1][1][1].inner2[0], arrayOfArray2[2][0][0].inner3[0][0].inner4 |from jsonTable """.stripMargin), - ("str1", Nil, "str4", 2) :: Nil + Row("str1", Nil, "str4", 2) ) } @@ -704,10 +705,10 @@ class JsonSuite extends QueryTest { |select a, b, c |from jsonTable """.stripMargin), - ("str_a_1", null, null) :: - ("str_a_2", null, null) :: - (null, "str_b_3", null) :: - ("str_a_4", "str_b_4", "str_c_4") :: Nil + Row("str_a_1", null, null) :: + Row("str_a_2", null, null) :: + Row(null, "str_b_3", null) :: + Row("str_a_4", "str_b_4", "str_c_4") :: Nil ) } @@ -734,12 +735,12 @@ class JsonSuite extends QueryTest { |SELECT a, b, c, _unparsed |FROM jsonTable """.stripMargin), - (null, null, null, "{") :: - (null, null, null, "") :: - (null, null, null, """{"a":1, b:2}""") :: - (null, null, null, """{"a":{, b:3}""") :: - ("str_a_4", "str_b_4", "str_c_4", null) :: - (null, null, null, "]") :: Nil + Row(null, null, null, "{") :: + Row(null, null, null, "") :: + Row(null, null, null, """{"a":1, b:2}""") :: + Row(null, null, null, """{"a":{, b:3}""") :: + Row("str_a_4", "str_b_4", "str_c_4", null) :: + Row(null, null, null, "]") :: Nil ) checkAnswer( @@ -749,7 +750,7 @@ class JsonSuite extends QueryTest { |FROM jsonTable |WHERE _unparsed IS NULL """.stripMargin), - ("str_a_4", "str_b_4", "str_c_4") :: Nil + Row("str_a_4", "str_b_4", "str_c_4") ) checkAnswer( @@ -759,11 +760,11 @@ class JsonSuite extends QueryTest { |FROM jsonTable |WHERE _unparsed IS NOT NULL """.stripMargin), - Seq("{") :: - Seq("") :: - Seq("""{"a":1, b:2}""") :: - Seq("""{"a":{, b:3}""") :: - Seq("]") :: Nil + Row("{") :: + Row("") :: + Row("""{"a":1, b:2}""") :: + Row("""{"a":{, b:3}""") :: + Row("]") :: Nil ) TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) @@ -793,10 +794,10 @@ class JsonSuite extends QueryTest { |SELECT field1, field2, field3, field4 |FROM jsonTable """.stripMargin), - Seq(Seq(Seq(null), Seq(Seq(Seq("Test")))), null, null, null) :: - Seq(null, Seq(null, Seq(Seq(1))), null, null) :: - Seq(null, null, Seq(Seq(null), Seq(Seq("2"))), null) :: - Seq(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil + Row(Seq(Seq(null), Seq(Seq(Seq("Test")))), null, null, null) :: + Row(null, Seq(null, Seq(Row(1))), null, null) :: + Row(null, null, Seq(Seq(null), Seq(Row("2"))), null) :: + Row(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil ) } @@ -851,12 +852,12 @@ class JsonSuite extends QueryTest { primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), - (new java.math.BigDecimal("92233720368547758070"), + Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, 21474836470L, - "this is a simple string.") :: Nil + "this is a simple string.") ) val complexJsonSchemaRDD = jsonRDD(complexFieldAndType1) @@ -865,38 +866,38 @@ class JsonSuite extends QueryTest { // Access elements of a primitive array. checkAnswer( sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from complexTable"), - ("str1", "str2", null) :: Nil + Row("str1", "str2", null) ) // Access an array of null values. checkAnswer( sql("select arrayOfNull from complexTable"), - Seq(Seq(null, null, null, null)) :: Nil + Row(Seq(null, null, null, null)) ) // Access elements of a BigInteger array (we use DecimalType internally). checkAnswer( sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from complexTable"), - (new java.math.BigDecimal("922337203685477580700"), - new java.math.BigDecimal("-922337203685477580800"), null) :: Nil + Row(new java.math.BigDecimal("922337203685477580700"), + new java.math.BigDecimal("-922337203685477580800"), null) ) // Access elements of an array of arrays. checkAnswer( sql("select arrayOfArray1[0], arrayOfArray1[1] from complexTable"), - (Seq("1", "2", "3"), Seq("str1", "str2")) :: Nil + Row(Seq("1", "2", "3"), Seq("str1", "str2")) ) // Access elements of an array of arrays. checkAnswer( sql("select arrayOfArray2[0], arrayOfArray2[1] from complexTable"), - (Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) :: Nil + Row(Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) ) // Access elements of an array inside a filed with the type of ArrayType(ArrayType). checkAnswer( sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from complexTable"), - ("str2", 2.1) :: Nil + Row("str2", 2.1) ) // Access a struct and fields inside of it. @@ -911,13 +912,13 @@ class JsonSuite extends QueryTest { // Access an array field of a struct. checkAnswer( sql("select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"), - (Seq(4, 5, 6), Seq("str1", "str2")) :: Nil + Row(Seq(4, 5, 6), Seq("str1", "str2")) ) // Access elements of an array field of a struct. checkAnswer( sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"), - (5, null) :: Nil + Row(5, null) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 4c3a04506ce42..4ad8c472007fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -46,7 +46,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { predicate: Predicate, filterClass: Class[_ <: FilterPredicate], checker: (SchemaRDD, Any) => Unit, - expectedResult: => Any): Unit = { + expectedResult: Any): Unit = { withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { val query = rdd.select(output.map(_.attr): _*).where(predicate) @@ -65,11 +65,20 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } } - private def checkFilterPushdown + private def checkFilterPushdown1 (rdd: SchemaRDD, output: Symbol*) (predicate: Predicate, filterClass: Class[_ <: FilterPredicate]) - (expectedResult: => Any): Unit = { - checkFilterPushdown(rdd, output, predicate, filterClass, checkAnswer _, expectedResult) + (expectedResult: => Seq[Row]): Unit = { + checkFilterPushdown(rdd, output, predicate, filterClass, + (query, expected) => checkAnswer(query, expected.asInstanceOf[Seq[Row]]), expectedResult) + } + + private def checkFilterPushdown + (rdd: SchemaRDD, output: Symbol*) + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate]) + (expectedResult: Int): Unit = { + checkFilterPushdown(rdd, output, predicate, filterClass, + (query, expected) => checkAnswer(query, expected.asInstanceOf[Seq[Row]]), Seq(Row(expectedResult))) } def checkBinaryFilterPushdown @@ -89,27 +98,25 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - boolean") { withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Boolean]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Boolean]]) { + checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Boolean]])(Seq.empty[Row]) + checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Boolean]]) { Seq(Row(true), Row(false)) } - checkFilterPushdown(rdd, '_1)('_1 === true, classOf[Eq[java.lang.Boolean]])(true) - checkFilterPushdown(rdd, '_1)('_1 !== true, classOf[Operators.NotEq[java.lang.Boolean]]) { - false - } + checkFilterPushdown1(rdd, '_1)('_1 === true, classOf[Eq[java.lang.Boolean]])(Seq(Row(true))) + checkFilterPushdown1(rdd, '_1)('_1 !== true, classOf[Operators.NotEq[java.lang.Boolean]])(Seq(Row(false))) } } test("filter pushdown - integer") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[Integer]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[Integer]]) { + checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[Integer]])(Seq.empty[Row]) + checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[Integer]]) { (1 to 4).map(Row.apply(_)) } checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[Integer]])(1) - checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[Integer]]) { + checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[Integer]]) { (2 to 4).map(Row.apply(_)) } @@ -126,7 +133,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[Integer]])(4) checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { + checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { Seq(Row(1), Row(4)) } } @@ -134,13 +141,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - long") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Long]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Long]]) { + checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Long]])(Seq.empty[Row]) + checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Long]]) { (1 to 4).map(Row.apply(_)) } checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Long]])(1) - checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Long]]) { + checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Long]]) { (2 to 4).map(Row.apply(_)) } @@ -157,7 +164,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Long]])(4) checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { + checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { Seq(Row(1), Row(4)) } } @@ -165,13 +172,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - float") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Float]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Float]]) { + checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Float]])(Seq.empty[Row]) + checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Float]]) { (1 to 4).map(Row.apply(_)) } checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Float]])(1) - checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Float]]) { + checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Float]]) { (2 to 4).map(Row.apply(_)) } @@ -188,7 +195,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Float]])(4) checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { + checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { Seq(Row(1), Row(4)) } } @@ -196,13 +203,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - double") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Double]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Double]]) { + checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Double]])(Seq.empty[Row]) + checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Double]]) { (1 to 4).map(Row.apply(_)) } checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Double]])(1) - checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Double]]) { + checkFilterPushdown1(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Double]]) { (2 to 4).map(Row.apply(_)) } @@ -219,7 +226,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPushdown(rdd, '_1)(!('_1 < 4), classOf[Operators.GtEq[java.lang.Double]])(4) checkFilterPushdown(rdd, '_1)('_1 > 2 && '_1 < 4, classOf[Operators.And])(3) - checkFilterPushdown(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { + checkFilterPushdown1(rdd, '_1)('_1 < 2 || '_1 > 3, classOf[Operators.Or]) { Seq(Row(1), Row(4)) } } @@ -227,30 +234,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - string") { withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { rdd => - checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row]) - checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) { + checkFilterPushdown1(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row]) + checkFilterPushdown1(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) { (1 to 4).map(i => Row.apply(i.toString)) } - checkFilterPushdown(rdd, '_1)('_1 === "1", classOf[Eq[String]])("1") - checkFilterPushdown(rdd, '_1)('_1 !== "1", classOf[Operators.NotEq[String]]) { + checkFilterPushdown1(rdd, '_1)('_1 === "1", classOf[Eq[String]])(Seq(Row("1"))) + checkFilterPushdown1(rdd, '_1)('_1 !== "1", classOf[Operators.NotEq[String]]) { (2 to 4).map(i => Row.apply(i.toString)) } - checkFilterPushdown(rdd, '_1)('_1 < "2", classOf[Lt [java.lang.String]])("1") - checkFilterPushdown(rdd, '_1)('_1 > "3", classOf[Gt [java.lang.String]])("4") - checkFilterPushdown(rdd, '_1)('_1 <= "1", classOf[LtEq[java.lang.String]])("1") - checkFilterPushdown(rdd, '_1)('_1 >= "4", classOf[GtEq[java.lang.String]])("4") + checkFilterPushdown1(rdd, '_1)('_1 < "2", classOf[Lt [java.lang.String]])(Seq(Row("1"))) + checkFilterPushdown1(rdd, '_1)('_1 > "3", classOf[Gt [java.lang.String]])(Seq(Row("4"))) + checkFilterPushdown1(rdd, '_1)('_1 <= "1", classOf[LtEq[java.lang.String]])(Seq(Row("1"))) + checkFilterPushdown1(rdd, '_1)('_1 >= "4", classOf[GtEq[java.lang.String]])(Seq(Row("4"))) - checkFilterPushdown(rdd, '_1)(Literal("1") === '_1, classOf[Eq [java.lang.String]])("1") - checkFilterPushdown(rdd, '_1)(Literal("2") > '_1, classOf[Lt [java.lang.String]])("1") - checkFilterPushdown(rdd, '_1)(Literal("3") < '_1, classOf[Gt [java.lang.String]])("4") - checkFilterPushdown(rdd, '_1)(Literal("1") >= '_1, classOf[LtEq[java.lang.String]])("1") - checkFilterPushdown(rdd, '_1)(Literal("4") <= '_1, classOf[GtEq[java.lang.String]])("4") + checkFilterPushdown1(rdd, '_1)(Literal("1") === '_1, classOf[Eq [java.lang.String]])(Seq(Row("1"))) + checkFilterPushdown1(rdd, '_1)(Literal("2") > '_1, classOf[Lt [java.lang.String]])(Seq(Row("1"))) + checkFilterPushdown1(rdd, '_1)(Literal("3") < '_1, classOf[Gt [java.lang.String]])(Seq(Row("4"))) + checkFilterPushdown1(rdd, '_1)(Literal("1") >= '_1, classOf[LtEq[java.lang.String]])(Seq(Row("1"))) + checkFilterPushdown1(rdd, '_1)(Literal("4") <= '_1, classOf[GtEq[java.lang.String]])(Seq(Row("4"))) - checkFilterPushdown(rdd, '_1)(!('_1 < "4"), classOf[Operators.GtEq[java.lang.String]])("4") - checkFilterPushdown(rdd, '_1)('_1 > "2" && '_1 < "4", classOf[Operators.And])("3") - checkFilterPushdown(rdd, '_1)('_1 < "2" || '_1 > "3", classOf[Operators.Or]) { + checkFilterPushdown1(rdd, '_1)(!('_1 < "4"), classOf[Operators.GtEq[java.lang.String]])(Seq(Row("4"))) + checkFilterPushdown1(rdd, '_1)('_1 > "2" && '_1 < "4", classOf[Operators.And])(Seq(Row("3"))) + checkFilterPushdown1(rdd, '_1)('_1 < "2" || '_1 > "3", classOf[Operators.Or]) { Seq(Row("1"), Row("4")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 973819aaa4d77..a57e4e85a35ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -68,8 +68,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest { /** * Writes `data` to a Parquet file, reads it back and check file contents. */ - protected def checkParquetFile[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { - withParquetRDD(data)(checkAnswer(_, data)) + protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = { + withParquetRDD(data)(r => checkAnswer(r, data.map(Row.fromTuple))) } test("basic data types (without binary)") { @@ -143,7 +143,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { withParquetRDD(data) { rdd => // Structs are converted to `Row`s checkAnswer(rdd, data.map { case Tuple1(struct) => - Tuple1(Row(struct.productIterator.toSeq: _*)) + Row(Row(struct.productIterator.toSeq: _*)) }) } } @@ -153,7 +153,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { withParquetRDD(data) { rdd => // Structs are converted to `Row`s checkAnswer(rdd, data.map { case Tuple1(struct) => - Tuple1(Row(struct.productIterator.toSeq: _*)) + Row(Row(struct.productIterator.toSeq: _*)) }) } } @@ -162,7 +162,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) withParquetRDD(data) { rdd => checkAnswer(rdd, data.map { case Tuple1(m) => - Tuple1(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) + Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) }) } } @@ -261,7 +261,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { val path = new Path(dir.toURI.toString, "part-r-0.parquet") makeRawParquetFile(path) checkAnswer(parquetFile(path.toString), (0 until 10).map { i => - (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 3a073a6b7057e..2c5345b1f9148 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -28,7 +28,7 @@ import parquet.hadoop.util.ContextUtil import parquet.io.api.Binary import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Row => _, _} import org.apache.spark.sql.catalyst.util.getTempFilePath import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ @@ -191,8 +191,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA parquetFile(path).registerTempTable("tmp") checkAnswer( sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) + Row(5, "val_5") :: + Row(7, "val_7") :: Nil) Utils.deleteRecursively(file) @@ -207,8 +207,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA parquetFile(path).registerTempTable("tmp") checkAnswer( sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) + Row(5, "val_5") :: + Row(7, "val_7") :: Nil) Utils.deleteRecursively(file) @@ -223,8 +223,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA parquetFile(path).registerTempTable("tmp") checkAnswer( sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) + Row(5, "val_5") :: + Row(7, "val_7") :: Nil) Utils.deleteRecursively(file) @@ -239,8 +239,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA parquetFile(path).registerTempTable("tmp") checkAnswer( sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) + Row(5, "val_5") :: + Row(7, "val_7") :: Nil) Utils.deleteRecursively(file) @@ -255,8 +255,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA parquetFile(path).registerTempTable("tmp") checkAnswer( sql("SELECT key, value FROM tmp WHERE value = 'val_5' OR value = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) + Row(5, "val_5") :: + Row(7, "val_7") :: Nil) Utils.deleteRecursively(file) @@ -303,7 +303,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result.size === 9, "self-join result has incorrect size") assert(result(0).size === 12, "result row has incorrect size") result.zipWithIndex.foreach { - case (row, index) => row.zipWithIndex.foreach { + case (row, index) => row.toSeq.zipWithIndex.foreach { case (field, column) => assert(field != null, s"self-join contains null value in row $index field $column") } } @@ -423,7 +423,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val readFile = parquetFile(path) val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Seq.fill(5)(null)) + assert(rdd_saved(0) === Row(null, null, null, null, null)) Utils.deleteRecursively(file) assert(true) } @@ -438,7 +438,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val readFile = parquetFile(path) val rdd_saved = readFile.collect() - assert(rdd_saved(0) === Seq.fill(5)(null)) + assert(rdd_saved(0) === Row(null, null, null, null, null)) Utils.deleteRecursively(file) assert(true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala index 4c081fb4510b2..7b3f8c22af2db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala @@ -38,7 +38,7 @@ class ParquetQuerySuite2 extends QueryTest with ParquetTest { val data = (0 until 10).map(i => (i, i.toString)) withParquetTable(data, "t") { sql("INSERT INTO t SELECT * FROM t") - checkAnswer(table("t"), data ++ data) + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 264f6d94c4ed9..b1e0919b7aed1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -244,7 +244,7 @@ class TableScanSuite extends DataSourceTest { sqlTest( "SELECT count(*) FROM tableWithSchema", - 10) + Seq(Row(10))) sqlTest( "SELECT `string$%Field` FROM tableWithSchema", @@ -260,7 +260,7 @@ class TableScanSuite extends DataSourceTest { sqlTest( "SELECT structFieldSimple.key, arrayFieldSimple[1] FROM tableWithSchema a where int_Field=1", - Seq(Seq(1, 2))) + Seq(Row(1, 2))) sqlTest( "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 10833c113216a..3e26fe3675768 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -368,10 +368,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { .mkString("\t") } case command: ExecutedCommand => - command.executeCollect().map(_.head.toString) + command.executeCollect().map(_(0).toString) case other => - val result: Seq[Seq[Any]] = other.executeCollect().toSeq + val result: Seq[Seq[Any]] = other.executeCollect().map(_.toSeq).toSeq // We need the types so we can output struct field names val types = analyzed.output.map(_.dataType) // Reformat to match hive tab delimited output. @@ -395,7 +395,7 @@ private object HiveContext { protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => - struct.zip(fields).map { + struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => @@ -418,7 +418,7 @@ private object HiveContext { /** Hive outputs fields of structs slightly differently than top level attributes. */ protected def toHiveStructString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => - struct.zip(fields).map { + struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index eeabfdd857916..82dba99900df9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -348,7 +348,7 @@ private[hive] trait HiveInspectors { (o: Any) => { if (o != null) { val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row].toSeq).zipped.foreach { (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) } struct @@ -432,7 +432,7 @@ private[hive] trait HiveInspectors { } case x: SettableStructObjectInspector => val fieldRefs = x.getAllStructFieldRefs - val row = a.asInstanceOf[Seq[_]] + val row = a.asInstanceOf[Row] // 1. create the pojo (most likely) object val result = x.create() var i = 0 @@ -448,7 +448,7 @@ private[hive] trait HiveInspectors { result case x: StructObjectInspector => val fieldRefs = x.getAllStructFieldRefs - val row = a.asInstanceOf[Seq[_]] + val row = a.asInstanceOf[Row] val result = new java.util.ArrayList[AnyRef](fieldRefs.length) var i = 0 while (i < fieldRefs.length) { @@ -475,7 +475,7 @@ private[hive] trait HiveInspectors { } def wrap( - row: Seq[Any], + row: Row, inspectors: Seq[ObjectInspector], cache: Array[AnyRef]): Array[AnyRef] = { var i = 0 @@ -486,6 +486,18 @@ private[hive] trait HiveInspectors { cache } + def wrap( + row: Seq[Any], + inspectors: Seq[ObjectInspector], + cache: Array[AnyRef]): Array[AnyRef] = { + var i = 0 + while (i < inspectors.length) { + cache(i) = wrap(row(i), inspectors(i)) + i += 1 + } + cache + } + /** * @param dataType Catalyst data type * @return Hive java object inspector (recursively), not the Writable ObjectInspector diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index d898b876c39f8..76d2140372197 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -360,7 +360,7 @@ private[hive] case class HiveUdafFunction( protected lazy val cached = new Array[AnyRef](exprs.length) def update(input: Row): Unit = { - val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray + val inputs = inputProjection(input) function.iterate(buffer, wrap(inputs, inspectors, cached)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index cc8bb3e172c6e..aae175e426ade 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -209,7 +209,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { val dynamicPartPath = dynamicPartColNames - .zip(row.takeRight(dynamicPartColNames.length)) + .zip(row.toSeq.takeRight(dynamicPartColNames.length)) .map { case (col, rawVal) => val string = if (rawVal == null) null else String.valueOf(rawVal) s"/$col=${if (string == null || string.isEmpty) defaultPartName else string}" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index f89c49d292c6c..f320d732fb77a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util._ * So, we duplicate this code here. */ class QueryTest extends PlanTest { + /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer @@ -56,17 +57,20 @@ class QueryTest extends PlanTest { * @param rdd the [[SchemaRDD]] to be executed * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Any): Unit = { - val convertedAnswer = expectedAnswer match { - case s: Seq[_] if s.isEmpty => s - case s: Seq[_] if s.head.isInstanceOf[Product] && - !s.head.isInstanceOf[Seq[_]] => s.map(_.asInstanceOf[Product].productIterator.toIndexedSeq) - case s: Seq[_] => s - case singleItem => Seq(Seq(singleItem)) + protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { + val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case o => o + }) + } + if (!isSorted) converted.sortBy(_.toString) else converted } - - val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s}.nonEmpty - def prepareAnswer(answer: Seq[Any]) = if (!isSorted) answer.sortBy(_.toString) else answer val sparkAnswer = try rdd.collect().toSeq catch { case e: Exception => fail( @@ -74,11 +78,12 @@ class QueryTest extends PlanTest { |Exception thrown while executing query: |${rdd.queryExecution} |== Exception == - |${stackTraceToString(e)} + |$e + |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} """.stripMargin) } - if(prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) { + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { fail(s""" |Results do not match for query: |${rdd.logicalPlan} @@ -88,11 +93,22 @@ class QueryTest extends PlanTest { |${rdd.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${convertedAnswer.size} ==" +: - prepareAnswer(convertedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} """.stripMargin) } } + + protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { + checkAnswer(rdd, Seq(expectedAnswer)) + } + + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { + test(sqlString) { + checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + } + } + } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 4864607252034..2d3ff680125ad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -129,6 +129,12 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { } } + def checkValues(row1: Seq[Any], row2: Row): Unit = { + row1.zip(row2.toSeq).map { + case (r1, r2) => checkValue(r1, r2) + } + } + def checkValue(v1: Any, v2: Any): Unit = { (v1, v2) match { case (r1: Decimal, r2: Decimal) => @@ -198,7 +204,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { case (t, idx) => StructField(s"c_$idx", t) }) - checkValues(row, unwrap(wrap(row, toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) + checkValues(row, unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 7cfb875e05db3..0e6636d38ed3c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -43,7 +43,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq + testData.collect().toSeq.map(Row.fromTuple) ) // Add more data. @@ -52,7 +52,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the table has been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq ++ testData.collect().toSeq + testData.toSchemaRDD.collect().toSeq ++ testData.toSchemaRDD.collect().toSeq ) // Now overwrite. @@ -61,7 +61,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the registered table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq + testData.collect().toSeq.map(Row.fromTuple) ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 53d8aa7739bc2..7408c7ffd69e8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -155,7 +155,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), - ("a", "b") :: Nil) + Row("a", "b")) FileUtils.deleteDirectory(tempDir) sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) @@ -164,14 +164,14 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { // will show. checkAnswer( sql("SELECT * FROM jsonTable"), - ("a1", "b1") :: Nil) + Row("a1", "b1")) refreshTable("jsonTable") // Check that the refresh worked checkAnswer( sql("SELECT * FROM jsonTable"), - ("a1", "b1", "c1") :: Nil) + Row("a1", "b1", "c1")) FileUtils.deleteDirectory(tempDir) } @@ -191,7 +191,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), - ("a", "b") :: Nil) + Row("a", "b")) FileUtils.deleteDirectory(tempDir) sparkContext.parallelize(("a", "b", "c") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) @@ -210,7 +210,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { // New table should reflect new schema. checkAnswer( sql("SELECT * FROM jsonTable"), - ("a", "b", "c") :: Nil) + Row("a", "b", "c")) FileUtils.deleteDirectory(tempDir) } @@ -253,6 +253,6 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |) """.stripMargin) - sql("DROP TABLE jsonTable").collect.foreach(println) + sql("DROP TABLE jsonTable").collect().foreach(println) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 0b4e76c9d3d2f..6f07fd5a879c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag -import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -141,7 +141,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { before: () => Unit, after: () => Unit, query: String, - expectedAnswer: Seq[Any], + expectedAnswer: Seq[Row], ct: ClassTag[_]) = { before() @@ -183,7 +183,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { /** Tests for MetastoreRelation */ val metastoreQuery = """SELECT * FROM src a JOIN src b ON a.key = 238 AND a.key = b.key""" - val metastoreAnswer = Seq.fill(4)((238, "val_238", 238, "val_238")) + val metastoreAnswer = Seq.fill(4)(Row(238, "val_238", 238, "val_238")) mkTest( () => (), () => (), @@ -197,7 +197,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val leftSemiJoinQuery = """SELECT * FROM src a |left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin - val answer = (86, "val_86") :: Nil + val answer = Row(86, "val_86") var rdd = sql(leftSemiJoinQuery) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index c14f0d24e0dc3..df72be7746ac6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -226,7 +226,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Jdk version leads to different query output for double, so not use createQueryTest here test("division") { val res = sql("SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1").collect().head - Seq(2.0, 0.5, 0.3333333333333333, 0.002).zip(res).foreach( x => + Seq(2.0, 0.5, 0.3333333333333333, 0.002).zip(res.toSeq).foreach( x => assert(x._1 == x._2.asInstanceOf[Double])) } @@ -235,7 +235,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("Query expressed in SQL") { setConf("spark.sql.dialect", "sql") - assert(sql("SELECT 1").collect() === Array(Seq(1))) + assert(sql("SELECT 1").collect() === Array(Row(1))) setConf("spark.sql.dialect", "hiveql") } @@ -467,7 +467,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TestData(2, "str2") :: Nil) testData.registerTempTable("REGisteredTABle") - assertResult(Array(Array(2, "str2"))) { + assertResult(Array(Row(2, "str2"))) { sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + "WHERE TableAliaS.a > 1").collect() } @@ -553,12 +553,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Describe a table assertResult( Array( - Array("key", "int", null), - Array("value", "string", null), - Array("dt", "string", null), - Array("# Partition Information", "", ""), - Array("# col_name", "data_type", "comment"), - Array("dt", "string", null)) + Row("key", "int", null), + Row("value", "string", null), + Row("dt", "string", null), + Row("# Partition Information", "", ""), + Row("# col_name", "data_type", "comment"), + Row("dt", "string", null)) ) { sql("DESCRIBE test_describe_commands1") .select('col_name, 'data_type, 'comment) @@ -568,12 +568,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Describe a table with a fully qualified table name assertResult( Array( - Array("key", "int", null), - Array("value", "string", null), - Array("dt", "string", null), - Array("# Partition Information", "", ""), - Array("# col_name", "data_type", "comment"), - Array("dt", "string", null)) + Row("key", "int", null), + Row("value", "string", null), + Row("dt", "string", null), + Row("# Partition Information", "", ""), + Row("# col_name", "data_type", "comment"), + Row("dt", "string", null)) ) { sql("DESCRIBE default.test_describe_commands1") .select('col_name, 'data_type, 'comment) @@ -623,8 +623,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assertResult( Array( - Array("a", "IntegerType", null), - Array("b", "StringType", null)) + Row("a", "IntegerType", null), + Row("b", "StringType", null)) ) { sql("DESCRIBE test_describe_commands2") .select('col_name, 'data_type, 'comment) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index 5dafcd6c0a76a..f2374a215291b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -64,7 +64,7 @@ class HiveUdfSuite extends QueryTest { test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") { checkAnswer( sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"), - 8 + Row(8) ) } @@ -115,7 +115,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'") checkAnswer( sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(), - Seq(Seq("1"), Seq("2"))) + Seq(Row("1"), Row("2"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") TestHive.reset() @@ -131,7 +131,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") checkAnswer( sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(), - Seq(Seq(0), Seq(2), Seq(13))) + Seq(Row(0), Row(2), Row(13))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") TestHive.reset() @@ -146,7 +146,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") checkAnswer( sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(), - Seq(Seq("a,b,c"), Seq("d,e"))) + Seq(Row("a,b,c"), Row("d,e"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") TestHive.reset() @@ -160,7 +160,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") checkAnswer( sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(), - Seq(Seq("hello world"), Seq("hello goodbye"))) + Seq(Row("hello world"), Row("hello goodbye"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") TestHive.reset() @@ -177,7 +177,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") checkAnswer( sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(), - Seq(Seq("0, 0"), Seq("2, 2"), Seq("13, 13"))) + Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") TestHive.reset() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index d41eb9e870bf0..f6bf2dbb5d6e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -41,7 +41,7 @@ class SQLQuerySuite extends QueryTest { } test("CTAS with serde") { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() sql( """CREATE TABLE ctas2 | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" @@ -51,23 +51,23 @@ class SQLQuerySuite extends QueryTest { | AS | SELECT key, value | FROM src - | ORDER BY key, value""".stripMargin).collect + | ORDER BY key, value""".stripMargin).collect() sql( """CREATE TABLE ctas3 | ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012' | STORED AS textfile AS | SELECT key, value | FROM src - | ORDER BY key, value""".stripMargin).collect + | ORDER BY key, value""".stripMargin).collect() // the table schema may like (key: integer, value: string) sql( """CREATE TABLE IF NOT EXISTS ctas4 AS - | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect + | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect() // do nothing cause the table ctas4 already existed. sql( """CREATE TABLE IF NOT EXISTS ctas4 AS - | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect + | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() checkAnswer( sql("SELECT k, value FROM ctas1 ORDER BY k, value"), @@ -89,7 +89,7 @@ class SQLQuerySuite extends QueryTest { intercept[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] { sql( """CREATE TABLE ctas4 AS - | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect + | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() } checkAnswer( sql("SELECT key, value FROM ctas4 ORDER BY key, value"), @@ -126,7 +126,7 @@ class SQLQuerySuite extends QueryTest { sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested") checkAnswer( sql("SELECT f1.f2.f3 FROM nested"), - 1) + Row(1)) checkAnswer(sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested"), Seq.empty[Row]) checkAnswer( @@ -233,7 +233,7 @@ class SQLQuerySuite extends QueryTest { | (s struct, | innerArray:array, | innerMap: map>) - """.stripMargin).collect + """.stripMargin).collect() sql( """ @@ -243,7 +243,7 @@ class SQLQuerySuite extends QueryTest { checkAnswer( sql("SELECT * FROM nullValuesInInnerComplexTypes"), - Seq(Seq(Seq(null, null, null))) + Row(Row(null, null, null)) ) sql("DROP TABLE nullValuesInInnerComplexTypes") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 4bc14bad0ad5f..581f666399492 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -39,7 +39,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test("SELECT on Parquet table") { val data = (1 to 4).map(i => (i, s"val_$i")) withParquetTable(data, "t") { - checkAnswer(sql("SELECT * FROM t"), data) + checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala index 8bbb7f2fdbf48..79fd99d9f89ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala @@ -177,81 +177,81 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll test(s"ordering of the partitioning columns $table") { checkAnswer( sql(s"SELECT p, stringField FROM $table WHERE p = 1"), - Seq.fill(10)((1, "part-1")) + Seq.fill(10)(Row(1, "part-1")) ) checkAnswer( sql(s"SELECT stringField, p FROM $table WHERE p = 1"), - Seq.fill(10)(("part-1", 1)) + Seq.fill(10)(Row("part-1", 1)) ) } test(s"project the partitioning column $table") { checkAnswer( sql(s"SELECT p, count(*) FROM $table group by p"), - (1, 10) :: - (2, 10) :: - (3, 10) :: - (4, 10) :: - (5, 10) :: - (6, 10) :: - (7, 10) :: - (8, 10) :: - (9, 10) :: - (10, 10) :: Nil + Row(1, 10) :: + Row(2, 10) :: + Row(3, 10) :: + Row(4, 10) :: + Row(5, 10) :: + Row(6, 10) :: + Row(7, 10) :: + Row(8, 10) :: + Row(9, 10) :: + Row(10, 10) :: Nil ) } test(s"project partitioning and non-partitioning columns $table") { checkAnswer( sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), - ("part-1", 1, 10) :: - ("part-2", 2, 10) :: - ("part-3", 3, 10) :: - ("part-4", 4, 10) :: - ("part-5", 5, 10) :: - ("part-6", 6, 10) :: - ("part-7", 7, 10) :: - ("part-8", 8, 10) :: - ("part-9", 9, 10) :: - ("part-10", 10, 10) :: Nil + Row("part-1", 1, 10) :: + Row("part-2", 2, 10) :: + Row("part-3", 3, 10) :: + Row("part-4", 4, 10) :: + Row("part-5", 5, 10) :: + Row("part-6", 6, 10) :: + Row("part-7", 7, 10) :: + Row("part-8", 8, 10) :: + Row("part-9", 9, 10) :: + Row("part-10", 10, 10) :: Nil ) } test(s"simple count $table") { checkAnswer( sql(s"SELECT COUNT(*) FROM $table"), - 100) + Row(100)) } test(s"pruned count $table") { checkAnswer( sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), - 10) + Row(10)) } test(s"non-existant partition $table") { checkAnswer( sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), - 0) + Row(0)) } test(s"multi-partition pruned count $table") { checkAnswer( sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), - 30) + Row(30)) } test(s"non-partition predicates $table") { checkAnswer( sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), - 30) + Row(30)) } test(s"sum $table") { checkAnswer( sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), - 1 + 2 + 3) + Row(1 + 2 + 3)) } test(s"hive udfs $table") { @@ -266,6 +266,6 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll test("non-part select(*)") { checkAnswer( sql("SELECT COUNT(*) FROM normal_parquet"), - 10) + Row(10)) } }