From edb1f0e3164b99b483b6878efe0b6ea5ae9f97ed Mon Sep 17 00:00:00 2001 From: akkomar Date: Fri, 13 Jun 2014 15:37:26 -0700 Subject: [PATCH 01/57] Small correction in Streaming Programming Guide doc Corrected description of `repartition` function under 'Level of Parallelism in Data Receiving'. Author: akkomar Closes #1079 from akkomar/streaming-guide-doc and squashes the following commits: 32dfc62 [akkomar] Corrected description of `repartition` function under 'Level of Parallelism in Data Receiving'. --- docs/streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index bbee67f54c6b8..ce8e58d64a7ed 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -950,7 +950,7 @@ is 200 milliseconds. An alternative to receiving data with multiple input streams / receivers is to explicitly repartition the input data stream (using `inputStream.repartition()`). -This distributes the received batches of data across all the machines in the cluster +This distributes the received batches of data across specified number of machines in the cluster before further processing. ### Level of Parallelism in Data Processing From 891968509105d8d8cf5a608ad9473aeeed747089 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 13 Jun 2014 23:28:57 -0700 Subject: [PATCH 02/57] [Spark-2137][SQL] Timestamp UDFs broken https://issues.apache.org/jira/browse/SPARK-2137 Author: Yin Huai Closes #1081 from yhuai/SPARK-2137 and squashes the following commits: c04f910 [Yin Huai] Merge remote-tracking branch 'upstream/master' into SPARK-2137 205f17b [Yin Huai] Make Hive UDF wrapper support Timestamp. --- .../src/main/scala/org/apache/spark/sql/hive/HiveContext.scala | 2 +- .../src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala | 3 ++- .../golden/timestamp_udf-10-dbc23736a61d9482d13cacada02a7a09 | 1 + .../golden/timestamp_udf-11-442cf850a0cc1f1dcfdeaeffbffb2c35 | 1 + .../golden/timestamp_udf-12-51959036fd4ac4f1e24f4e06eb9b0b6 | 1 + .../golden/timestamp_udf-13-6ab3f356deaf807e8accc37e1f4849a | 1 + .../golden/timestamp_udf-14-c745a1016461403526d44928a269c1de | 1 + .../golden/timestamp_udf-15-7ab76c4458c7f78038c8b1df0fdeafbe | 1 + .../golden/timestamp_udf-16-b36e87e17ca24d82072220bff559c718 | 1 + .../golden/timestamp_udf-17-dad44d2d4a421286e9da080271bd2639 | 1 + .../golden/timestamp_udf-18-cb033ecad964a2623bc633ac1d3f752a | 1 + .../golden/timestamp_udf-19-79914c5347620c6e62a8e0b9a95984af | 0 .../golden/timestamp_udf-20-59fc1842a23369235d42ed040d45fb3d | 0 .../golden/timestamp_udf-4-80ce02ec84ee8abcb046367ca37279cc | 0 .../golden/timestamp_udf-5-1124399033bcadf3874fb48f593392d | 1 + .../golden/timestamp_udf-6-5810193ce35d38c23f4fc4b4979d60a4 | 1 + .../golden/timestamp_udf-7-250e640a6a818f989f3f3280b00f64f9 | 1 + .../golden/timestamp_udf-8-975df43df015d86422965af456f87a94 | 1 + .../golden/timestamp_udf-9-287614364eaa3fb82aad08c6b62cc938 | 1 + 19 files changed, 17 insertions(+), 2 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-10-dbc23736a61d9482d13cacada02a7a09 create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-11-442cf850a0cc1f1dcfdeaeffbffb2c35 create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-12-51959036fd4ac4f1e24f4e06eb9b0b6 create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-13-6ab3f356deaf807e8accc37e1f4849a create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-14-c745a1016461403526d44928a269c1de create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-15-7ab76c4458c7f78038c8b1df0fdeafbe create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-16-b36e87e17ca24d82072220bff559c718 create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-17-dad44d2d4a421286e9da080271bd2639 create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-18-cb033ecad964a2623bc633ac1d3f752a create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-19-79914c5347620c6e62a8e0b9a95984af create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-20-59fc1842a23369235d42ed040d45fb3d create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-4-80ce02ec84ee8abcb046367ca37279cc create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-5-1124399033bcadf3874fb48f593392d create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-6-5810193ce35d38c23f4fc4b4979d60a4 create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-7-250e640a6a818f989f3f3280b00f64f9 create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-8-975df43df015d86422965af456f87a94 create mode 100644 sql/hive/src/test/resources/golden/timestamp_udf-9-287614364eaa3fb82aad08c6b62cc938 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 96e0ec5136331..cc95b7af0abf6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -250,7 +250,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, - ShortType, DecimalType) + ShortType, DecimalType, TimestampType) protected def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 572902042337f..771d2bccf43a7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -187,7 +187,8 @@ private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression]) val primitiveClasses = Seq( Integer.TYPE, classOf[java.lang.Integer], classOf[java.lang.String], java.lang.Double.TYPE, classOf[java.lang.Double], java.lang.Long.TYPE, classOf[java.lang.Long], - classOf[HiveDecimal], java.lang.Byte.TYPE, classOf[java.lang.Byte] + classOf[HiveDecimal], java.lang.Byte.TYPE, classOf[java.lang.Byte], + classOf[java.sql.Timestamp] ) val matchingConstructor = argClass.getConstructors.find { c => c.getParameterTypes.size == 1 && primitiveClasses.contains(c.getParameterTypes.head) diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-10-dbc23736a61d9482d13cacada02a7a09 b/sql/hive/src/test/resources/golden/timestamp_udf-10-dbc23736a61d9482d13cacada02a7a09 new file mode 100644 index 0000000000000..b3c4eec4c2209 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-10-dbc23736a61d9482d13cacada02a7a09 @@ -0,0 +1 @@ +2011-05-06 07:08:09.1234567 2011-05-06 02:08:09.1234567 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-11-442cf850a0cc1f1dcfdeaeffbffb2c35 b/sql/hive/src/test/resources/golden/timestamp_udf-11-442cf850a0cc1f1dcfdeaeffbffb2c35 new file mode 100644 index 0000000000000..f69f13ed1fb94 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-11-442cf850a0cc1f1dcfdeaeffbffb2c35 @@ -0,0 +1 @@ +2011-05-06 07:08:09.1234567 2011-05-06 02:08:09.1234567 2011-05-06 07:08:09.1234567 2011-05-06 02:08:09.1234567 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-12-51959036fd4ac4f1e24f4e06eb9b0b6 b/sql/hive/src/test/resources/golden/timestamp_udf-12-51959036fd4ac4f1e24f4e06eb9b0b6 new file mode 100644 index 0000000000000..f14f17e692822 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-12-51959036fd4ac4f1e24f4e06eb9b0b6 @@ -0,0 +1 @@ +2011-05-06 07:08:09.1234567 2011-05-06 12:08:09.1234567 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-13-6ab3f356deaf807e8accc37e1f4849a b/sql/hive/src/test/resources/golden/timestamp_udf-13-6ab3f356deaf807e8accc37e1f4849a new file mode 100644 index 0000000000000..7881bff731be1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-13-6ab3f356deaf807e8accc37e1f4849a @@ -0,0 +1 @@ +2011-05-06 07:08:09.1234567 2011-05-06 12:08:09.1234567 2011-05-06 07:08:09.1234567 2011-05-06 12:08:09.1234567 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-14-c745a1016461403526d44928a269c1de b/sql/hive/src/test/resources/golden/timestamp_udf-14-c745a1016461403526d44928a269c1de new file mode 100644 index 0000000000000..2c5e9e9656202 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-14-c745a1016461403526d44928a269c1de @@ -0,0 +1 @@ +1304690889 2011 5 6 6 18 7 8 9 2011-05-06 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-15-7ab76c4458c7f78038c8b1df0fdeafbe b/sql/hive/src/test/resources/golden/timestamp_udf-15-7ab76c4458c7f78038c8b1df0fdeafbe new file mode 100644 index 0000000000000..19497254f8f7e --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-15-7ab76c4458c7f78038c8b1df0fdeafbe @@ -0,0 +1 @@ +2011-05-11 2011-04-26 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-16-b36e87e17ca24d82072220bff559c718 b/sql/hive/src/test/resources/golden/timestamp_udf-16-b36e87e17ca24d82072220bff559c718 new file mode 100644 index 0000000000000..816f56e43eaba --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-16-b36e87e17ca24d82072220bff559c718 @@ -0,0 +1 @@ +0 3333 -3333 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-17-dad44d2d4a421286e9da080271bd2639 b/sql/hive/src/test/resources/golden/timestamp_udf-17-dad44d2d4a421286e9da080271bd2639 new file mode 100644 index 0000000000000..a4182d1e39db9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-17-dad44d2d4a421286e9da080271bd2639 @@ -0,0 +1 @@ +2011-05-06 02:08:09.1234567 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-18-cb033ecad964a2623bc633ac1d3f752a b/sql/hive/src/test/resources/golden/timestamp_udf-18-cb033ecad964a2623bc633ac1d3f752a new file mode 100644 index 0000000000000..02ccd3a2e97ce --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-18-cb033ecad964a2623bc633ac1d3f752a @@ -0,0 +1 @@ +2011-05-06 12:08:09.1234567 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-19-79914c5347620c6e62a8e0b9a95984af b/sql/hive/src/test/resources/golden/timestamp_udf-19-79914c5347620c6e62a8e0b9a95984af new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-20-59fc1842a23369235d42ed040d45fb3d b/sql/hive/src/test/resources/golden/timestamp_udf-20-59fc1842a23369235d42ed040d45fb3d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-4-80ce02ec84ee8abcb046367ca37279cc b/sql/hive/src/test/resources/golden/timestamp_udf-4-80ce02ec84ee8abcb046367ca37279cc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-5-1124399033bcadf3874fb48f593392d b/sql/hive/src/test/resources/golden/timestamp_udf-5-1124399033bcadf3874fb48f593392d new file mode 100644 index 0000000000000..2c5e9e9656202 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-5-1124399033bcadf3874fb48f593392d @@ -0,0 +1 @@ +1304690889 2011 5 6 6 18 7 8 9 2011-05-06 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-6-5810193ce35d38c23f4fc4b4979d60a4 b/sql/hive/src/test/resources/golden/timestamp_udf-6-5810193ce35d38c23f4fc4b4979d60a4 new file mode 100644 index 0000000000000..19497254f8f7e --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-6-5810193ce35d38c23f4fc4b4979d60a4 @@ -0,0 +1 @@ +2011-05-11 2011-04-26 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-7-250e640a6a818f989f3f3280b00f64f9 b/sql/hive/src/test/resources/golden/timestamp_udf-7-250e640a6a818f989f3f3280b00f64f9 new file mode 100644 index 0000000000000..816f56e43eaba --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-7-250e640a6a818f989f3f3280b00f64f9 @@ -0,0 +1 @@ +0 3333 -3333 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-8-975df43df015d86422965af456f87a94 b/sql/hive/src/test/resources/golden/timestamp_udf-8-975df43df015d86422965af456f87a94 new file mode 100644 index 0000000000000..a4182d1e39db9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-8-975df43df015d86422965af456f87a94 @@ -0,0 +1 @@ +2011-05-06 02:08:09.1234567 diff --git a/sql/hive/src/test/resources/golden/timestamp_udf-9-287614364eaa3fb82aad08c6b62cc938 b/sql/hive/src/test/resources/golden/timestamp_udf-9-287614364eaa3fb82aad08c6b62cc938 new file mode 100644 index 0000000000000..02ccd3a2e97ce --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_udf-9-287614364eaa3fb82aad08c6b62cc938 @@ -0,0 +1 @@ +2011-05-06 12:08:09.1234567 From 2550533a28382664f8fd294b2caa494d12bfc7c1 Mon Sep 17 00:00:00 2001 From: Kan Zhang Date: Sat, 14 Jun 2014 13:17:22 -0700 Subject: [PATCH 03/57] [SPARK-2079] Support batching when serializing SchemaRDD to Python Added batching with default batch size 10 in SchemaRDD.javaToPython Author: Kan Zhang Closes #1023 from kanzhang/SPARK-2079 and squashes the following commits: 2d1915e [Kan Zhang] [SPARK-2079] Add batching in SchemaRDD.javaToPython 19b0c09 [Kan Zhang] [SPARK-2079] Removing unnecessary wrapping in SchemaRDD.javaToPython --- python/pyspark/sql.py | 4 +++- .../src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 9 ++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 960d0a82448aa..e344610b1fe4d 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -16,6 +16,7 @@ # from pyspark.rdd import RDD +from pyspark.serializers import BatchedSerializer, PickleSerializer from py4j.protocol import Py4JError @@ -346,7 +347,8 @@ def _toPython(self): # TODO: This is inefficient, we should construct the Python Row object # in Java land in the javaToPython function. May require a custom # pickle serializer in Pyrolite - return RDD(jrdd, self._sc, self._sc.serializer).map(lambda d: Row(d)) + return RDD(jrdd, self._sc, BatchedSerializer( + PickleSerializer())).map(lambda d: Row(d)) # We override the default cache/persist/checkpoint behavior as we want to cache the underlying # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 821ac850ac3f5..89eaba2d19aa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -347,16 +347,11 @@ class SchemaRDD( val pickle = new Pickler iter.map { row => val map: JMap[String, Any] = new java.util.HashMap - // TODO: We place the map in an ArrayList so that the object is pickled to a List[Dict]. - // Ideally we should be able to pickle an object directly into a Python collection so we - // don't have to create an ArrayList every time. - val arr: java.util.ArrayList[Any] = new java.util.ArrayList row.zip(fieldNames).foreach { case (obj, name) => map.put(name, obj) } - arr.add(map) - pickle.dumps(arr) - } + map + }.grouped(10).map(batched => pickle.dumps(batched.toArray)) } } From b52603b039cdfa0f8e58ef3c6229d79e732ffc58 Mon Sep 17 00:00:00 2001 From: Kan Zhang Date: Sat, 14 Jun 2014 13:22:30 -0700 Subject: [PATCH 04/57] [SPARK-2013] Documentation for saveAsPickleFile and pickleFile in Python Author: Kan Zhang Closes #983 from kanzhang/SPARK-2013 and squashes the following commits: 0e128bb [Kan Zhang] [SPARK-2013] minor update e728516 [Kan Zhang] [SPARK-2013] Documentation for saveAsPickleFile and pickleFile in Python --- docs/programming-guide.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 79784682bfd1b..ef0c0e34301f3 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -377,13 +377,15 @@ Some notes on reading files with Spark: * The `textFile` method also takes an optional second argument for controlling the number of slices of the file. By default, Spark creates one slice for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of slices by passing a larger value. Note that you cannot have fewer slices than blocks. -Apart from reading files as a collection of lines, -`SparkContext.wholeTextFiles` lets you read a directory containing multiple small text files, and returns each of them as (filename, content) pairs. This is in contrast with `textFile`, which would return one record per line in each file. +Apart from text files, Spark's Python API also supports several other data formats: -### SequenceFile and Hadoop InputFormats +* `SparkContext.wholeTextFiles` lets you read a directory containing multiple small text files, and returns each of them as (filename, content) pairs. This is in contrast with `textFile`, which would return one record per line in each file. + +* `RDD.saveAsPickleFile` and `SparkContext.pickleFile` support saving an RDD in a simple format consisting of pickled Python objects. Batching is used on pickle serialization, with default batch size 10. -In addition to reading text files, PySpark supports reading ```SequenceFile``` -and any arbitrary ```InputFormat```. +* Details on reading `SequenceFile` and arbitrary Hadoop `InputFormat` are given below. + +### SequenceFile and Hadoop InputFormats **Note** this feature is currently marked ```Experimental``` and is intended for advanced users. It may be replaced in future with read/write support based on SparkSQL, in which case SparkSQL is the preferred approach. From 7dd9fc67a63985493ad0482d307edd56f3af0b9d Mon Sep 17 00:00:00 2001 From: Kan Zhang Date: Sat, 14 Jun 2014 14:31:28 -0700 Subject: [PATCH 05/57] [SPARK-1837] NumericRange should be partitioned in the same way as other... ... sequences Author: Kan Zhang Closes #776 from kanzhang/SPARK-1837 and squashes the following commits: e48f018 [Kan Zhang] [SPARK-1837] code refactoring 67c33b5 [Kan Zhang] minor change 403f9b1 [Kan Zhang] [SPARK-1837] NumericRange should be partitioned in the same way as other sequences --- .../spark/rdd/ParallelCollectionRDD.scala | 31 ++++++++++++------- .../rdd/ParallelCollectionSplitSuite.scala | 18 +++++++++++ 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 2425929fc73c5..66c71bf7e8bb5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -117,6 +117,15 @@ private object ParallelCollectionRDD { if (numSlices < 1) { throw new IllegalArgumentException("Positive number of slices required") } + // Sequences need to be sliced at the same set of index positions for operations + // like RDD.zip() to behave as expected + def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = { + (0 until numSlices).iterator.map(i => { + val start = ((i * length) / numSlices).toInt + val end = (((i + 1) * length) / numSlices).toInt + (start, end) + }) + } seq match { case r: Range.Inclusive => { val sign = if (r.step < 0) { @@ -128,18 +137,17 @@ private object ParallelCollectionRDD { r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) } case r: Range => { - (0 until numSlices).map(i => { - val start = ((i * r.length.toLong) / numSlices).toInt - val end = (((i + 1) * r.length.toLong) / numSlices).toInt - new Range(r.start + start * r.step, r.start + end * r.step, r.step) - }).asInstanceOf[Seq[Seq[T]]] + positions(r.length, numSlices).map({ + case (start, end) => + new Range(r.start + start * r.step, r.start + end * r.step, r.step) + }).toSeq.asInstanceOf[Seq[Seq[T]]] } case nr: NumericRange[_] => { // For ranges of Long, Double, BigInteger, etc val slices = new ArrayBuffer[Seq[T]](numSlices) - val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything var r = nr - for (i <- 0 until numSlices) { + for ((start, end) <- positions(nr.length, numSlices)) { + val sliceSize = end - start slices += r.take(sliceSize).asInstanceOf[Seq[T]] r = r.drop(sliceSize) } @@ -147,11 +155,10 @@ private object ParallelCollectionRDD { } case _ => { val array = seq.toArray // To prevent O(n^2) operations for List etc - (0 until numSlices).map(i => { - val start = ((i * array.length.toLong) / numSlices).toInt - val end = (((i + 1) * array.length.toLong) / numSlices).toInt - array.slice(start, end).toSeq - }) + positions(array.length, numSlices).map({ + case (start, end) => + array.slice(start, end).toSeq + }).toSeq } } } diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index 4df36558b6d4b..1b112f1a41ca9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -111,6 +111,24 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices.forall(_.isInstanceOf[Range])) } + test("identical slice sizes between Range and NumericRange") { + val r = ParallelCollectionRDD.slice(1 to 7, 4) + val nr = ParallelCollectionRDD.slice(1L to 7L, 4) + assert(r.size === 4) + for (i <- 0 until r.size) { + assert(r(i).size === nr(i).size) + } + } + + test("identical slice sizes between List and NumericRange") { + val r = ParallelCollectionRDD.slice(List(1, 2), 4) + val nr = ParallelCollectionRDD.slice(1L to 2L, 4) + assert(r.size === 4) + for (i <- 0 until r.size) { + assert(r(i).size === nr(i).size) + } + } + test("large ranges don't overflow") { val N = 100 * 1000 * 1000 val data = 0 until N From 269fc62b20ee5f9cd60a8f133c29f662d17071b1 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 15 Jun 2014 11:28:34 +0200 Subject: [PATCH 06/57] [SQL] Support transforming TreeNodes with Option children. Thanks goes to @marmbrus for his implementation. Author: Michael Armbrust Author: Zongheng Yang Closes #1074 from concretevitamin/option-treenode and squashes the following commits: ef27b85 [Zongheng Yang] Merge pull request #1 from marmbrus/pr/1074 73133c2 [Michael Armbrust] TreeNodes can't be inner classes. ab78420 [Zongheng Yang] Add a test. 2ccb721 [Michael Armbrust] Add support for transformation of optional children. --- .../spark/sql/catalyst/trees/TreeNode.scala | 19 ++++++++++++- .../sql/catalyst/trees/TreeNodeSuite.scala | 27 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 0369129393a08..cd04bdf02cf84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -187,6 +187,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } + case Some(arg: TreeNode[_]) if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + if (!(newChild fastEquals arg)) { + changed = true + Some(newChild) + } else { + Some(arg) + } case m: Map[_,_] => m case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => @@ -231,6 +239,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } + case Some(arg: TreeNode[_]) if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformUp(rule) + if (!(newChild fastEquals arg)) { + changed = true + Some(newChild) + } else { + Some(arg) + } case m: Map[_,_] => m case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => @@ -273,7 +289,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } catch { case e: java.lang.IllegalArgumentException => throw new TreeNodeException( - this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName?") + this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? " + + s"Exception message: ${e.getMessage}.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 1ddc41a731ff5..6344874538d67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -22,6 +22,17 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{StringType, NullType} + +case class Dummy(optKey: Option[Expression]) extends Expression { + def children = optKey.toSeq + def references = Set.empty[Attribute] + def nullable = true + def dataType = NullType + override lazy val resolved = true + type EvaluatedType = Any + def eval(input: Row) = null.asInstanceOf[Any] +} class TreeNodeSuite extends FunSuite { test("top node changed") { @@ -75,4 +86,20 @@ class TreeNodeSuite extends FunSuite { assert(expected === actual) } + + test("transform works on nodes with Option children") { + val dummy1 = Dummy(Some(Literal("1", StringType))) + val dummy2 = Dummy(None) + val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } + + var actual = dummy1 transformDown toZero + assert(actual === Dummy(Some(Literal(0)))) + + actual = dummy1 transformUp toZero + assert(actual === Dummy(Some(Literal(0)))) + + actual = dummy2 transform toZero + assert(actual === Dummy(None)) + } + } From ca5d9d43b93abd279079b3be8a06fdd78c595510 Mon Sep 17 00:00:00 2001 From: Kan Zhang Date: Sun, 15 Jun 2014 14:55:34 -0700 Subject: [PATCH 07/57] [SPARK-937] adding EXITED executor state and not relaunching cleanly exited executors There seems to be 2 issues. 1. When job is done, driver asks executor to shutdown. However, this clean exit was assigned FAILED executor state by Worker. I introduced EXITED executor state for executors who voluntarily exit (both normal and abnormal exit depending on the exit code). 2. When Master gets notified an executor has exited, it launches another one to replace it, regardless of reason why the executor had exited. When the reason was job has finished, the unnecessary replacement got subsequently killed when App disassociates. This launching and killing of unnecessary executors shows up in the log and is confusing to users. I added check for executor exit status and avoid launching (and subsequent killing) of unnecessary replacements when executors exit cleanly. One could ask the scheduler to tell Master job is done so that Master wouldn't launch the replacement executor. However, there is a race condition between App telling Master job is done and Worker telling Master an executor had exited. There is no guarantee the former will happen before the later. Instead, I chose to check the exit code when executor exits. If the exit code is 0, I assume executor has been asked to shutdown by driver and Master will not launch replacements. Due to race condition, it could also happen that (although didn't happen on my local cluster), Master detects App disassociation event before the executor exits by itself. In such cases, the executor will be rightfully killed and labeled as KILLED, while the App state will show FINISHED. Author: Kan Zhang Closes #306 from kanzhang/SPARK-1118 and squashes the following commits: cb0cc86 [Kan Zhang] [SPARK-937] adding EXITED executor state and not relaunching cleanly exited executors --- .../main/scala/org/apache/spark/deploy/ExecutorState.scala | 4 ++-- .../main/scala/org/apache/spark/deploy/master/Master.scala | 5 +++-- .../org/apache/spark/deploy/worker/ExecutorRunner.scala | 7 +++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala index 37dfa7fec0831..9f34d01e6db48 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -19,9 +19,9 @@ package org.apache.spark.deploy private[spark] object ExecutorState extends Enumeration { - val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value + val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST, EXITED = Value type ExecutorState = Value - def isFinished(state: ExecutorState): Boolean = Seq(KILLED, FAILED, LOST).contains(state) + def isFinished(state: ExecutorState): Boolean = Seq(KILLED, FAILED, LOST, EXITED).contains(state) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index c6dec305bffcb..33ffcbd216954 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -303,10 +303,11 @@ private[spark] class Master( appInfo.removeExecutor(exec) exec.worker.removeExecutor(exec) + val normalExit = exitStatus.exists(_ == 0) // Only retry certain number of times so we don't go into an infinite loop. - if (appInfo.incrementRetryCount < ApplicationState.MAX_NUM_RETRY) { + if (!normalExit && appInfo.incrementRetryCount < ApplicationState.MAX_NUM_RETRY) { schedule() - } else { + } else if (!normalExit) { logError("Application %s with ID %s failed %d times, removing it".format( appInfo.desc.name, appInfo.id, appInfo.retryCount)) removeApplication(appInfo, ApplicationState.FAILED) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index d09136de49807..6433aac1c23e0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -154,11 +154,10 @@ private[spark] class ExecutorRunner( Files.write(header, stderr, Charsets.UTF_8) stderrAppender = FileAppender(process.getErrorStream, stderr, conf) - // Wait for it to exit; this is actually a bad thing if it happens, because we expect to run - // long-lived processes only. However, in the future, we might restart the executor a few - // times on the same machine. + // Wait for it to exit; executor may exit with code 0 (when driver instructs it to shutdown) + // or with nonzero exit code val exitCode = process.waitFor() - state = ExecutorState.FAILED + state = ExecutorState.EXITED val message = "Command exited with code " + exitCode worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)) } catch { From a63aa1adb2dfb19c8189167932ee8569840f96a0 Mon Sep 17 00:00:00 2001 From: CrazyJvm Date: Sun, 15 Jun 2014 23:23:26 -0700 Subject: [PATCH 08/57] SPARK-1999: StorageLevel in storage tab and RDD Storage Info never changes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit StorageLevel in 'storage tab' and 'RDD Storage Info' never changes even if you call rdd.unpersist() and then you give the rdd another different storage level. Author: CrazyJvm Closes #968 from CrazyJvm/ui-storagelevel and squashes the following commits: 62555fa [CrazyJvm] change RDDInfo constructor param 'storageLevel' to var, so there's need to add another variable _storageLevel。 9f1571e [CrazyJvm] JIRA https://issues.apache.org/jira/browse/SPARK-1999 UI : StorageLevel in storage tab and RDD Storage Info never changes --- core/src/main/scala/org/apache/spark/storage/RDDInfo.scala | 6 +++--- .../main/scala/org/apache/spark/storage/StorageUtils.scala | 3 +++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 023fd6e4d8baa..5a72e216872a6 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -26,7 +26,7 @@ class RDDInfo( val id: Int, val name: String, val numPartitions: Int, - val storageLevel: StorageLevel) + var storageLevel: StorageLevel) extends Ordered[RDDInfo] { var numCachedPartitions = 0 @@ -36,8 +36,8 @@ class RDDInfo( override def toString = { import Utils.bytesToString - ("RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; " + - "TachyonSize: %s; DiskSize: %s").format( + ("RDD \"%s\" (%d) StorageLevel: %s; CachedPartitions: %d; TotalPartitions: %d; " + + "MemorySize: %s; TachyonSize: %s; DiskSize: %s").format( name, id, storageLevel.toString, numCachedPartitions, numPartitions, bytesToString(memSize), bytesToString(tachyonSize), bytesToString(diskSize)) } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 6f3252a2f6d31..f3bde1df45c79 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -89,10 +89,13 @@ private[spark] object StorageUtils { // Add up memory, disk and Tachyon sizes val persistedBlocks = blocks.filter { status => status.memSize + status.diskSize + status.tachyonSize > 0 } + val _storageLevel = + if (persistedBlocks.length > 0) persistedBlocks(0).storageLevel else StorageLevel.NONE val memSize = persistedBlocks.map(_.memSize).reduceOption(_ + _).getOrElse(0L) val diskSize = persistedBlocks.map(_.diskSize).reduceOption(_ + _).getOrElse(0L) val tachyonSize = persistedBlocks.map(_.tachyonSize).reduceOption(_ + _).getOrElse(0L) rddInfoMap.get(rddId).map { rddInfo => + rddInfo.storageLevel = _storageLevel rddInfo.numCachedPartitions = persistedBlocks.length rddInfo.memSize = memSize rddInfo.diskSize = diskSize From 9672ee07fb1c3583c70f23a699de3b2282eb0f98 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Sun, 15 Jun 2014 23:32:55 -0700 Subject: [PATCH 09/57] SPARK-2148 Add link to requirements for custom equals() and hashcode() methods https://issues.apache.org/jira/browse/SPARK-2148 Author: Andrew Ash Closes #1092 from ash211/SPARK-2148 and squashes the following commits: 93513df [Andrew Ash] SPARK-2148 Add link to requirements for custom equals() and hashcode() methods --- docs/programming-guide.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index ef0c0e34301f3..0b24a8b88b3cc 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -762,6 +762,11 @@ val counts = pairs.reduceByKey((a, b) => a + b) We could also use `counts.sortByKey()`, for example, to sort the pairs alphabetically, and finally `counts.collect()` to bring them back to the driver program as an array of objects. +**Note:** when using custom objects as the key in key-value pair operations, you must be sure that a +custom `equals()` method is accompanied with a matching `hashCode()` method. For full details, see +the contract outlined in the [Object.hashCode() +documentation](http://docs.oracle.com/javase/7/docs/api/java/lang/Object.html#hashCode()). +
@@ -794,6 +799,10 @@ JavaPairRDD counts = pairs.reduceByKey((a, b) -> a + b); We could also use `counts.sortByKey()`, for example, to sort the pairs alphabetically, and finally `counts.collect()` to bring them back to the driver program as an array of objects. +**Note:** when using custom objects as the key in key-value pair operations, you must be sure that a +custom `equals()` method is accompanied with a matching `hashCode()` method. For full details, see +the contract outlined in the [Object.hashCode() +documentation](http://docs.oracle.com/javase/7/docs/api/java/lang/Object.html#hashCode()).
From 119b06a04f6df3949b3b074a18f791bbc732ac31 Mon Sep 17 00:00:00 2001 From: Ali Ghodsi Date: Sun, 15 Jun 2014 23:44:30 -0700 Subject: [PATCH 10/57] Updating docs to include missing information about reducers and clarify ... ...how the OFFHEAP storage level works (there has been confusion around this). Author: Ali Ghodsi Closes #1089 from alig/master and squashes the following commits: ca8114d [Ali Ghodsi] Updating docs to include missing information about reducers and clarify how the OFFHEAP storage level works (there has been confusion around this). --- docs/programming-guide.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 0b24a8b88b3cc..65d75b85efda6 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -899,7 +899,7 @@ for details. reduceByKey(func, [numTasks]) - When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. + When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function func, which must be of type (V,V) => V. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) @@ -1067,7 +1067,10 @@ storage levels is: Store RDD in serialized format in Tachyon. Compared to MEMORY_ONLY_SER, OFF_HEAP reduces garbage collection overhead and allows executors to be smaller and to share a pool of memory, making it attractive in environments with - large heaps or multiple concurrent applications. + large heaps or multiple concurrent applications. Furthermore, as the RDDs reside in Tachyon, + the crash of an executor does not lead to losing the in-memory cache. In this mode, the memory + in Tachyon is discardable. Thus, Tachyon does not attempt to reconstruct a block that it evicts + from memory. From 716c88aa147762f7f617adf34a17edd681d9a4ff Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sun, 15 Jun 2014 23:47:58 -0700 Subject: [PATCH 11/57] SPARK-2039: apply output dir existence checking for all output formats https://issues.apache.org/jira/browse/SPARK-2039 apply output dir existence checking for all output formats Author: CodingCat Closes #1088 from CodingCat/SPARK-2039 and squashes the following commits: c52747a [CodingCat] apply output dir existence checking for all output formats --- .../main/scala/org/apache/spark/rdd/PairRDDFunctions.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index b6ad9b6c3e168..fe36c80e0be84 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -787,8 +787,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance - if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true) && - jobFormat.isInstanceOf[NewFileOutputFormat[_, _]]) { + if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) { // FileOutputFormat ignores the filesystem parameter jobFormat.checkOutputSpecs(job) } @@ -854,8 +853,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + valueClass.getSimpleName + ")") - if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true) && - outputFormatInstance.isInstanceOf[FileOutputFormat[_, _]]) { + if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) { // FileOutputFormat ignores the filesystem parameter val ignoredFs = FileSystem.get(conf) conf.getOutputFormat.checkOutputSpecs(ignoredFs, conf) From 4fdb491775bb9c4afa40477dc0069ff6fcadfe25 Mon Sep 17 00:00:00 2001 From: Kan Zhang Date: Mon, 16 Jun 2014 11:11:29 -0700 Subject: [PATCH 12/57] [SPARK-2010] Support for nested data in PySpark SQL JIRA issue https://issues.apache.org/jira/browse/SPARK-2010 This PR adds support for nested collection types in PySpark SQL, including array, dict, list, set, and tuple. Example, ``` >>> from array import array >>> from pyspark.sql import SQLContext >>> sqlCtx = SQLContext(sc) >>> rdd = sc.parallelize([ ... {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]) >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}] True >>> rdd = sc.parallelize([ ... {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]) >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.collect() == \ ... [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}] True ``` Author: Kan Zhang Closes #1041 from kanzhang/SPARK-2010 and squashes the following commits: 1b2891d [Kan Zhang] [SPARK-2010] minor doc change and adding a TODO 504f27e [Kan Zhang] [SPARK-2010] Support for nested data in PySpark SQL --- python/pyspark/sql.py | 22 +++++++++++++- .../org/apache/spark/sql/SQLContext.scala | 29 ++++++++++++------- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index e344610b1fe4d..c31d49ce837fc 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -77,12 +77,25 @@ def inferSchema(self, rdd): """Infer and apply a schema to an RDD of L{dict}s. We peek at the first row of the RDD to determine the fields names - and types, and then use that to extract all the dictionaries. + and types, and then use that to extract all the dictionaries. Nested + collections are supported, which include array, dict, list, set, and + tuple. >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, ... {"field1" : 3, "field2": "row3"}] True + + >>> from array import array + >>> srdd = sqlCtx.inferSchema(nestedRdd1) + >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, + ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}] + True + + >>> srdd = sqlCtx.inferSchema(nestedRdd2) + >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, + ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}] + True """ if (rdd.__class__ is SchemaRDD): raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__) @@ -413,6 +426,7 @@ def subtract(self, other, numPartitions=None): def _test(): import doctest + from array import array from pyspark.context import SparkContext globs = globals().copy() # The small batch size here ensures that we see multiple batches, @@ -422,6 +436,12 @@ def _test(): globs['sqlCtx'] = SQLContext(sc) globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]) + globs['nestedRdd1'] = sc.parallelize([ + {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, + {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]) + globs['nestedRdd2'] = sc.parallelize([ + {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, + {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]) (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 378ff54531118..131c130bbb3e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -298,19 +298,28 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * Peek at the first row of the RDD and infer its schema. - * TODO: We only support primitive types, add support for nested types. + * TODO: consolidate this with the type system developed in SPARK-2060. */ private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { + import scala.collection.JavaConversions._ + def typeFor(obj: Any): DataType = obj match { + case c: java.lang.String => StringType + case c: java.lang.Integer => IntegerType + case c: java.lang.Long => LongType + case c: java.lang.Double => DoubleType + case c: java.lang.Boolean => BooleanType + case c: java.util.List[_] => ArrayType(typeFor(c.head)) + case c: java.util.Set[_] => ArrayType(typeFor(c.head)) + case c: java.util.Map[_, _] => + val (key, value) = c.head + MapType(typeFor(key), typeFor(value)) + case c if c.getClass.isArray => + val elem = c.asInstanceOf[Array[_]].head + ArrayType(typeFor(elem)) + case c => throw new Exception(s"Object of type $c cannot be used") + } val schema = rdd.first().map { case (fieldName, obj) => - val dataType = obj.getClass match { - case c: Class[_] if c == classOf[java.lang.String] => StringType - case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType - case c: Class[_] if c == classOf[java.lang.Long] => LongType - case c: Class[_] if c == classOf[java.lang.Double] => DoubleType - case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType - case c => throw new Exception(s"Object of type $c cannot be used") - } - AttributeReference(fieldName, dataType, true)() + AttributeReference(fieldName, typeFor(obj), true)() }.toSeq val rowRdd = rdd.mapPartitions { iter => From cdf2b04570871848442ca9f9e2316a37e4aaaae0 Mon Sep 17 00:00:00 2001 From: witgo Date: Mon, 16 Jun 2014 14:27:31 -0500 Subject: [PATCH 13/57] [SPARK-1930] The Container is running beyond physical memory limits, so as to be killed Author: witgo Closes #894 from witgo/SPARK-1930 and squashes the following commits: 564307e [witgo] Update the running-on-yarn.md 3747515 [witgo] Merge branch 'master' of https://github.com/apache/spark into SPARK-1930 172647b [witgo] add memoryOverhead docs a0ff545 [witgo] leaving only two configs a17bda2 [witgo] Merge branch 'master' of https://github.com/apache/spark into SPARK-1930 478ca15 [witgo] Merge branch 'master' into SPARK-1930 d1244a1 [witgo] Merge branch 'master' into SPARK-1930 8b967ae [witgo] Merge branch 'master' into SPARK-1930 655a820 [witgo] review commit 71859a7 [witgo] Merge branch 'master' of https://github.com/apache/spark into SPARK-1930 e3c531d [witgo] review commit e16f190 [witgo] different memoryOverhead ffa7569 [witgo] review commit 5c9581f [witgo] Merge branch 'master' into SPARK-1930 9a6bcf2 [witgo] review commit 8fae45a [witgo] fix NullPointerException e0dcc16 [witgo] Adding configuration items b6a989c [witgo] Fix container memory beyond limit, were killed --- docs/running-on-yarn.md | 14 ++++++++++++++ .../org/apache/spark/deploy/yarn/Client.scala | 4 ++-- .../spark/deploy/yarn/ExecutorLauncher.scala | 4 +++- .../spark/deploy/yarn/YarnAllocationHandler.scala | 12 ++++++++---- .../org/apache/spark/deploy/yarn/ClientBase.scala | 14 +++++++++----- .../org/apache/spark/deploy/yarn/Client.scala | 4 ++-- .../spark/deploy/yarn/YarnAllocationHandler.scala | 12 ++++++++---- 7 files changed, 46 insertions(+), 18 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index af1788f2aa151..4243ef480ba39 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -67,6 +67,20 @@ Most of the configs are the same for Spark on YARN as for other deployment modes The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + + spark.yarn.executor.memoryOverhead + 384 + + The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. + + + + spark.yarn.driver.memoryOverhead + 384 + + The amount of off heap memory (in megabytes) to be allocated per driver. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. + + By default, Spark on YARN will use a Spark jar installed locally, but the Spark JAR can also be in a world-readable location on HDFS. This allows YARN to cache it on nodes so that it doesn't need to be distributed each time an application runs. To point to a JAR on HDFS, `export SPARK_JAR=hdfs:///some/path`. diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 4ccddc214c8ad..82f79d88a3009 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -71,7 +71,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] // Memory for the ApplicationMaster. - capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + capability.setMemory(args.amMemory + memoryOverhead) amContainer.setResource(capability) appContext.setQueue(args.amQueue) @@ -115,7 +115,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa val minResMemory = newApp.getMinimumResourceCapability().getMemory() val amMemory = ((args.amMemory / minResMemory) * minResMemory) + ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) - - YarnAllocationHandler.MEMORY_OVERHEAD) + memoryOverhead) amMemory } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index b6ecae1e652fe..bfdb6232f5113 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -92,13 +92,15 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp appAttemptId = getApplicationAttemptId() resourceManager = registerWithResourceManager() + val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() // Compute number of threads for akka val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() if (minimumMemory > 0) { - val mem = args.executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD + val mem = args.executorMemory + sparkConf.getInt("spark.yarn.executor.memoryOverhead", + YarnAllocationHandler.MEMORY_OVERHEAD) val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) if (numCore > 0) { diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 856391e52b2df..80e0162e9f277 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -88,6 +88,10 @@ private[yarn] class YarnAllocationHandler( // Containers to be released in next request to RM private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] + // Additional memory overhead - in mb. + private def memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", + YarnAllocationHandler.MEMORY_OVERHEAD) + private val numExecutorsRunning = new AtomicInteger() // Used to generate a unique id per executor private val executorIdCounter = new AtomicInteger() @@ -99,7 +103,7 @@ private[yarn] class YarnAllocationHandler( def getNumExecutorsFailed: Int = numExecutorsFailed.intValue def isResourceConstraintSatisfied(container: Container): Boolean = { - container.getResource.getMemory >= (executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + container.getResource.getMemory >= (executorMemory + memoryOverhead) } def allocateContainers(executorsToRequest: Int) { @@ -229,7 +233,7 @@ private[yarn] class YarnAllocationHandler( val containerId = container.getId assert( container.getResource.getMemory >= - (executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD)) + (executorMemory + memoryOverhead)) if (numExecutorsRunningNow > maxExecutors) { logInfo("""Ignoring container %s at host %s, since we already have the required number of @@ -450,7 +454,7 @@ private[yarn] class YarnAllocationHandler( if (numExecutors > 0) { logInfo("Allocating %d executor containers with %d of memory each.".format(numExecutors, - executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD)) + executorMemory + memoryOverhead)) } else { logDebug("Empty allocation req .. release : " + releasedContainerList) } @@ -505,7 +509,7 @@ private[yarn] class YarnAllocationHandler( val rsrcRequest = Records.newRecord(classOf[ResourceRequest]) val memCapability = Records.newRecord(classOf[Resource]) // There probably is some overhead here, let's reserve a bit more memory. - memCapability.setMemory(executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + memCapability.setMemory(executorMemory + memoryOverhead) rsrcRequest.setCapability(memCapability) val pri = Records.newRecord(classOf[Priority]) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 6861b503000ca..858bcaa95b409 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -65,6 +65,10 @@ trait ClientBase extends Logging { val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(Integer.parseInt("644", 8).toShort) + // Additional memory overhead - in mb. + protected def memoryOverhead: Int = sparkConf.getInt("spark.yarn.driver.memoryOverhead", + YarnAllocationHandler.MEMORY_OVERHEAD) + // TODO(harvey): This could just go in ClientArguments. def validateArgs() = { Map( @@ -72,10 +76,10 @@ trait ClientBase extends Logging { "Error: You must specify a user jar when running in standalone mode!"), (args.userClass == null) -> "Error: You must specify a user class!", (args.numExecutors <= 0) -> "Error: You must specify at least 1 executor!", - (args.amMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: AM memory size must be" + - "greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD), - (args.executorMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: Executor memory size" + - "must be greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD.toString) + (args.amMemory <= memoryOverhead) -> ("Error: AM memory size must be" + + "greater than: " + memoryOverhead), + (args.executorMemory <= memoryOverhead) -> ("Error: Executor memory size" + + "must be greater than: " + memoryOverhead.toString) ).foreach { case(cond, errStr) => if (cond) { logError(errStr) @@ -101,7 +105,7 @@ trait ClientBase extends Logging { logError(errorMessage) throw new IllegalArgumentException(errorMessage) } - val amMem = args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD + val amMem = args.amMemory + memoryOverhead if (amMem > maxMem) { val errorMessage = "Required AM memory (%d) is above the max threshold (%d) of this cluster." diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 80a8bceb17269..15f3c4f180ea3 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -84,7 +84,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa // Memory for the ApplicationMaster. val memoryResource = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] - memoryResource.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + memoryResource.setMemory(args.amMemory + memoryOverhead) appContext.setResource(memoryResource) // Finally, submit and monitor the application. @@ -117,7 +117,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa // val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory() // var amMemory = ((args.amMemory / minResMemory) * minResMemory) + // ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) - - // YarnAllocationHandler.MEMORY_OVERHEAD) + // memoryOverhead ) args.amMemory } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index a979fe4d62630..29ccec2adcac3 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -90,6 +90,10 @@ private[yarn] class YarnAllocationHandler( // Containers to be released in next request to RM private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] + // Additional memory overhead - in mb. + private def memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", + YarnAllocationHandler.MEMORY_OVERHEAD) + // Number of container requests that have been sent to, but not yet allocated by the // ApplicationMaster. private val numPendingAllocate = new AtomicInteger() @@ -106,7 +110,7 @@ private[yarn] class YarnAllocationHandler( def getNumExecutorsFailed: Int = numExecutorsFailed.intValue def isResourceConstraintSatisfied(container: Container): Boolean = { - container.getResource.getMemory >= (executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + container.getResource.getMemory >= (executorMemory + memoryOverhead) } def releaseContainer(container: Container) { @@ -248,7 +252,7 @@ private[yarn] class YarnAllocationHandler( val executorHostname = container.getNodeId.getHost val containerId = container.getId - val executorMemoryOverhead = (executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + val executorMemoryOverhead = (executorMemory + memoryOverhead) assert(container.getResource.getMemory >= executorMemoryOverhead) if (numExecutorsRunningNow > maxExecutors) { @@ -477,7 +481,7 @@ private[yarn] class YarnAllocationHandler( numPendingAllocate.addAndGet(numExecutors) logInfo("Will Allocate %d executor containers, each with %d memory".format( numExecutors, - (executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD))) + (executorMemory + memoryOverhead))) } else { logDebug("Empty allocation request ...") } @@ -537,7 +541,7 @@ private[yarn] class YarnAllocationHandler( priority: Int ): ArrayBuffer[ContainerRequest] = { - val memoryRequest = executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD + val memoryRequest = executorMemory + memoryOverhead val resource = Resource.newInstance(memoryRequest, executorCores) val prioritySetting = Records.newRecord(classOf[Priority]) From 273afcb254fb5384204c56bdcb3b9b760bcfab3f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 16 Jun 2014 21:30:29 +0200 Subject: [PATCH 14/57] [SQL][SPARK-2094] Follow up of PR #1071 for Java API Updated `JavaSQLContext` and `JavaHiveContext` similar to what we've done to `SQLContext` and `HiveContext` in PR #1071. Added corresponding test case for Spark SQL Java API. Author: Cheng Lian Closes #1085 from liancheng/spark-2094-java and squashes the following commits: 29b8a51 [Cheng Lian] Avoided instantiating JavaSparkContext & JavaHiveContext to workaround test failure 92bb4fb [Cheng Lian] Marked test cases in JavaHiveQLSuite with "ignore" 22aec97 [Cheng Lian] Follow up of PR #1071 for Java API --- .../spark/sql/api/java/JavaSQLContext.scala | 16 +-- .../sql/hive/api/java/JavaHiveContext.scala | 10 +- .../sql/hive/api/java/JavaHiveQLSuite.scala | 101 ++++++++++++++++++ .../sql/hive/api/java/JavaHiveSuite.scala | 41 ------- .../sql/hive/execution/HiveQuerySuite.scala | 30 +++--- 5 files changed, 124 insertions(+), 74 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 6f7d431b9a819..352260fa15bbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -40,19 +40,13 @@ class JavaSQLContext(val sqlContext: SQLContext) { /** * Executes a query expressed in SQL, returning the result as a JavaSchemaRDD */ - def sql(sqlQuery: String): JavaSchemaRDD = { - val result = new JavaSchemaRDD(sqlContext, sqlContext.parseSql(sqlQuery)) - // We force query optimization to happen right away instead of letting it happen lazily like - // when using the query DSL. This is so DDL commands behave as expected. This is only - // generates the RDD lineage for DML queries, but do not perform any execution. - result.queryExecution.toRdd - result - } + def sql(sqlQuery: String): JavaSchemaRDD = + new JavaSchemaRDD(sqlContext, sqlContext.parseSql(sqlQuery)) /** * :: Experimental :: * Creates an empty parquet file with the schema of class `beanClass`, which can be registered as - * a table. This registered table can be used as the target of future insertInto` operations. + * a table. This registered table can be used as the target of future `insertInto` operations. * * {{{ * JavaSQLContext sqlCtx = new JavaSQLContext(...) @@ -62,7 +56,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { * }}} * * @param beanClass A java bean class object that will be used to determine the schema of the - * parquet file. s + * parquet file. * @param path The path where the directory containing parquet metadata should be created. * Data inserted into this table will also be stored at this location. * @param allowExisting When false, an exception will be thrown if this directory already exists. @@ -100,14 +94,12 @@ class JavaSQLContext(val sqlContext: SQLContext) { new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) } - /** * Loads a parquet file, returning the result as a [[JavaSchemaRDD]]. */ def parquetFile(path: String): JavaSchemaRDD = new JavaSchemaRDD(sqlContext, ParquetRelation(path)) - /** * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only * during the lifetime of this instance of SQLContext. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala index 6df76fa825101..c9ee162191c96 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala @@ -31,12 +31,6 @@ class JavaHiveContext(sparkContext: JavaSparkContext) extends JavaSQLContext(spa /** * Executes a query expressed in HiveQL, returning the result as a JavaSchemaRDD. */ - def hql(hqlQuery: String): JavaSchemaRDD = { - val result = new JavaSchemaRDD(sqlContext, HiveQl.parseSql(hqlQuery)) - // We force query optimization to happen right away instead of letting it happen lazily like - // when using the query DSL. This is so DDL commands behave as expected. This is only - // generates the RDD lineage for DML queries, but do not perform any execution. - result.queryExecution.toRdd - result - } + def hql(hqlQuery: String): JavaSchemaRDD = + new JavaSchemaRDD(sqlContext, HiveQl.parseSql(hqlQuery)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala new file mode 100644 index 0000000000000..3b9cd8f52de4e --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.api.java + +import scala.util.Try + +import org.scalatest.FunSuite + +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql.api.java.JavaSchemaRDD +import org.apache.spark.sql.execution.ExplainCommand +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.TestSQLContext + +// Implicits +import scala.collection.JavaConversions._ + +class JavaHiveQLSuite extends FunSuite { + lazy val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext) + + // There is a little trickery here to avoid instantiating two HiveContexts in the same JVM + lazy val javaHiveCtx = new JavaHiveContext(javaCtx) { + override val sqlContext = TestHive + } + + ignore("SELECT * FROM src") { + assert( + javaHiveCtx.hql("SELECT * FROM src").collect().map(_.getInt(0)) === + TestHive.sql("SELECT * FROM src").collect().map(_.getInt(0)).toSeq) + } + + private val explainCommandClassName = + classOf[ExplainCommand].getSimpleName.stripSuffix("$") + + def isExplanation(result: JavaSchemaRDD) = { + val explanation = result.collect().map(_.getString(0)) + explanation.size == 1 && explanation.head.startsWith(explainCommandClassName) + } + + ignore("Query Hive native command execution result") { + val tableName = "test_native_commands" + + assertResult(0) { + javaHiveCtx.hql(s"DROP TABLE IF EXISTS $tableName").count() + } + + assertResult(0) { + javaHiveCtx.hql(s"CREATE TABLE $tableName(key INT, value STRING)").count() + } + + javaHiveCtx.hql("SHOW TABLES").registerAsTable("show_tables") + + assert( + javaHiveCtx + .hql("SELECT result FROM show_tables") + .collect() + .map(_.getString(0)) + .contains(tableName)) + + assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { + javaHiveCtx.hql(s"DESCRIBE $tableName").registerAsTable("describe_table") + + javaHiveCtx + .hql("SELECT result FROM describe_table") + .collect() + .map(_.getString(0).split("\t").map(_.trim)) + .toArray + } + + assert(isExplanation(javaHiveCtx.hql( + s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) + + TestHive.reset() + } + + ignore("Exactly once semantics for DDL and command statements") { + val tableName = "test_exactly_once" + val q0 = javaHiveCtx.hql(s"CREATE TABLE $tableName(key INT, value STRING)") + + // If the table was not created, the following assertion would fail + assert(Try(TestHive.table(tableName)).isSuccess) + + // If the CREATE TABLE command got executed again, the following assertion would fail + assert(Try(q0.count()).isSuccess) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveSuite.scala deleted file mode 100644 index 9c5d7c81f7c09..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveSuite.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.api.java - -import org.scalatest.FunSuite - -import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.hive.test.TestHive - -// Implicits -import scala.collection.JavaConversions._ - -class JavaHiveSQLSuite extends FunSuite { - ignore("SELECT * FROM src") { - val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext) - // There is a little trickery here to avoid instantiating two HiveContexts in the same JVM - val javaSqlCtx = new JavaHiveContext(javaCtx) { - override val sqlContext = TestHive - } - - assert( - javaSqlCtx.hql("SELECT * FROM src").collect().map(_.getInt(0)) === - TestHive.sql("SELECT * FROM src").collect().map(_.getInt(0)).toSeq) - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 0d656c556965d..6e8d11b8a1300 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -184,25 +184,29 @@ class HiveQuerySuite extends HiveComparisonTest { test("Query Hive native command execution result") { val tableName = "test_native_commands" - val q0 = hql(s"DROP TABLE IF EXISTS $tableName") - assert(q0.count() == 0) + assertResult(0) { + hql(s"DROP TABLE IF EXISTS $tableName").count() + } - val q1 = hql(s"CREATE TABLE $tableName(key INT, value STRING)") - assert(q1.count() == 0) + assertResult(0) { + hql(s"CREATE TABLE $tableName(key INT, value STRING)").count() + } - val q2 = hql("SHOW TABLES") - val tables = q2.select('result).collect().map { case Row(table: String) => table } - assert(tables.contains(tableName)) + assert( + hql("SHOW TABLES") + .select('result) + .collect() + .map(_.getString(0)) + .contains(tableName)) - val q3 = hql(s"DESCRIBE $tableName") assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { - q3.select('result).collect().map { case Row(fieldDesc: String) => - fieldDesc.split("\t").map(_.trim) - } + hql(s"DESCRIBE $tableName") + .select('result) + .collect() + .map(_.getString(0).split("\t").map(_.trim)) } - val q4 = hql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key") - assert(isExplanation(q4)) + assert(isExplanation(hql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) TestHive.reset() } From 237b96bc59ab1b54c31d06a5260cd77e1eb96116 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 16 Jun 2014 16:42:17 -0700 Subject: [PATCH 15/57] Minor fix: made "EXPLAIN" output to play well with JDBC output format Fixed the broken JDBC output. Test from Shark `beeline`: ``` beeline> !connect jdbc:hive2://localhost:10000/ scan complete in 2ms Connecting to jdbc:hive2://localhost:10000/ Enter username for jdbc:hive2://localhost:10000/: lian Enter password for jdbc:hive2://localhost:10000/: Connected to: Hive (version 0.12.0) Driver: Hive (version 0.12.0) Transaction isolation: TRANSACTION_REPEATABLE_READ 0: jdbc:hive2://localhost:10000/> 0: jdbc:hive2://localhost:10000/> explain select * from src; +-------------------------------------------------------------------------------+ | plan | +-------------------------------------------------------------------------------+ | ExplainCommand [plan#2:0] | | HiveTableScan [key#0,value#1], (MetastoreRelation default, src, None), None | +-------------------------------------------------------------------------------+ 2 rows selected (1.386 seconds) ``` Before this change, the output looked something like this: ``` +-------------------------------------------------------------------------------+ | plan | +-------------------------------------------------------------------------------+ | ExplainCommand [plan#2:0] HiveTableScan [key#0,value#1], (MetastoreRelation default, src, None), None | +-------------------------------------------------------------------------------+ ``` Author: Cheng Lian Closes #1097 from liancheng/multiLineExplain and squashes the following commits: eb37967 [Cheng Lian] Made output of "EXPLAIN" play well with JDBC output format --- .../main/scala/org/apache/spark/sql/execution/commands.scala | 4 ++-- .../org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala | 2 +- .../org/apache/spark/sql/hive/execution/HiveQuerySuite.scala | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 0377290af5926..39b3246c875df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -83,8 +83,8 @@ case class ExplainCommand( override protected[sql] lazy val sideEffectResult: Seq[String] = this.toString.split("\n") def execute(): RDD[Row] = { - val explanation = sideEffectResult.mkString("\n") - context.sparkContext.parallelize(Seq(new GenericRow(Array[Any](explanation))), 1) + val explanation = sideEffectResult.map(row => new GenericRow(Array[Any](row))) + context.sparkContext.parallelize(explanation, 1) } override def otherCopyArgs = context :: Nil diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala index 3b9cd8f52de4e..10c8069a624e6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala @@ -49,7 +49,7 @@ class JavaHiveQLSuite extends FunSuite { def isExplanation(result: JavaSchemaRDD) = { val explanation = result.collect().map(_.getString(0)) - explanation.size == 1 && explanation.head.startsWith(explainCommandClassName) + explanation.size > 1 && explanation.head.startsWith(explainCommandClassName) } ignore("Query Hive native command execution result") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6e8d11b8a1300..04652587f9073 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -169,7 +169,7 @@ class HiveQuerySuite extends HiveComparisonTest { def isExplanation(result: SchemaRDD) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } - explanation.size == 1 && explanation.head.startsWith(explainCommandClassName) + explanation.size > 1 && explanation.head.startsWith(explainCommandClassName) } test("SPARK-1704: Explain commands as a SchemaRDD") { From 7afa912e747c77ebfd10bddf7bda2e3190fdeb9c Mon Sep 17 00:00:00 2001 From: Anatoli Fomenko Date: Mon, 16 Jun 2014 23:10:36 -0700 Subject: [PATCH 16/57] MLlib documentation fix Synchronized mllib-optimization.md with Spark Scaladoc: removed reference to GradientDescent.runMiniBatchSGD method This is a temporary fix to remove a link from http://spark.apache.org/docs/latest/mllib-optimization.html to GradientDescent.runMiniBatchSGD which is not in the current online GradientDescent Scaladoc. FIXME: revert this commit after GradientDescent Scaladoc is updated. See images for details. ![mllib-docs-fix-1](https://cloud.githubusercontent.com/assets/1375501/3294410/ccf19bb8-f5a8-11e3-93f1-f593016209eb.png) ![mllib-docs-fix-2](https://cloud.githubusercontent.com/assets/1375501/3294411/d0b59a7e-f5a8-11e3-8fc8-329c177ef8c8.png) Author: Anatoli Fomenko Closes #1098 from afomenko/master and squashes the following commits: 5cb0758 [Anatoli Fomenko] MLlib documentation fix --- docs/mllib-optimization.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index 97e8f4e9661b6..ae9ede58e8e60 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -147,9 +147,9 @@ are developed, see the linear methods section for example. -The SGD method -[GradientDescent.runMiniBatchSGD](api/scala/index.html#org.apache.spark.mllib.optimization.GradientDescent) -has the following parameters: +The SGD class +[GradientDescent](api/scala/index.html#org.apache.spark.mllib.optimization.GradientDescent) +sets the following parameters: * `Gradient` is a class that computes the stochastic gradient of the function being optimized, i.e., with respect to a single training example, at the @@ -171,7 +171,7 @@ each iteration, to compute the gradient direction. Available algorithms for gradient descent: -* [GradientDescent.runMiniBatchSGD](api/scala/index.html#org.apache.spark.mllib.optimization.GradientDescent) +* [GradientDescent](api/scala/index.html#org.apache.spark.mllib.optimization.GradientDescent) ### L-BFGS L-BFGS is currently only a low-level optimization primitive in `MLlib`. If you want to use L-BFGS in various From d81c08bac9756045865ed6490252fbb3f7591142 Mon Sep 17 00:00:00 2001 From: Kan Zhang Date: Mon, 16 Jun 2014 23:31:31 -0700 Subject: [PATCH 17/57] [SPARK-2130] End-user friendly String repr for StorageLevel in Python JIRA issue https://issues.apache.org/jira/browse/SPARK-2130 This PR adds an end-user friendly String representation for StorageLevel in Python, similar to ```StorageLevel.description``` in Scala. ``` >>> rdd = sc.parallelize([1,2]) >>> storage_level = rdd.getStorageLevel() >>> storage_level StorageLevel(False, False, False, False, 1) >>> print(storage_level) Serialized 1x Replicated ``` Author: Kan Zhang Closes #1096 from kanzhang/SPARK-2130 and squashes the following commits: 7c8b98b [Kan Zhang] [SPARK-2130] Prettier epydoc output cc5bf45 [Kan Zhang] [SPARK-2130] End-user friendly String representation for StorageLevel in Python --- python/pyspark/rdd.py | 3 +++ python/pyspark/storagelevel.py | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ddd22850a819c..bb4d035edcdeb 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1448,9 +1448,12 @@ def toDebugString(self): def getStorageLevel(self): """ Get the RDD's current storage level. + >>> rdd1 = sc.parallelize([1,2]) >>> rdd1.getStorageLevel() StorageLevel(False, False, False, False, 1) + >>> print(rdd1.getStorageLevel()) + Serialized 1x Replicated """ java_storage_level = self._jrdd.getStorageLevel() storage_level = StorageLevel(java_storage_level.useDisk(), diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index 7b6660eab231b..3a18ea54eae4c 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -36,6 +36,15 @@ def __repr__(self): return "StorageLevel(%s, %s, %s, %s, %s)" % ( self.useDisk, self.useMemory, self.useOffHeap, self.deserialized, self.replication) + def __str__(self): + result = "" + result += "Disk " if self.useDisk else "" + result += "Memory " if self.useMemory else "" + result += "Tachyon " if self.useOffHeap else "" + result += "Deserialized " if self.deserialized else "Serialized " + result += "%sx Replicated" % self.replication + return result + StorageLevel.DISK_ONLY = StorageLevel(True, False, False, False) StorageLevel.DISK_ONLY_2 = StorageLevel(True, False, False, False, 2) StorageLevel.MEMORY_ONLY = StorageLevel(False, True, False, True) From 8cd04c3eecc2dd827ea163dcd5e08af9912fa323 Mon Sep 17 00:00:00 2001 From: Anant Date: Mon, 16 Jun 2014 23:42:27 -0700 Subject: [PATCH 18/57] SPARK-1990: added compatibility for python 2.6 for ssh_read command https://issues.apache.org/jira/browse/SPARK-1990 There were some posts on the lists that spark-ec2 does not work with Python 2.6. In addition, we should check the Python version at the top of the script and exit if it's too old Author: Anant Closes #941 from anantasty/SPARK-1990 and squashes the following commits: 4ca441d [Anant] Implmented check_optput withinthe module to work with python 2.6 c6ed85c [Anant] added compatibility for python 2.6 for ssh_read command --- ec2/spark_ec2.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 52a89cb2481ca..803caa0c480e7 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -689,9 +689,23 @@ def ssh(host, opts, command): time.sleep(30) tries = tries + 1 +# Backported from Python 2.7 for compatiblity with 2.6 (See SPARK-1990) +def _check_output(*popenargs, **kwargs): + if 'stdout' in kwargs: + raise ValueError('stdout argument not allowed, it will be overridden.') + process = subprocess.Popen(stdout=PIPE, *popenargs, **kwargs) + output, unused_err = process.communicate() + retcode = process.poll() + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + raise subprocess.CalledProcessError(retcode, cmd, output=output) + return output + def ssh_read(host, opts, command): - return subprocess.check_output( + return _check_output( ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)]) From 23a12ce20c55653b08b16e6159ab31d2ca88acf1 Mon Sep 17 00:00:00 2001 From: Daniel Darabos Date: Tue, 17 Jun 2014 00:08:05 -0700 Subject: [PATCH 19/57] SPARK-2035: Store call stack for stages, display it on the UI. I'm not sure about the test -- I get a lot of unrelated failures for some reason. I'll try to sort it out. But hopefully the automation will test this for me if I send a pull request :). I'll attach a demo HTML in [Jira](https://issues.apache.org/jira/browse/SPARK-2035). Author: Daniel Darabos Author: Patrick Wendell Closes #981 from darabos/darabos-call-stack and squashes the following commits: f7c6bfa [Daniel Darabos] Fix bad merge. I undid 83c226d454 by Doris. 3d0a48d [Daniel Darabos] Merge remote-tracking branch 'upstream/master' into darabos-call-stack b857849 [Daniel Darabos] Style: Break long line. ecb5690 [Daniel Darabos] Include the last Spark method in the full stack trace. Otherwise it is not visible if the stage name is overridden. d00a85b [Patrick Wendell] Make call sites for stages non-optional and well defined b9eba24 [Daniel Darabos] Make StageInfo.details non-optional. Add JSON serialization code for the new field. Verify JSON backward compatibility. 4312828 [Daniel Darabos] Remove Mima excludes for CallSite. They should be unnecessary now, with SPARK-2070 fixed. 0920750 [Daniel Darabos] Merge remote-tracking branch 'upstream/master' into darabos-call-stack a4b1faf [Daniel Darabos] Add Mima exclusions for the CallSite changes it has picked up. They are private methods/classes, so we ought to be safe. 932f810 [Daniel Darabos] Use empty CallSite instead of null in DAGSchedulerSuite. Outside of testing, this parameter always originates in SparkContext.scala, and will never be null. ccd89d1 [Daniel Darabos] Fix long lines. ac173e4 [Daniel Darabos] Hide "show details" if there are no details to show. 6182da6 [Daniel Darabos] Set a configurable limit on maximum call stack depth. It can be useful in memory-constrained situations with large numbers of stages. 8fe2e34 [Daniel Darabos] Store call stack for stages, display it on the UI. --- .../org/apache/spark/ui/static/webui.css | 21 ++++++++++++ .../scala/org/apache/spark/SparkContext.scala | 18 ++++++----- .../main/scala/org/apache/spark/rdd/RDD.scala | 6 ++-- .../apache/spark/scheduler/ActiveJob.scala | 3 +- .../apache/spark/scheduler/DAGScheduler.scala | 24 +++++++------- .../spark/scheduler/DAGSchedulerEvent.scala | 3 +- .../org/apache/spark/scheduler/Stage.scala | 11 +++++-- .../apache/spark/scheduler/StageInfo.scala | 9 ++++-- .../org/apache/spark/ui/jobs/StageTable.scala | 10 +++++- .../org/apache/spark/util/JsonProtocol.scala | 4 ++- .../scala/org/apache/spark/util/Utils.scala | 32 +++++++++---------- .../apache/spark/SparkContextInfoSuite.scala | 2 +- .../spark/scheduler/DAGSchedulerSuite.scala | 7 ++-- .../ui/jobs/JobProgressListenerSuite.scala | 4 +-- .../apache/spark/util/JsonProtocolSuite.scala | 14 +++++++- 15 files changed, 115 insertions(+), 53 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 599c3ac9b57c0..a8bc141208a94 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -87,3 +87,24 @@ span.kill-link { span.kill-link a { color: gray; } + +span.expand-details { + font-size: 10pt; + cursor: pointer; + color: grey; + float: right; +} + +.stage-details { + max-height: 100px; + overflow-y: auto; + margin: 0; + transition: max-height 0.5s ease-out, padding 0.5s ease-out; +} + +.stage-details.collapsed { + max-height: 0; + padding-top: 0; + padding-bottom: 0; + border: none; +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 35970c2f50892..0678bdd02110e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -49,7 +49,7 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} +import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -1036,9 +1036,11 @@ class SparkContext(config: SparkConf) extends Logging { * Capture the current user callsite and return a formatted version for printing. If the user * has overridden the call site, this will return the user's version. */ - private[spark] def getCallSite(): String = { - val defaultCallSite = Utils.getCallSiteInfo - Option(getLocalProperty("externalCallSite")).getOrElse(defaultCallSite.toString) + private[spark] def getCallSite(): CallSite = { + Option(getLocalProperty("externalCallSite")) match { + case Some(callSite) => CallSite(callSite, long = "") + case None => Utils.getCallSite + } } /** @@ -1058,11 +1060,11 @@ class SparkContext(config: SparkConf) extends Logging { } val callSite = getCallSite val cleanedFunc = clean(func) - logInfo("Starting job: " + callSite) + logInfo("Starting job: " + callSite.short) val start = System.nanoTime dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, resultHandler, localProperties.get) - logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") + logInfo("Job finished: " + callSite.short + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() } @@ -1143,11 +1145,11 @@ class SparkContext(config: SparkConf) extends Logging { evaluator: ApproximateEvaluator[U, R], timeout: Long): PartialResult[R] = { val callSite = getCallSite - logInfo("Starting job: " + callSite) + logInfo("Starting job: " + callSite.short) val start = System.nanoTime val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.get) - logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") + logInfo("Job finished: " + callSite.short + ", took " + (System.nanoTime - start) / 1e9 + " s") result } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 446f369c9ea16..27cc60d775788 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -40,7 +40,7 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1189,8 +1189,8 @@ abstract class RDD[T: ClassTag]( private var storageLevel: StorageLevel = StorageLevel.NONE /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ - @transient private[spark] val creationSiteInfo = Utils.getCallSiteInfo - private[spark] def getCreationSite: String = Option(creationSiteInfo).getOrElse("").toString + @transient private[spark] val creationSite = Utils.getCallSite + private[spark] def getCreationSite: String = Option(creationSite).map(_.short).getOrElse("") private[spark] def elementClassTag: ClassTag[T] = classTag[T] diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index 9257f48559c9e..b755d8fb15757 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.util.Properties import org.apache.spark.TaskContext +import org.apache.spark.util.CallSite /** * Tracks information about an active job in the DAGScheduler. @@ -29,7 +30,7 @@ private[spark] class ActiveJob( val finalStage: Stage, val func: (TaskContext, Iterator[_]) => _, val partitions: Array[Int], - val callSite: String, + val callSite: CallSite, val listener: JobListener, val properties: Properties) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 3c85b5a2ae776..b3ebaa547de0d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -38,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId} -import org.apache.spark.util.{SystemClock, Clock, Utils} +import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils} /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -195,7 +195,9 @@ class DAGScheduler( case Some(stage) => stage case None => val stage = - newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId) + newOrUsedStage( + shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId, + shuffleDep.rdd.creationSite) shuffleToMapStage(shuffleDep.shuffleId) = stage stage } @@ -212,7 +214,7 @@ class DAGScheduler( numTasks: Int, shuffleDep: Option[ShuffleDependency[_, _, _]], jobId: Int, - callSite: Option[String] = None) + callSite: CallSite) : Stage = { val id = nextStageId.getAndIncrement() @@ -235,7 +237,7 @@ class DAGScheduler( numTasks: Int, shuffleDep: ShuffleDependency[_, _, _], jobId: Int, - callSite: Option[String] = None) + callSite: CallSite) : Stage = { val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) @@ -413,7 +415,7 @@ class DAGScheduler( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], - callSite: String, + callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, properties: Properties = null): JobWaiter[U] = @@ -443,7 +445,7 @@ class DAGScheduler( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], - callSite: String, + callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, properties: Properties = null) @@ -452,7 +454,7 @@ class DAGScheduler( waiter.awaitResult() match { case JobSucceeded => {} case JobFailed(exception: Exception) => - logInfo("Failed to run " + callSite) + logInfo("Failed to run " + callSite.short) throw exception } } @@ -461,7 +463,7 @@ class DAGScheduler( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], - callSite: String, + callSite: CallSite, timeout: Long, properties: Properties = null) : PartialResult[R] = @@ -666,7 +668,7 @@ class DAGScheduler( func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], allowLocal: Boolean, - callSite: String, + callSite: CallSite, listener: JobListener, properties: Properties = null) { @@ -674,7 +676,7 @@ class DAGScheduler( try { // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. - finalStage = newStage(finalRDD, partitions.size, None, jobId, Some(callSite)) + finalStage = newStage(finalRDD, partitions.size, None, jobId, callSite) } catch { case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) @@ -685,7 +687,7 @@ class DAGScheduler( val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format( - job.jobId, callSite, partitions.length, allowLocal)) + job.jobId, callSite.short, partitions.length, allowLocal)) logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") logInfo("Parents of final stage: " + finalStage.parents) logInfo("Missing parents: " + getMissingParentStages(finalStage)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 23f57441b4b11..2b6f7e4205c32 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -25,6 +25,7 @@ import scala.language.existentials import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD +import org.apache.spark.util.CallSite /** * Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue @@ -40,7 +41,7 @@ private[scheduler] case class JobSubmitted( func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], allowLocal: Boolean, - callSite: String, + callSite: CallSite, listener: JobListener, properties: Properties = null) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 3bf9713f728c6..9a4be43ee219f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.CallSite /** * A stage is a set of independent tasks all computing the same function that need to run as part @@ -35,6 +36,11 @@ import org.apache.spark.storage.BlockManagerId * Each Stage also has a jobId, identifying the job that first submitted the stage. When FIFO * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered * faster on failure. + * + * The callSite provides a location in user code which relates to the stage. For a shuffle map + * stage, the callSite gives the user code that created the RDD being shuffled. For a result + * stage, the callSite gives the user code that executes the associated action (e.g. count()). + * */ private[spark] class Stage( val id: Int, @@ -43,7 +49,7 @@ private[spark] class Stage( val shuffleDep: Option[ShuffleDependency[_, _, _]], // Output shuffle if stage is a map stage val parents: List[Stage], val jobId: Int, - callSite: Option[String]) + val callSite: CallSite) extends Logging { val isShuffleMap = shuffleDep.isDefined @@ -100,7 +106,8 @@ private[spark] class Stage( id } - val name = callSite.getOrElse(rdd.getCreationSite) + val name = callSite.short + val details = callSite.long override def toString = "Stage " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index b42e231e11f91..7644e3f351b3c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -25,7 +25,12 @@ import org.apache.spark.storage.RDDInfo * Stores information about a stage to pass from the scheduler to SparkListeners. */ @DeveloperApi -class StageInfo(val stageId: Int, val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo]) { +class StageInfo( + val stageId: Int, + val name: String, + val numTasks: Int, + val rddInfos: Seq[RDDInfo], + val details: String) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None /** Time when all tasks in the stage completed or when the stage was cancelled. */ @@ -52,6 +57,6 @@ private[spark] object StageInfo { def fromStage(stage: Stage): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos - new StageInfo(stage.id, stage.name, stage.numTasks, rddInfos) + new StageInfo(stage.id, stage.name, stage.numTasks, rddInfos, stage.details) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 153434a2032be..a3f824a4e1f57 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -91,9 +91,17 @@ private[ui] class StageTableBase( {s.name} + val details = if (s.details.nonEmpty) ( + + +show details + + + ) + listener.stageIdToDescription.get(s.stageId) .map(d =>
{d}
{nameLink} {killLink}
) - .getOrElse(
{killLink}{nameLink}
) + .getOrElse(
{killLink} {nameLink} {details}
) } protected def stageRow(s: StageInfo): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 09825087bb048..7cecbfe62a382 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -184,6 +184,7 @@ private[spark] object JsonProtocol { ("Stage Name" -> stageInfo.name) ~ ("Number of Tasks" -> stageInfo.numTasks) ~ ("RDD Info" -> rddInfo) ~ + ("Details" -> stageInfo.details) ~ ("Submission Time" -> submissionTime) ~ ("Completion Time" -> completionTime) ~ ("Failure Reason" -> failureReason) ~ @@ -469,12 +470,13 @@ private[spark] object JsonProtocol { val stageName = (json \ "Stage Name").extract[String] val numTasks = (json \ "Number of Tasks").extract[Int] val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson) + val details = (json \ "Details").extractOpt[String].getOrElse("") val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]) val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]) val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String]) val emittedTaskSizeWarning = (json \ "Emitted Task Size Warning").extract[Boolean] - val stageInfo = new StageInfo(stageId, stageName, numTasks, rddInfos) + val stageInfo = new StageInfo(stageId, stageName, numTasks, rddInfos, details) stageInfo.submissionTime = submissionTime stageInfo.completionTime = completionTime stageInfo.failureReason = failureReason diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 4ce28bb0cf059..a2454e120a8ab 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -43,6 +43,9 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.ExecutorUncaughtExceptionHandler import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} +/** CallSite represents a place in user code. It can have a short and a long form. */ +private[spark] case class CallSite(val short: String, val long: String) + /** * Various utility methods used by Spark. */ @@ -799,21 +802,12 @@ private[spark] object Utils extends Logging { */ private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r - private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, - val firstUserLine: Int, val firstUserClass: String) { - - /** Returns a printable version of the call site info suitable for logs. */ - override def toString = { - "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine) - } - } - /** * When called inside a class in the spark package, returns the name of the user code class * (outside the spark package) that called into Spark, as well as which Spark method they called. * This is used, for example, to tell users where in their code each RDD got created. */ - def getCallSiteInfo: CallSiteInfo = { + def getCallSite: CallSite = { val trace = Thread.currentThread.getStackTrace() .filterNot(_.getMethodName.contains("getStackTrace")) @@ -824,11 +818,11 @@ private[spark] object Utils extends Logging { var lastSparkMethod = "" var firstUserFile = "" var firstUserLine = 0 - var finished = false - var firstUserClass = "" + var insideSpark = true + var callStack = new ArrayBuffer[String]() :+ "" for (el <- trace) { - if (!finished) { + if (insideSpark) { if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName).isDefined) { lastSparkMethod = if (el.getMethodName == "") { // Spark method is a constructor; get its class name @@ -836,15 +830,21 @@ private[spark] object Utils extends Logging { } else { el.getMethodName } + callStack(0) = el.toString // Put last Spark method on top of the stack trace. } else { firstUserLine = el.getLineNumber firstUserFile = el.getFileName - firstUserClass = el.getClassName - finished = true + callStack += el.toString + insideSpark = false } + } else { + callStack += el.toString } } - new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) + val callStackDepth = System.getProperty("spark.callstack.depth", "20").toInt + CallSite( + short = "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine), + long = callStack.take(callStackDepth).mkString("\n")) } /** Return a string containing part of a file from byte 'start' to 'end'. */ diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala index cd3887dcc7371..1fde4badda949 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -70,7 +70,7 @@ package object testPackage extends Assertions { def runCallSiteTest(sc: SparkContext) { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) val rddCreationSite = rdd.getCreationSite - val curCallSite = sc.getCallSite() // note: 2 lines after definition of "rdd" + val curCallSite = sc.getCallSite().short // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { case CALL_SITE_REGEX(func, file, line) => { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 7506d56d7e26d..45368328297d3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} +import org.apache.spark.util.CallSite class BuggyDAGEventProcessActor extends Actor { val state = 0 @@ -211,7 +212,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F allowLocal: Boolean = false, listener: JobListener = jobListener): Int = { val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, null, listener)) + runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, CallSite("", ""), listener)) jobId } @@ -251,7 +252,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F override def toString = "DAGSchedulerSuite Local RDD" } val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, jobListener)) + runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) assert(results === Map(0 -> 42)) assertDataStructuresEmpty } @@ -265,7 +266,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F override def toString = "DAGSchedulerSuite Local RDD" } val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, jobListener)) + runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) assert(results.size == 0) assertDataStructuresEmpty } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 91b4c7b0dd962..c3a14f48de38e 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -32,12 +32,12 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val listener = new JobProgressListener(conf) def createStageStartEvent(stageId: Int) = { - val stageInfo = new StageInfo(stageId, stageId.toString, 0, null) + val stageInfo = new StageInfo(stageId, stageId.toString, 0, null, "") SparkListenerStageSubmitted(stageInfo) } def createStageEndEvent(stageId: Int) = { - val stageInfo = new StageInfo(stageId, stageId.toString, 0, null) + val stageInfo = new StageInfo(stageId, stageId.toString, 0, null, "") SparkListenerStageCompleted(stageInfo) } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 3031015256ec9..f72389b6b323f 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -117,6 +117,17 @@ class JsonProtocolSuite extends FunSuite { testBlockId(StreamBlockId(1, 2L)) } + test("Backward compatibility") { + // StageInfo.details was added after 1.0.0. + val info = makeStageInfo(1, 2, 3, 4L, 5L) + assert(info.details.nonEmpty) + val newJson = JsonProtocol.stageInfoToJson(info) + val oldJson = newJson.removeField { case (field, _) => field == "Details" } + val newInfo = JsonProtocol.stageInfoFromJson(oldJson) + assert(info.name === newInfo.name) + assert("" === newInfo.details) + } + /** -------------------------- * | Helper test running methods | @@ -235,6 +246,7 @@ class JsonProtocolSuite extends FunSuite { (0 until info1.rddInfos.size).foreach { i => assertEquals(info1.rddInfos(i), info2.rddInfos(i)) } + assert(info1.details === info2.details) } private def assertEquals(info1: RDDInfo, info2: RDDInfo) { @@ -438,7 +450,7 @@ class JsonProtocolSuite extends FunSuite { private def makeStageInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { val rddInfos = (1 to a % 5).map { i => makeRddInfo(a % i, b % i, c % i, d % i, e % i) } - new StageInfo(a, "greetings", b, rddInfos) + new StageInfo(a, "greetings", b, rddInfos, "details") } private def makeTaskInfo(a: Long, b: Int, c: Long) = { From 09deb3eee090eb8ec1d9a0cd90825699748e3ffc Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 17 Jun 2014 01:28:22 -0700 Subject: [PATCH 20/57] [SPARK-2144] ExecutorsPage reports incorrect # of RDD blocks This is reproducible whenever we drop a block because of memory pressure. This is because StorageStatusListener actually never removes anything from the block maps of its StorageStatuses. Instead, when a block is dropped, it sets the block's storage level to `StorageLevel.NONE`, when it should just remove it from the map. This PR includes this simple fix. Author: Andrew Or Closes #1080 from andrewor14/ui-blocks and squashes the following commits: fcf9f1a [Andrew Or] Remove BlockStatus if it is no longer cached --- .../org/apache/spark/storage/StorageStatusListener.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index a6e6627d54e01..c694fc8c347ec 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -37,7 +37,11 @@ class StorageStatusListener extends SparkListener { val filteredStatus = storageStatusList.find(_.blockManagerId.executorId == execId) filteredStatus.foreach { storageStatus => updatedBlocks.foreach { case (blockId, updatedStatus) => - storageStatus.blocks(blockId) = updatedStatus + if (updatedStatus.storageLevel == StorageLevel.NONE) { + storageStatus.blocks.remove(blockId) + } else { + storageStatus.blocks(blockId) = updatedStatus + } } } } From f5a4049e534da3c55e1b495ce34155236dfb6dee Mon Sep 17 00:00:00 2001 From: Xi Liu Date: Tue, 17 Jun 2014 13:14:40 +0200 Subject: [PATCH 21/57] [SPARK-2164][SQL] Allow Hive UDF on columns of type struct Author: Xi Liu Closes #796 from xiliu82/sqlbug and squashes the following commits: 328dfc4 [Xi Liu] [Spark SQL] remove a temporary function after test 354386a [Xi Liu] [Spark SQL] add test suite for UDF on struct 8fc6f51 [Xi Liu] [SparkSQL] allow UDF on struct --- .../org/apache/spark/sql/hive/hiveUdfs.scala | 3 + .../resources/data/files/testUdf/part-00000 | Bin 0 -> 153 bytes .../sql/hive/execution/HiveUdfSuite.scala | 127 ++++++++++++++++++ 3 files changed, 130 insertions(+) create mode 100755 sql/hive/src/test/resources/data/files/testUdf/part-00000 create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 771d2bccf43a7..ad5e24c62c621 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -335,6 +335,9 @@ private[hive] trait HiveInspectors { case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector + case StructType(fields) => + ObjectInspectorFactory.getStandardStructObjectInspector( + fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) } def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { diff --git a/sql/hive/src/test/resources/data/files/testUdf/part-00000 b/sql/hive/src/test/resources/data/files/testUdf/part-00000 new file mode 100755 index 0000000000000000000000000000000000000000..240a5c1a63c5c4016d096cbd13ddc8b787aee8da GIT binary patch literal 153 zcmWG`4P;ZyFG|--EJ#ewNY%?oOv%qL(96u%^DE8C2`|blNleN~)j?8GT##6ltyf%_ zqnD9cma3Opk(yjul9`{U7m`|B5|Ef#!~h0IwrxAkdN!tuT_U@F&YjICfr1 + |) + |PARTITIONED BY (partition STRING) + |ROW FORMAT SERDE '%s' + |STORED AS SEQUENCEFILE + """.stripMargin.format(classOf[PairSerDe].getName) + ) + + TestHive.hql( + "ALTER TABLE hiveUdfTestTable ADD IF NOT EXISTS PARTITION(partition='testUdf') LOCATION '%s'" + .format(this.getClass.getClassLoader.getResource("data/files/testUdf").getFile) + ) + + TestHive.hql("CREATE TEMPORARY FUNCTION testUdf AS '%s'".format(classOf[PairUdf].getName)) + + TestHive.hql("SELECT testUdf(pair) FROM hiveUdfTestTable") + + TestHive.hql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") +} + +class TestPair(x: Int, y: Int) extends Writable with Serializable { + def this() = this(0, 0) + var entry: (Int, Int) = (x, y) + + override def write(output: DataOutput): Unit = { + output.writeInt(entry._1) + output.writeInt(entry._2) + } + + override def readFields(input: DataInput): Unit = { + val x = input.readInt() + val y = input.readInt() + entry = (x, y) + } +} + +class PairSerDe extends AbstractSerDe { + override def initialize(p1: Configuration, p2: Properties): Unit = {} + + override def getObjectInspector: ObjectInspector = { + ObjectInspectorFactory + .getStandardStructObjectInspector( + Seq("pair"), + Seq(ObjectInspectorFactory.getStandardStructObjectInspector( + Seq("id", "value"), + Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector)) + )) + } + + override def getSerializedClass: Class[_ <: Writable] = classOf[TestPair] + + override def getSerDeStats: SerDeStats = null + + override def serialize(p1: scala.Any, p2: ObjectInspector): Writable = null + + override def deserialize(value: Writable): AnyRef = { + val pair = value.asInstanceOf[TestPair] + + val row = new util.ArrayList[util.ArrayList[AnyRef]] + row.add(new util.ArrayList[AnyRef](2)) + row(0).add(Integer.valueOf(pair.entry._1)) + row(0).add(Integer.valueOf(pair.entry._2)) + + row + } +} + +class PairUdf extends GenericUDF { + override def initialize(p1: Array[ObjectInspector]): ObjectInspector = + ObjectInspectorFactory.getStandardStructObjectInspector( + Seq("id", "value"), + Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.javaIntObjectInspector) + ) + + override def evaluate(args: Array[DeferredObject]): AnyRef = { + println("Type = %s".format(args(0).getClass.getName)) + Integer.valueOf(args(0).get.asInstanceOf[TestPair].entry._2) + } + + override def getDisplayString(p1: Array[String]): String = "" +} + + + From e243c5ffacd70ecadaf5c91668955dcc8141e060 Mon Sep 17 00:00:00 2001 From: Zongheng Yang Date: Tue, 17 Jun 2014 13:30:17 +0200 Subject: [PATCH 22/57] [SPARK-2053][SQL] Add Catalyst expressions for CASE WHEN. JIRA ticket: https://issues.apache.org/jira/browse/SPARK-2053 This PR adds support for two types of CASE statements present in Hive. The first type is of the form `CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END`, with the semantics like a chain of if statements. The second type is of the form `CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END`, with the semantics like a switch statement on key `a`. Both forms are implemented in `CaseWhen`. [This link](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions) contains more detailed descriptions on their semantics. Notes / Open issues: * Please check if any implicit contracts / invariants are broken in the implementations (especially for the operators). I am not very familiar with them and I currently find them tricky to spot. * We should decide whether or not a non-boolean condition is allowed in a branch of `CaseWhen`. Hive throws a `SemanticException` for this situation and I think it'd be good to mimic it -- the question is where in the whole Spark SQL pipeline should we signal an exception for such a query. Author: Zongheng Yang Closes #1055 from concretevitamin/caseWhen and squashes the following commits: 4226eb9 [Zongheng Yang] Comment. 79d26fc [Zongheng Yang] Merge branch 'master' into caseWhen caf9383 [Zongheng Yang] Update a FIXME. 9d26ab8 [Zongheng Yang] Add @transient marker. 788a0d9 [Zongheng Yang] Implement CastNulls, which fixes udf_case and udf_when. 7ef284f [Zongheng Yang] Refactors: remove redundant passes, improve toString, mark transient. f47ae7b [Zongheng Yang] Modify queries in tests to have shorter golden files. 1c1fbfc [Zongheng Yang] Cleanups per review comments. 7d2b7e2 [Zongheng Yang] Translate CaseKeyWhen to CaseWhen at parsing time. 47d406a [Zongheng Yang] Do toArray once and lazily outside of eval(). bb3d109 [Zongheng Yang] Update scaladoc of a method. aea3195 [Zongheng Yang] Fix bug that branchesArr is not used; remove unused import. 96870a8 [Zongheng Yang] Turn off scalastyle for some comments. 7392f3a [Zongheng Yang] Minor cleanup. 2cf08bb [Zongheng Yang] Merge branch 'master' into caseWhen 9f84b40 [Zongheng Yang] Add golden outputs from Hive. db51a85 [Zongheng Yang] Add allCondBooleans check; uncomment tests. 3f9ef0a [Zongheng Yang] Cleanups and bug fixes (mainly in eval() and resolved). be54bc8 [Zongheng Yang] Rewrite eval() to a low-level implementation. Separate two CASE stmts. f2bcb9d [Zongheng Yang] WIP 5906f75 [Zongheng Yang] WIP efd019b [Zongheng Yang] eval() and toString() bug fixes. 7d81e95 [Zongheng Yang] Clean up resolved. a31d782 [Zongheng Yang] Finish up Case. --- .../catalyst/analysis/HiveTypeCoercion.scala | 41 +++++++++- .../sql/catalyst/expressions/Expression.scala | 10 ++- .../sql/catalyst/expressions/predicates.scala | 76 ++++++++++++++++++- .../spark/sql/catalyst/util/package.scala | 2 +- .../ExpressionEvaluationSuite.scala | 2 +- .../org/apache/spark/sql/hive/HiveQl.scala | 17 +++++ ... key #1-0-36750f0f6727c287c471309689ff7563 | 14 ++++ ... key #2-0-e3a2b981ebff7e273537dd6c43ece0c0 | 14 ++++ ... key #3-0-be5efc0574a97ec465e2686f4a724bd5 | 14 ++++ ... key #4-0-631f824a91b7230657bea7a05e393a1e | 14 ++++ ... key #1-0-616830b2011da0990e87a188fb609299 | 14 ++++ ... key #2-0-6c5b5a997949f9e5ab9676b60e95657b | 14 ++++ ... key #3-0-a241862582c47d9e98be95339d35c7c4 | 14 ++++ ... key #4-0-ea87ca38ead8858d2337792dcd430226 | 14 ++++ .../sql/hive/execution/HiveQuerySuite.scala | 38 ++++++++++ 15 files changed, 290 insertions(+), 8 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/case statements WITHOUT key #1-0-36750f0f6727c287c471309689ff7563 create mode 100644 sql/hive/src/test/resources/golden/case statements WITHOUT key #2-0-e3a2b981ebff7e273537dd6c43ece0c0 create mode 100644 sql/hive/src/test/resources/golden/case statements WITHOUT key #3-0-be5efc0574a97ec465e2686f4a724bd5 create mode 100644 sql/hive/src/test/resources/golden/case statements WITHOUT key #4-0-631f824a91b7230657bea7a05e393a1e create mode 100644 sql/hive/src/test/resources/golden/case statements with key #1-0-616830b2011da0990e87a188fb609299 create mode 100644 sql/hive/src/test/resources/golden/case statements with key #2-0-6c5b5a997949f9e5ab9676b60e95657b create mode 100644 sql/hive/src/test/resources/golden/case statements with key #3-0-a241862582c47d9e98be95339d35c7c4 create mode 100644 sql/hive/src/test/resources/golden/case statements with key #4-0-ea87ca38ead8858d2337792dcd430226 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 326feea6fee91..d291814c8aa7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -31,8 +31,16 @@ import org.apache.spark.sql.catalyst.types._ trait HiveTypeCoercion { val typeCoercionRules = - List(PropagateTypes, ConvertNaNs, WidenTypes, PromoteStrings, BooleanComparisons, BooleanCasts, - StringToIntegralCasts, FunctionArgumentConversion) + PropagateTypes :: + ConvertNaNs :: + WidenTypes :: + PromoteStrings :: + BooleanComparisons :: + BooleanCasts :: + StringToIntegralCasts :: + FunctionArgumentConversion :: + CastNulls :: + Nil /** * Applies any changes to [[catalyst.expressions.AttributeReference AttributeReference]] data @@ -282,4 +290,33 @@ trait HiveTypeCoercion { Average(Cast(e, DoubleType)) } } + + /** + * Ensures that NullType gets casted to some other types under certain circumstances. + */ + object CastNulls extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case cw @ CaseWhen(branches) => + val valueTypes = branches.sliding(2, 2).map { + case Seq(_, value) if value.resolved => Some(value.dataType) + case Seq(elseVal) if elseVal.resolved => Some(elseVal.dataType) + case _ => None + }.toSeq + if (valueTypes.distinct.size == 2 && valueTypes.exists(_ == Some(NullType))) { + val otherType = valueTypes.filterNot(_ == Some(NullType))(0).get + val transformedBranches = branches.sliding(2, 2).map { + case Seq(cond, value) if value.resolved && value.dataType == NullType => + Seq(cond, Cast(value, otherType)) + case Seq(elseVal) if elseVal.resolved && elseVal.dataType == NullType => + Seq(Cast(elseVal, otherType)) + case s => s + }.reduce(_ ++ _) + CaseWhen(transformedBranches) + } else { + // It is possible to have more types due to the possibility of short-circuiting. + cw + } + } + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 41398ff956edd..3912f5f4375fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -28,8 +28,6 @@ abstract class Expression extends TreeNode[Expression] { /** The narrowest possible type that is produced when this expression is evaluated. */ type EvaluatedType <: Any - def dataType: DataType - /** * Returns true when an expression is a candidate for static evaluation before the query is * executed. @@ -53,12 +51,18 @@ abstract class Expression extends TreeNode[Expression] { /** * Returns `true` if this expression and all its children have been resolved to a specific schema - * and `false` if it is still contains any unresolved placeholders. Implementations of expressions + * and `false` if it still contains any unresolved placeholders. Implementations of expressions * should override this if the resolution of this type of expression involves more than just * the resolution of its children. */ lazy val resolved: Boolean = childrenResolved + /** + * Returns the [[types.DataType DataType]] of the result of evaluating this expression. It is + * invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false). + */ + def dataType: DataType + /** * Returns true if all the children of this expression have been resolved to a specific schema * and false if any still contains any unresolved placeholders. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index d111578530506..2902906df2844 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.types.BooleanType @@ -202,3 +201,78 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def toString = s"if ($predicate) $trueValue else $falseValue" } + +// scalastyle:off +/** + * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". + * Refer to this link for the corresponding semantics: + * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions + * + * The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets + * translated to this form at parsing time. Namely, such a statement gets translated to + * "CASE WHEN a=b THEN c [WHEN a=d THEN e]* [ELSE f] END". + * + * Note that `branches` are considered in consecutive pairs (cond, val), and the optional last + * element is the value for the default catch-all case (if provided). Hence, `branches` consists of + * at least two elements, and can have an odd or even length. + */ +// scalastyle:on +case class CaseWhen(branches: Seq[Expression]) extends Expression { + type EvaluatedType = Any + def children = branches + def references = children.flatMap(_.references).toSet + def dataType = { + if (!resolved) { + throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") + } + branches(1).dataType + } + + @transient private[this] lazy val branchesArr = branches.toArray + @transient private[this] lazy val predicates = + branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq + @transient private[this] lazy val values = + branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq + + override def nullable = { + // If no value is nullable and no elseValue is provided, the whole statement defaults to null. + values.exists(_.nullable) || (values.length % 2 == 0) + } + + override lazy val resolved = { + if (!childrenResolved) { + false + } else { + val allCondBooleans = predicates.forall(_.dataType == BooleanType) + val dataTypesEqual = values.map(_.dataType).distinct.size <= 1 + allCondBooleans && dataTypesEqual + } + } + + /** Written in imperative fashion for performance considerations. Same for CaseKeyWhen. */ + override def eval(input: Row): Any = { + val len = branchesArr.length + var i = 0 + // If all branches fail and an elseVal is not provided, the whole statement + // defaults to null, according to Hive's semantics. + var res: Any = null + while (i < len - 1) { + if (branchesArr(i).eval(input) == true) { + res = branchesArr(i + 1).eval(input) + return res + } + i += 2 + } + if (i == len - 1) { + res = branchesArr(i).eval(input) + } + res + } + + override def toString = { + "CASE" + branches.sliding(2, 2).map { + case Seq(cond, value) => s" WHEN $cond THEN $value" + case Seq(elseValue) => s" ELSE $elseValue" + }.mkString + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 49fc4f70fdfae..d8da45ae70c4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -115,7 +115,7 @@ package object util { } /* FIX ME - implicit class debugLogging(a: AnyRef) { + implicit class debugLogging(a: Any) { def debugLogging() { org.apache.log4j.Logger.getLogger(a.getClass.getName).setLevel(org.apache.log4j.Level.DEBUG) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 1132a30b42767..8c3b062d0f801 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -35,7 +35,7 @@ class ExpressionEvaluationSuite extends FunSuite { /** * Checks for three-valued-logic. Based on: * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29 - * + * I.e. in flat cpo "False -> Unknown -> True", OR is lowest upper bound, AND is greatest lower bound. * p q p OR q p AND q p = q * True True True True True * True False True False False diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index b745d8ffd8f17..844673f66d103 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -811,6 +811,8 @@ private[hive] object HiveQl { val IN = "(?i)IN".r val DIV = "(?i)DIV".r val BETWEEN = "(?i)BETWEEN".r + val WHEN = "(?i)WHEN".r + val CASE = "(?i)CASE".r protected def nodeToExpr(node: Node): Expression = node match { /* Attribute References */ @@ -917,6 +919,21 @@ private[hive] object HiveQl { case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right)) case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) + /* Case statements */ + case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => + CaseWhen(branches.map(nodeToExpr)) + case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => + val transformed = branches.drop(1).sliding(2, 2).map { + case Seq(condVal, value) => + // FIXME (SPARK-2155): the key will get evaluated for multiple times in CaseWhen's eval(). + // Hence effectful / non-deterministic key expressions are *not* supported at the moment. + // We should consider adding new Expressions to get around this. + Seq(Equals(nodeToExpr(branches(0)), nodeToExpr(condVal)), + nodeToExpr(value)) + case Seq(elseVal) => Seq(nodeToExpr(elseVal)) + }.toSeq.reduce(_ ++ _) + CaseWhen(transformed) + /* Complex datatype manipulation */ case Token("[", child :: ordinal :: Nil) => GetItem(nodeToExpr(child), nodeToExpr(ordinal)) diff --git a/sql/hive/src/test/resources/golden/case statements WITHOUT key #1-0-36750f0f6727c287c471309689ff7563 b/sql/hive/src/test/resources/golden/case statements WITHOUT key #1-0-36750f0f6727c287c471309689ff7563 new file mode 100644 index 0000000000000..816fe57d162dc --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements WITHOUT key #1-0-36750f0f6727c287c471309689ff7563 @@ -0,0 +1,14 @@ +NULL +3 +3 +3 +NULL +NULL +3 +3 +3 +3 +NULL +3 +3 +3 diff --git a/sql/hive/src/test/resources/golden/case statements WITHOUT key #2-0-e3a2b981ebff7e273537dd6c43ece0c0 b/sql/hive/src/test/resources/golden/case statements WITHOUT key #2-0-e3a2b981ebff7e273537dd6c43ece0c0 new file mode 100644 index 0000000000000..4cca081e6e294 --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements WITHOUT key #2-0-e3a2b981ebff7e273537dd6c43ece0c0 @@ -0,0 +1,14 @@ +4 +3 +3 +3 +4 +4 +3 +3 +3 +3 +4 +3 +3 +3 diff --git a/sql/hive/src/test/resources/golden/case statements WITHOUT key #3-0-be5efc0574a97ec465e2686f4a724bd5 b/sql/hive/src/test/resources/golden/case statements WITHOUT key #3-0-be5efc0574a97ec465e2686f4a724bd5 new file mode 100644 index 0000000000000..8d0416a8f8d9c --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements WITHOUT key #3-0-be5efc0574a97ec465e2686f4a724bd5 @@ -0,0 +1,14 @@ +2 +3 +3 +3 +2 +2 +3 +3 +3 +3 +NULL +3 +3 +3 diff --git a/sql/hive/src/test/resources/golden/case statements WITHOUT key #4-0-631f824a91b7230657bea7a05e393a1e b/sql/hive/src/test/resources/golden/case statements WITHOUT key #4-0-631f824a91b7230657bea7a05e393a1e new file mode 100644 index 0000000000000..6ed452bcd870d --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements WITHOUT key #4-0-631f824a91b7230657bea7a05e393a1e @@ -0,0 +1,14 @@ +2 +3 +3 +3 +2 +2 +3 +3 +3 +3 +0 +3 +3 +3 diff --git a/sql/hive/src/test/resources/golden/case statements with key #1-0-616830b2011da0990e87a188fb609299 b/sql/hive/src/test/resources/golden/case statements with key #1-0-616830b2011da0990e87a188fb609299 new file mode 100644 index 0000000000000..3f5a2fbbe99fd --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements with key #1-0-616830b2011da0990e87a188fb609299 @@ -0,0 +1,14 @@ +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL diff --git a/sql/hive/src/test/resources/golden/case statements with key #2-0-6c5b5a997949f9e5ab9676b60e95657b b/sql/hive/src/test/resources/golden/case statements with key #2-0-6c5b5a997949f9e5ab9676b60e95657b new file mode 100644 index 0000000000000..e1ca6e76d1f8f --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements with key #2-0-6c5b5a997949f9e5ab9676b60e95657b @@ -0,0 +1,14 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +3 +0 +0 +0 diff --git a/sql/hive/src/test/resources/golden/case statements with key #3-0-a241862582c47d9e98be95339d35c7c4 b/sql/hive/src/test/resources/golden/case statements with key #3-0-a241862582c47d9e98be95339d35c7c4 new file mode 100644 index 0000000000000..896207fdbcf3d --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements with key #3-0-a241862582c47d9e98be95339d35c7c4 @@ -0,0 +1,14 @@ +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +3 +NULL +NULL +NULL diff --git a/sql/hive/src/test/resources/golden/case statements with key #4-0-ea87ca38ead8858d2337792dcd430226 b/sql/hive/src/test/resources/golden/case statements with key #4-0-ea87ca38ead8858d2337792dcd430226 new file mode 100644 index 0000000000000..e1ca6e76d1f8f --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements with key #4-0-ea87ca38ead8858d2337792dcd430226 @@ -0,0 +1,14 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +3 +0 +0 +0 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 04652587f9073..fe698f0fc57b8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -164,6 +164,44 @@ class HiveQuerySuite extends HiveComparisonTest { hql("SELECT * FROM src").toString } + createQueryTest("case statements with key #1", + "SELECT (CASE 1 WHEN 2 THEN 3 END) FROM src where key < 15") + + createQueryTest("case statements with key #2", + "SELECT (CASE key WHEN 2 THEN 3 ELSE 0 END) FROM src WHERE key < 15") + + createQueryTest("case statements with key #3", + "SELECT (CASE key WHEN 2 THEN 3 WHEN NULL THEN 4 END) FROM src WHERE key < 15") + + createQueryTest("case statements with key #4", + "SELECT (CASE key WHEN 2 THEN 3 WHEN NULL THEN 4 ELSE 0 END) FROM src WHERE key < 15") + + createQueryTest("case statements WITHOUT key #1", + "SELECT (CASE WHEN key > 2 THEN 3 END) FROM src WHERE key < 15") + + createQueryTest("case statements WITHOUT key #2", + "SELECT (CASE WHEN key > 2 THEN 3 ELSE 4 END) FROM src WHERE key < 15") + + createQueryTest("case statements WITHOUT key #3", + "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 END) FROM src WHERE key < 15") + + createQueryTest("case statements WITHOUT key #4", + "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15") + + test("implement identity function using case statement") { + val actual = hql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet + val expected = hql("SELECT key FROM src").collect().toSet + assert(actual === expected) + } + + // TODO: adopt this test when Spark SQL has the functionality / framework to report errors. + // See https://github.com/apache/spark/pull/1055#issuecomment-45820167 for a discussion. + ignore("non-boolean conditions in a CaseWhen are illegal") { + intercept[Exception] { + hql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() + } + } + private val explainCommandClassName = classOf[execution.ExplainCommand].getSimpleName.stripSuffix("$") From b92d16b114fd49e881d09e7974ad57b2a0df2906 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Tue, 17 Jun 2014 11:47:48 -0700 Subject: [PATCH 23/57] SPARK-1063 Add .sortBy(f) method on RDD This never got merged from the apache/incubator-spark repo (which is now deleted) but there had been several rounds of code review on this PR there. I think this is ready for merging. Author: Andrew Ash This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #369 from ash211/sortby and squashes the following commits: d09147a [Andrew Ash] Fix Ordering import 43d0a53 [Andrew Ash] Fix missing .collect() 29a54ed [Andrew Ash] Re-enable test by converting to a closure 5a95348 [Andrew Ash] Add license for RDDSuiteUtils 64ed6e3 [Andrew Ash] Remove leaked diff d4de69a [Andrew Ash] Remove scar tissue 63638b5 [Andrew Ash] Add Python version of .sortBy() 45e0fde [Andrew Ash] Add Java version of .sortBy() adf84c5 [Andrew Ash] Re-indent to keep line lengths under 100 chars 9d9b9d8 [Andrew Ash] Use parentheses on .collect() calls 0457b69 [Andrew Ash] Ignore failing test 99f0baf [Andrew Ash] Merge branch 'master' into sortby 222ae97 [Andrew Ash] Try moving Ordering objects out to a different class 3fd0dd3 [Andrew Ash] Add (failing) test for sortByKey with explicit Ordering b8b5bbc [Andrew Ash] Align remove extra spaces that were used to align ='s in test code 8c53298 [Andrew Ash] Actually use ascending and numPartitions parameters 381eef2 [Andrew Ash] Correct silly typo 7db3e84 [Andrew Ash] Support ascending and numPartitions params in sortBy() 0f685fd [Andrew Ash] Merge remote-tracking branch 'origin/master' into sortby ca4490d [Andrew Ash] Add .sortBy(f) method on RDD --- .../org/apache/spark/api/java/JavaRDD.scala | 16 +++++ .../main/scala/org/apache/spark/rdd/RDD.scala | 12 ++++ .../java/org/apache/spark/JavaAPISuite.java | 33 +++++++++++ .../scala/org/apache/spark/rdd/RDDSuite.scala | 59 +++++++++++++++++-- .../org/apache/spark/rdd/RDDSuiteUtils.scala | 31 ++++++++++ python/pyspark/rdd.py | 12 ++++ 6 files changed, 159 insertions(+), 4 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 23d13710794af..86fb374bef1e3 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -17,10 +17,13 @@ package org.apache.spark.api.java +import java.util.Comparator + import scala.language.implicitConversions import scala.reflect.ClassTag import org.apache.spark._ +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -172,6 +175,19 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) rdd.setName(name) this } + + /** + * Return this RDD sorted by the given key function. + */ + def sortBy[S](f: JFunction[T, S], ascending: Boolean, numPartitions: Int): JavaRDD[T] = { + import scala.collection.JavaConverters._ + def fn = (x: T) => f.call(x) + import com.google.common.collect.Ordering // shadows scala.math.Ordering + implicit val ordering = Ordering.natural().asInstanceOf[Ordering[S]] + implicit val ctag: ClassTag[S] = fakeClassTag + wrapRDD(rdd.sortBy(fn, ascending, numPartitions)) + } + } object JavaRDD { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 27cc60d775788..cf915b870e0d3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -442,6 +442,18 @@ abstract class RDD[T: ClassTag]( */ def ++(other: RDD[T]): RDD[T] = this.union(other) + /** + * Return this RDD sorted by the given key function. + */ + def sortBy[K]( + f: (T) ⇒ K, + ascending: Boolean = true, + numPartitions: Int = this.partitions.size) + (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = + this.keyBy[K](f) + .sortByKey(ascending, numPartitions) + .values + /** * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index ef41bfb88de9d..e46298c6a9e63 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -180,6 +180,39 @@ public void sortByKey() { Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); } + @Test + public void sortBy() { + List> pairs = new ArrayList>(); + pairs.add(new Tuple2(0, 4)); + pairs.add(new Tuple2(3, 2)); + pairs.add(new Tuple2(-1, 1)); + + JavaRDD> rdd = sc.parallelize(pairs); + + // compare on first value + JavaRDD> sortedRDD = rdd.sortBy(new Function, Integer>() { + public Integer call(Tuple2 t) throws Exception { + return t._1(); + } + }, true, 2); + + Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + List> sortedPairs = sortedRDD.collect(); + Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + + // compare on second value + sortedRDD = rdd.sortBy(new Function, Integer>() { + public Integer call(Tuple2 t) throws Exception { + return t._2(); + } + }, true, 2); + Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + sortedPairs = sortedRDD.collect(); + Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(2)); + } + @Test public void foreach() { final Accumulator accum = sc.accumulator(0); diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e94a1e76d410c..0e5625b7645d5 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -26,6 +26,8 @@ import org.apache.spark._ import org.apache.spark.SparkContext._ import org.apache.spark.util.Utils +import org.apache.spark.rdd.RDDSuiteUtils._ + class RDDSuite extends FunSuite with SharedSparkContext { test("basic operations") { @@ -585,14 +587,63 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("sortByKey") { + val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B")) + + val col1 = Array("4|60|C", "5|50|A", "6|40|B") + val col2 = Array("6|40|B", "5|50|A", "4|60|C") + val col3 = Array("5|50|A", "6|40|B", "4|60|C") + + assert(data.sortBy(_.split("\\|")(0)).collect() === col1) + assert(data.sortBy(_.split("\\|")(1)).collect() === col2) + assert(data.sortBy(_.split("\\|")(2)).collect() === col3) + } + + test("sortByKey ascending parameter") { + val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B")) + + val asc = Array("4|60|C", "5|50|A", "6|40|B") + val desc = Array("6|40|B", "5|50|A", "4|60|C") + + assert(data.sortBy(_.split("\\|")(0), true).collect() === asc) + assert(data.sortBy(_.split("\\|")(0), false).collect() === desc) + } + + test("sortByKey with explicit ordering") { + val data = sc.parallelize(Seq("Bob|Smith|50", + "Jane|Smith|40", + "Thomas|Williams|30", + "Karen|Williams|60")) + + val ageOrdered = Array("Thomas|Williams|30", + "Jane|Smith|40", + "Bob|Smith|50", + "Karen|Williams|60") + + // last name, then first name + val nameOrdered = Array("Bob|Smith|50", + "Jane|Smith|40", + "Karen|Williams|60", + "Thomas|Williams|30") + + val parse = (s: String) => { + val split = s.split("\\|") + Person(split(0), split(1), split(2).toInt) + } + + import scala.reflect.classTag + assert(data.sortBy(parse, true, 2)(AgeOrdering, classTag[Person]).collect() === ageOrdered) + assert(data.sortBy(parse, true, 2)(NameOrdering, classTag[Person]).collect() === nameOrdered) + } + test("intersection") { val all = sc.parallelize(1 to 10) val evens = sc.parallelize(2 to 10 by 2) val intersection = Array(2, 4, 6, 8, 10) // intersection is commutative - assert(all.intersection(evens).collect.sorted === intersection) - assert(evens.intersection(all).collect.sorted === intersection) + assert(all.intersection(evens).collect().sorted === intersection) + assert(evens.intersection(all).collect().sorted === intersection) } test("intersection strips duplicates in an input") { @@ -600,8 +651,8 @@ class RDDSuite extends FunSuite with SharedSparkContext { val b = sc.parallelize(Seq(1,1,2,3)) val intersection = Array(1,2,3) - assert(a.intersection(b).collect.sorted === intersection) - assert(b.intersection(a).collect.sorted === intersection) + assert(a.intersection(b).collect().sorted === intersection) + assert(b.intersection(a).collect().sorted === intersection) } test("zipWithIndex") { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala new file mode 100644 index 0000000000000..4762fc17855ce --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +object RDDSuiteUtils { + case class Person(first: String, last: String, age: Int) + + object AgeOrdering extends Ordering[Person] { + def compare(a:Person, b:Person) = a.age compare b.age + } + + object NameOrdering extends Ordering[Person] { + def compare(a:Person, b:Person) = + implicitly[Ordering[Tuple2[String,String]]].compare((a.last, a.first), (b.last, b.first)) + } +} diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index bb4d035edcdeb..65f63153cdff4 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -549,6 +549,18 @@ def mapFunc(iterator): .mapPartitions(mapFunc,preservesPartitioning=True) .flatMap(lambda x: x, preservesPartitioning=True)) + def sortBy(self, keyfunc, ascending=True, numPartitions=None): + """ + Sorts this RDD by the given keyfunc + + >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + >>> sc.parallelize(tmp).sortBy(lambda x: x[0]).collect() + [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] + >>> sc.parallelize(tmp).sortBy(lambda x: x[1]).collect() + [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + """ + return self.keyBy(keyfunc).sortByKey(ascending, numPartitions).values() + def glom(self): """ Return an RDD created by coalescing all elements within each partition From 2794990e9eb8712d76d3a0f0483063ddc295e639 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Tue, 17 Jun 2014 12:03:22 -0700 Subject: [PATCH 24/57] SPARK-2146. Fix takeOrdered doc Removes Python syntax in Scaladoc, corrects result in Scaladoc, and removes irrelevant cache() call in Python doc. Author: Sandy Ryza Closes #1086 from sryza/sandy-spark-2146 and squashes the following commits: 185ff18 [Sandy Ryza] Use Seq instead of Array c996120 [Sandy Ryza] SPARK-2146. Fix takeOrdered doc --- .../main/scala/org/apache/spark/rdd/RDD.scala | 16 ++++++++-------- python/pyspark/rdd.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index cf915b870e0d3..1633b185861b9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1074,11 +1074,11 @@ abstract class RDD[T: ClassTag]( * Returns the top K (largest) elements from this RDD as defined by the specified * implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example: * {{{ - * sc.parallelize([10, 4, 2, 12, 3]).top(1) - * // returns [12] + * sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1) + * // returns Array(12) * - * sc.parallelize([2, 3, 4, 5, 6]).top(2) - * // returns [6, 5] + * sc.parallelize(Seq(2, 3, 4, 5, 6)).top(2) + * // returns Array(6, 5) * }}} * * @param num the number of top elements to return @@ -1092,11 +1092,11 @@ abstract class RDD[T: ClassTag]( * implicit Ordering[T] and maintains the ordering. This does the opposite of [[top]]. * For example: * {{{ - * sc.parallelize([10, 4, 2, 12, 3]).takeOrdered(1) - * // returns [12] + * sc.parallelize(Seq(10, 4, 2, 12, 3)).takeOrdered(1) + * // returns Array(2) * - * sc.parallelize([2, 3, 4, 5, 6]).takeOrdered(2) - * // returns [2, 3] + * sc.parallelize(Seq(2, 3, 4, 5, 6)).takeOrdered(2) + * // returns Array(2, 3) * }}} * * @param num the number of top elements to return diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 65f63153cdff4..a0b2c744f0e7f 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -857,7 +857,7 @@ def top(self, num): Note: It returns the list sorted in descending order. >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) [12] - >>> sc.parallelize([2, 3, 4, 5, 6], 2).cache().top(2) + >>> sc.parallelize([2, 3, 4, 5, 6], 2).top(2) [6, 5] """ def topIterator(iterator): From 443f5e1bbcf9ec55e5ce6e4f738a002a47818100 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 17 Jun 2014 12:17:48 -0700 Subject: [PATCH 25/57] SPARK-2038: rename "conf" parameters in the saveAsHadoop functions to distinguish with SparkConf object https://issues.apache.org/jira/browse/SPARK-2038 Author: CodingCat Closes #1087 from CodingCat/SPARK-2038 and squashes the following commits: 763975f [CodingCat] style fix d91288d [CodingCat] rename "conf" parameters in the saveAsHadoop functions --- .../apache/spark/rdd/PairRDDFunctions.scala | 49 ++++++++++--------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index fe36c80e0be84..bff77b4ecbf27 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -719,9 +719,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration = self.context.hadoopConfiguration) + hadoopConf: Configuration = self.context.hadoopConfiguration) { - val job = new NewAPIHadoopJob(conf) + val job = new NewAPIHadoopJob(hadoopConf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) @@ -752,24 +752,25 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf = new JobConf(self.context.hadoopConfiguration), + hadoopConf: JobConf = new JobConf(self.context.hadoopConfiguration), codec: Option[Class[_ <: CompressionCodec]] = None) { - conf.setOutputKeyClass(keyClass) - conf.setOutputValueClass(valueClass) + hadoopConf.setOutputKeyClass(keyClass) + hadoopConf.setOutputValueClass(valueClass) // Doesn't work in Scala 2.9 due to what may be a generics bug // TODO: Should we uncomment this for Scala 2.10? // conf.setOutputFormat(outputFormatClass) - conf.set("mapred.output.format.class", outputFormatClass.getName) + hadoopConf.set("mapred.output.format.class", outputFormatClass.getName) for (c <- codec) { - conf.setCompressMapOutput(true) - conf.set("mapred.output.compress", "true") - conf.setMapOutputCompressorClass(c) - conf.set("mapred.output.compression.codec", c.getCanonicalName) - conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) + hadoopConf.setCompressMapOutput(true) + hadoopConf.set("mapred.output.compress", "true") + hadoopConf.setMapOutputCompressorClass(c) + hadoopConf.set("mapred.output.compression.codec", c.getCanonicalName) + hadoopConf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) } - conf.setOutputCommitter(classOf[FileOutputCommitter]) - FileOutputFormat.setOutputPath(conf, SparkHadoopWriter.createPathFromString(path, conf)) - saveAsHadoopDataset(conf) + hadoopConf.setOutputCommitter(classOf[FileOutputCommitter]) + FileOutputFormat.setOutputPath(hadoopConf, + SparkHadoopWriter.createPathFromString(path, hadoopConf)) + saveAsHadoopDataset(hadoopConf) } /** @@ -778,8 +779,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. */ - def saveAsNewAPIHadoopDataset(conf: Configuration) { - val job = new NewAPIHadoopJob(conf) + def saveAsNewAPIHadoopDataset(hadoopConf: Configuration) { + val job = new NewAPIHadoopJob(hadoopConf) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id @@ -835,10 +836,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop * MapReduce job. */ - def saveAsHadoopDataset(conf: JobConf) { - val outputFormatInstance = conf.getOutputFormat - val keyClass = conf.getOutputKeyClass - val valueClass = conf.getOutputValueClass + def saveAsHadoopDataset(hadoopConf: JobConf) { + val outputFormatInstance = hadoopConf.getOutputFormat + val keyClass = hadoopConf.getOutputKeyClass + val valueClass = hadoopConf.getOutputValueClass if (outputFormatInstance == null) { throw new SparkException("Output format class not set") } @@ -848,18 +849,18 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) if (valueClass == null) { throw new SparkException("Output value class not set") } - SparkHadoopUtil.get.addCredentials(conf) + SparkHadoopUtil.get.addCredentials(hadoopConf) logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + valueClass.getSimpleName + ")") if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) { // FileOutputFormat ignores the filesystem parameter - val ignoredFs = FileSystem.get(conf) - conf.getOutputFormat.checkOutputSpecs(ignoredFs, conf) + val ignoredFs = FileSystem.get(hadoopConf) + hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf) } - val writer = new SparkHadoopWriter(conf) + val writer = new SparkHadoopWriter(hadoopConf) writer.preSetup() def writeToFile(context: TaskContext, iter: Iterator[(K, V)]) { From a14807e84cbda64e5a73babb7a28c69ee1ef3cbb Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 17 Jun 2014 12:25:55 -0700 Subject: [PATCH 26/57] [SPARK-2147 / 2161] Show removed executors on the UI This PR includes two changes - **[SPARK-2147]** When an application finishes cleanly (i.e. `sc.stop()` is called), all of its executors used to disappear from the Master UI. This no longer happens. - **[SPARK-2161]** This adds a "Removed Executors" table to Master UI, so the user can find out why their executors died from the logs, for instance. The equivalent table already existed in the Worker UI, but was hidden because of a bug (the comment `//scalastyle:off` disconnected the `Seq[Node]` that represents the HTML for table). This should go into 1.0.1 if possible. Author: Andrew Or Closes #1102 from andrewor14/remember-removed-executors and squashes the following commits: 2e2298f [Andrew Or] Add hash code method to ExecutorInfo (minor) abd72e0 [Andrew Or] Merge branch 'master' of github.com:apache/spark into remember-removed-executors 792f992 [Andrew Or] Add missing equals method in ExecutorInfo 3390b49 [Andrew Or] Add executor state column to WorkerPage 161f8a2 [Andrew Or] Display finished executors table (fix bug) fbb65b8 [Andrew Or] Removed unused method c89bb6e [Andrew Or] Add table for removed executors in MasterWebUI fe47402 [Andrew Or] Show exited executors on the Master UI --- .../spark/deploy/master/ApplicationInfo.scala | 4 + .../spark/deploy/master/ExecutorInfo.scala | 15 +++ .../deploy/master/ui/ApplicationPage.scala | 80 +++++++++------- .../spark/deploy/worker/ui/WorkerPage.scala | 95 ++++++++----------- 4 files changed, 107 insertions(+), 87 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 46b9f4dc7d3ba..72d0589689e71 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy.master import java.util.Date import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import akka.actor.ActorRef @@ -36,6 +37,7 @@ private[spark] class ApplicationInfo( @transient var state: ApplicationState.Value = _ @transient var executors: mutable.HashMap[Int, ExecutorInfo] = _ + @transient var removedExecutors: ArrayBuffer[ExecutorInfo] = _ @transient var coresGranted: Int = _ @transient var endTime: Long = _ @transient var appSource: ApplicationSource = _ @@ -51,6 +53,7 @@ private[spark] class ApplicationInfo( endTime = -1L appSource = new ApplicationSource(this) nextExecutorId = 0 + removedExecutors = new ArrayBuffer[ExecutorInfo] } private def newExecutorId(useID: Option[Int] = None): Int = { @@ -74,6 +77,7 @@ private[spark] class ApplicationInfo( def removeExecutor(exec: ExecutorInfo) { if (executors.contains(exec.id)) { + removedExecutors += executors(exec.id) executors -= exec.id coresGranted -= exec.cores } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala index 76db61dd619c6..d417070c51016 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala @@ -34,4 +34,19 @@ private[spark] class ExecutorInfo( } def fullId: String = application.id + "/" + id + + override def equals(other: Any): Boolean = { + other match { + case info: ExecutorInfo => + fullId == info.fullId && + worker.id == info.worker.id && + cores == info.cores && + memory == info.memory + case _ => false + } + } + + override def toString: String = fullId + + override def hashCode: Int = toString.hashCode() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index b5cd4d2ea963f..34fa1429c86de 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -25,7 +25,7 @@ import scala.xml.Node import akka.pattern.ask import org.json4s.JValue -import org.apache.spark.deploy.JsonProtocol +import org.apache.spark.deploy.{ExecutorState, JsonProtocol} import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.ExecutorInfo import org.apache.spark.ui.{WebUIPage, UIUtils} @@ -57,43 +57,55 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app }) val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs") - val executors = app.executors.values.toSeq - val executorTable = UIUtils.listingTable(executorHeaders, executorRow, executors) + val allExecutors = (app.executors.values ++ app.removedExecutors).toSet.toSeq + // This includes executors that are either still running or have exited cleanly + val executors = allExecutors.filter { exec => + !ExecutorState.isFinished(exec.state) || exec.state == ExecutorState.EXITED + } + val removedExecutors = allExecutors.diff(executors) + val executorsTable = UIUtils.listingTable(executorHeaders, executorRow, executors) + val removedExecutorsTable = UIUtils.listingTable(executorHeaders, executorRow, removedExecutors) val content = -
-
-
    -
  • ID: {app.id}
  • -
  • Name: {app.desc.name}
  • -
  • User: {app.desc.user}
  • -
  • Cores: - { - if (app.desc.maxCores.isEmpty) { - "Unlimited (%s granted)".format(app.coresGranted) - } else { - "%s (%s granted, %s left)".format( - app.desc.maxCores.get, app.coresGranted, app.coresLeft) - } - } -
  • -
  • - Executor Memory: - {Utils.megabytesToString(app.desc.memoryPerSlave)} -
  • -
  • Submit Date: {app.submitDate}
  • -
  • State: {app.state}
  • -
  • Application Detail UI
  • -
-
+
+
+
    +
  • ID: {app.id}
  • +
  • Name: {app.desc.name}
  • +
  • User: {app.desc.user}
  • +
  • Cores: + { + if (app.desc.maxCores.isEmpty) { + "Unlimited (%s granted)".format(app.coresGranted) + } else { + "%s (%s granted, %s left)".format( + app.desc.maxCores.get, app.coresGranted, app.coresLeft) + } + } +
  • +
  • + Executor Memory: + {Utils.megabytesToString(app.desc.memoryPerSlave)} +
  • +
  • Submit Date: {app.submitDate}
  • +
  • State: {app.state}
  • +
  • Application Detail UI
  • +
+
-
-
-

Executor Summary

- {executorTable} -
-
; +
+
+

Executor Summary

+ {executorsTable} + { + if (removedExecutors.nonEmpty) { +

Removed Executors

++ + removedExecutorsTable + } + } +
+
; UIUtils.basicSparkPage(content, "Application: " + app.desc.name) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index d4513118ced05..327b905032800 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -46,74 +46,62 @@ private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] val workerState = Await.result(stateFuture, timeout) - val executorHeaders = Seq("ExecutorID", "Cores", "Memory", "Job Details", "Logs") + val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs") + val runningExecutors = workerState.executors val runningExecutorTable = - UIUtils.listingTable(executorHeaders, executorRow, workerState.executors) + UIUtils.listingTable(executorHeaders, executorRow, runningExecutors) + val finishedExecutors = workerState.finishedExecutors val finishedExecutorTable = - UIUtils.listingTable(executorHeaders, executorRow, workerState.finishedExecutors) + UIUtils.listingTable(executorHeaders, executorRow, finishedExecutors) val driverHeaders = Seq("DriverID", "Main Class", "State", "Cores", "Memory", "Logs", "Notes") val runningDrivers = workerState.drivers.sortBy(_.driverId).reverse val runningDriverTable = UIUtils.listingTable(driverHeaders, driverRow, runningDrivers) val finishedDrivers = workerState.finishedDrivers.sortBy(_.driverId).reverse - def finishedDriverTable = UIUtils.listingTable(driverHeaders, driverRow, finishedDrivers) + val finishedDriverTable = UIUtils.listingTable(driverHeaders, driverRow, finishedDrivers) // For now we only show driver information if the user has submitted drivers to the cluster. // This is until we integrate the notion of drivers and applications in the UI. - def hasDrivers = runningDrivers.length > 0 || finishedDrivers.length > 0 val content = -
-
-
    -
  • ID: {workerState.workerId}
  • -
  • - Master URL: {workerState.masterUrl} -
  • -
  • Cores: {workerState.cores} ({workerState.coresUsed} Used)
  • -
  • Memory: {Utils.megabytesToString(workerState.memory)} - ({Utils.megabytesToString(workerState.memoryUsed)} Used)
  • -
-

Back to Master

-
+
+
+
    +
  • ID: {workerState.workerId}
  • +
  • + Master URL: {workerState.masterUrl} +
  • +
  • Cores: {workerState.cores} ({workerState.coresUsed} Used)
  • +
  • Memory: {Utils.megabytesToString(workerState.memory)} + ({Utils.megabytesToString(workerState.memoryUsed)} Used)
  • +
+

Back to Master

- -
-
-

Running Executors {workerState.executors.size}

- {runningExecutorTable} -
-
- // scalastyle:off -
- {if (hasDrivers) -
-
-

Running Drivers {workerState.drivers.size}

- {runningDriverTable} -
-
+
+
+
+

Running Executors ({runningExecutors.size})

+ {runningExecutorTable} + { + if (runningDrivers.nonEmpty) { +

Running Drivers ({runningDrivers.size})

++ + runningDriverTable + } } -
- -
-
-

Finished Executors

- {finishedExecutorTable} -
-
- -
- {if (hasDrivers) -
-
-

Finished Drivers

- {finishedDriverTable} -
-
+ { + if (finishedExecutors.nonEmpty) { +

Finished Executors ({finishedExecutors.size})

++ + finishedExecutorTable + } } -
; - // scalastyle:on + { + if (finishedDrivers.nonEmpty) { +

Finished Drivers ({finishedDrivers.size})

++ + finishedDriverTable + } + } +
+
; UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format( workerState.host, workerState.port)) } @@ -122,6 +110,7 @@ private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { {executor.execId} {executor.cores} + {executor.state} {Utils.megabytesToString(executor.memory)} From b2ebf429e24566c29850c570f8d76943151ad78c Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 17 Jun 2014 15:09:24 -0700 Subject: [PATCH 27/57] HOTFIX: bug caused by #941 This patch should have qualified the use of PIPE. This needs to be back ported into 0.9 and 1.0. Author: Patrick Wendell Closes #1108 from pwendell/hotfix and squashes the following commits: 711c58d [Patrick Wendell] HOTFIX: bug caused by #941 --- ec2/spark_ec2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 803caa0c480e7..a40311d9fcf02 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -693,7 +693,7 @@ def ssh(host, opts, command): def _check_output(*popenargs, **kwargs): if 'stdout' in kwargs: raise ValueError('stdout argument not allowed, it will be overridden.') - process = subprocess.Popen(stdout=PIPE, *popenargs, **kwargs) + process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) output, unused_err = process.communicate() retcode = process.poll() if retcode: From d2f4f30b12f99358953e2781957468e2cfe3c916 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 17 Jun 2014 19:14:59 -0700 Subject: [PATCH 28/57] [SPARK-2060][SQL] Querying JSON Datasets with SQL and DSL in Spark SQL JIRA: https://issues.apache.org/jira/browse/SPARK-2060 Programming guide: http://yhuai.github.io/site/sql-programming-guide.html Scala doc of SQLContext: http://yhuai.github.io/site/api/scala/index.html#org.apache.spark.sql.SQLContext Author: Yin Huai Closes #999 from yhuai/newJson and squashes the following commits: 227e89e [Yin Huai] Merge remote-tracking branch 'upstream/master' into newJson ce8eedd [Yin Huai] rxin's comments. bc9ac51 [Yin Huai] Merge remote-tracking branch 'upstream/master' into newJson 94ffdaa [Yin Huai] Remove "get" from method names. ce31c81 [Yin Huai] Merge remote-tracking branch 'upstream/master' into newJson e2773a6 [Yin Huai] Merge remote-tracking branch 'upstream/master' into newJson 79ea9ba [Yin Huai] Fix typos. 5428451 [Yin Huai] Newline 1f908ce [Yin Huai] Remove extra line. d7a005c [Yin Huai] Merge remote-tracking branch 'upstream/master' into newJson 7ea750e [Yin Huai] marmbrus's comments. 6a5f5ef [Yin Huai] Merge remote-tracking branch 'upstream/master' into newJson 83013fb [Yin Huai] Update Java Example. e7a6c19 [Yin Huai] SchemaRDD.javaToPython should convert a field with the StructType to a Map. 6d20b85 [Yin Huai] Merge remote-tracking branch 'upstream/master' into newJson 4fbddf0 [Yin Huai] Programming guide. 9df8c5a [Yin Huai] Python API. 7027634 [Yin Huai] Java API. cff84cc [Yin Huai] Use a SchemaRDD for a JSON dataset. d0bd412 [Yin Huai] Merge remote-tracking branch 'upstream/master' into newJson ab810b0 [Yin Huai] Make JsonRDD private. 6df0891 [Yin Huai] Apache header. 8347f2e [Yin Huai] Merge remote-tracking branch 'upstream/master' into newJson 66f9e76 [Yin Huai] Update docs and use the entire dataset to infer the schema. 8ffed79 [Yin Huai] Update the example. a5a4b52 [Yin Huai] Merge remote-tracking branch 'upstream/master' into newJson 4325475 [Yin Huai] If a sampled dataset is used for schema inferring, update the schema of the JsonTable after first execution. 65b87f0 [Yin Huai] Fix sampling... 8846af5 [Yin Huai] API doc. 52a2275 [Yin Huai] Merge remote-tracking branch 'upstream/master' into newJson 0387523 [Yin Huai] Address PR comments. 666b957 [Yin Huai] Merge remote-tracking branch 'upstream/master' into newJson a2313a6 [Yin Huai] Address PR comments. f3ce176 [Yin Huai] After type conflict resolution, if a NullType is found, StringType is used. 0576406 [Yin Huai] Add Apache license header. af91b23 [Yin Huai] Merge remote-tracking branch 'upstream/master' into newJson f45583b [Yin Huai] Infer the schema of a JSON dataset (a text file with one JSON object per line or a RDD[String] with one JSON object per string) and returns a SchemaRDD. f31065f [Yin Huai] A query plan or a SchemaRDD can print out its schema. --- .rat-excludes | 1 + docs/sql-programming-guide.md | 290 +++++++--- .../spark/examples/sql/JavaSparkSQL.java | 78 ++- examples/src/main/resources/people.json | 3 + project/SparkBuild.scala | 22 +- python/pyspark/sql.py | 64 ++- sql/catalyst/pom.xml | 28 + .../catalyst/analysis/HiveTypeCoercion.scala | 25 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 51 ++ .../optimizer/CombiningLimitsSuite.scala | 3 +- .../optimizer/ConstantFoldingSuite.scala | 3 +- .../optimizer/FilterPushdownSuite.scala | 5 +- ...mplifyCaseConversionExpressionsSuite.scala | 3 +- .../PlanTest.scala} | 9 +- sql/core/pom.xml | 12 + .../org/apache/spark/sql/SQLContext.scala | 45 +- .../org/apache/spark/sql/SchemaRDD.scala | 38 +- .../org/apache/spark/sql/SchemaRDDLike.scala | 6 + .../spark/sql/api/java/JavaSQLContext.scala | 20 + .../org/apache/spark/sql/json/JsonRDD.scala | 397 ++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 4 +- .../spark/sql/api/java/JavaSQLSuite.scala | 45 ++ .../org/apache/spark/sql/json/JsonSuite.scala | 519 ++++++++++++++++++ .../apache/spark/sql/json/TestJsonData.scala | 84 +++ 24 files changed, 1644 insertions(+), 111 deletions(-) create mode 100644 examples/src/main/resources/people.json rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/{optimizer/OptimizerTest.scala => plans/PlanTest.scala} (88%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala diff --git a/.rat-excludes b/.rat-excludes index 52b2dfac5cf2b..15344dfb292db 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -22,6 +22,7 @@ spark-env.sh.template log4j-defaults.properties sorttable.js .*txt +.*json .*data .*log cloudpickle.py diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4623bb4247d77..522c83884ef42 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -17,20 +17,20 @@ Spark. At the core of this component is a new type of RDD, [Row](api/scala/index.html#org.apache.spark.sql.catalyst.expressions.Row) objects along with a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table in a traditional relational database. A SchemaRDD can be created from an existing RDD, [Parquet](http://parquet.io) -file, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`.
-Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using +Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using Spark. At the core of this component is a new type of RDD, [JavaSchemaRDD](api/scala/index.html#org.apache.spark.sql.api.java.JavaSchemaRDD). JavaSchemaRDDs are composed [Row](api/scala/index.html#org.apache.spark.sql.api.java.Row) objects along with a schema that describes the data types of each column in the row. A JavaSchemaRDD is similar to a table in a traditional relational database. A JavaSchemaRDD can be created from an existing RDD, [Parquet](http://parquet.io) -file, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/).
@@ -41,7 +41,7 @@ Spark. At the core of this component is a new type of RDD, [Row](api/python/pyspark.sql.Row-class.html) objects along with a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table in a traditional relational database. A SchemaRDD can be created from an existing RDD, [Parquet](http://parquet.io) -file, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). All of the examples on this page use sample data included in the Spark distribution and can be run in the `pyspark` shell.
@@ -64,8 +64,8 @@ descendants. To create a basic SQLContext, all you need is a SparkContext. val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// Importing the SQL context gives access to all the public SQL functions and implicit conversions. -import sqlContext._ +// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. +import sqlContext.createSchemaRDD {% endhighlight %}
@@ -77,8 +77,8 @@ The entry point into all relational functionality in Spark is the of its descendants. To create a basic JavaSQLContext, all you need is a JavaSparkContext. {% highlight java %} -JavaSparkContext ctx = ...; // An existing JavaSparkContext. -JavaSQLContext sqlCtx = new org.apache.spark.sql.api.java.JavaSQLContext(ctx); +JavaSparkContext sc = ...; // An existing JavaSparkContext. +JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); {% endhighlight %} @@ -91,14 +91,33 @@ of its decedents. To create a basic SQLContext, all you need is a SparkContext. {% highlight python %} from pyspark.sql import SQLContext -sqlCtx = SQLContext(sc) +sqlContext = SQLContext(sc) {% endhighlight %} -## Running SQL on RDDs +# Data Sources + +
+
+Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface. +Once a dataset has been loaded, it can be registered as a table and even joined with data from other sources. +
+ +
+Spark SQL supports operating on a variety of data sources through the `JavaSchemaRDD` interface. +Once a dataset has been loaded, it can be registered as a table and even joined with data from other sources. +
+ +
+Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface. +Once a dataset has been loaded, it can be registered as a table and even joined with data from other sources. +
+
+ +## RDDs
@@ -111,8 +130,10 @@ types such as Sequences or Arrays. This RDD can be implicitly converted to a Sch registered as a table. Tables can be used in subsequent SQL statements. {% highlight scala %} +// sc is an existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -import sqlContext._ +// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. +import sqlContext.createSchemaRDD // Define the schema using a case class. // Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, @@ -124,7 +145,7 @@ val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(" people.registerAsTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. -val teenagers = sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") // The results of SQL queries are SchemaRDDs and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -170,12 +191,11 @@ A schema can be applied to an existing RDD by calling `applySchema` and providin for the JavaBean. {% highlight java %} - -JavaSparkContext ctx = ...; // An existing JavaSparkContext. -JavaSQLContext sqlCtx = new org.apache.spark.sql.api.java.JavaSQLContext(ctx) +// sc is an existing JavaSparkContext. +JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc) // Load a text file and convert each line to a JavaBean. -JavaRDD people = ctx.textFile("examples/src/main/resources/people.txt").map( +JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").map( new Function() { public Person call(String line) throws Exception { String[] parts = line.split(","); @@ -189,11 +209,11 @@ JavaRDD people = ctx.textFile("examples/src/main/resources/people.txt"). }); // Apply a schema to an RDD of JavaBeans and register it as a table. -JavaSchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); +JavaSchemaRDD schemaPeople = sqlContext.applySchema(people, Person.class); schemaPeople.registerAsTable("people"); // SQL can be run over RDDs that have been registered as tables. -JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") // The results of SQL queries are SchemaRDDs and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -215,6 +235,10 @@ row. Any RDD of dictionaries can converted to a SchemaRDD and then registered as can be used in subsequent SQL statements. {% highlight python %} +# sc is an existing SparkContext. +from pyspark.sql import SQLContext +sqlContext = SQLContext(sc) + # Load a text file and convert each line to a dictionary. lines = sc.textFile("examples/src/main/resources/people.txt") parts = lines.map(lambda l: l.split(",")) @@ -223,14 +247,16 @@ people = parts.map(lambda p: {"name": p[0], "age": int(p[1])}) # Infer the schema, and register the SchemaRDD as a table. # In future versions of PySpark we would like to add support for registering RDDs with other # datatypes as tables -peopleTable = sqlCtx.inferSchema(people) -peopleTable.registerAsTable("people") +schemaPeople = sqlContext.inferSchema(people) +schemaPeople.registerAsTable("people") # SQL can be run over SchemaRDDs that have been registered as a table. -teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") # The results of SQL queries are RDDs and support all the normal RDD operations. teenNames = teenagers.map(lambda p: "Name: " + p.name) +for teenName in teenNames.collect(): + print teenName {% endhighlight %}
@@ -241,7 +267,7 @@ teenNames = teenagers.map(lambda p: "Name: " + p.name) Users that want a more complete dialect of SQL should look at the HiveQL support provided by `HiveContext`. -## Using Parquet +## Parquet Files [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. Spark SQL provides support for both reading and writing Parquet files that automatically preserves the schema @@ -252,22 +278,23 @@ of the original data. Using the data from the above example:
{% highlight scala %} -val sqlContext = new org.apache.spark.sql.SQLContext(sc) -import sqlContext._ +// sqlContext from the previous example is used in this example. +// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. +import sqlContext.createSchemaRDD val people: RDD[Person] = ... // An RDD of case class objects, from the previous example. -// The RDD is implicitly converted to a SchemaRDD, allowing it to be stored using Parquet. +// The RDD is implicitly converted to a SchemaRDD by createSchemaRDD, allowing it to be stored using Parquet. people.saveAsParquetFile("people.parquet") // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. -// The result of loading a Parquet file is also a JavaSchemaRDD. +// The result of loading a Parquet file is also a SchemaRDD. val parquetFile = sqlContext.parquetFile("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerAsTable("parquetFile") -val teenagers = sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") -teenagers.collect().foreach(println) +val teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") +teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %}
@@ -275,6 +302,7 @@ teenagers.collect().foreach(println)
{% highlight java %} +// sqlContext from the previous example is used in this example. JavaSchemaRDD schemaPeople = ... // The JavaSchemaRDD from the previous example. @@ -283,13 +311,16 @@ schemaPeople.saveAsParquetFile("people.parquet"); // Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a JavaSchemaRDD. -JavaSchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); +JavaSchemaRDD parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerAsTable("parquetFile"); -JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); - - +JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); +List teenagerNames = teenagers.map(new Function() { + public String call(Row row) { + return "Name: " + row.getString(0); + } +}).collect(); {% endhighlight %}
@@ -297,50 +328,149 @@ JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >=
{% highlight python %} +# sqlContext from the previous example is used in this example. -peopleTable # The SchemaRDD from the previous example. +schemaPeople # The SchemaRDD from the previous example. # SchemaRDDs can be saved as Parquet files, maintaining the schema information. -peopleTable.saveAsParquetFile("people.parquet") +schemaPeople.saveAsParquetFile("people.parquet") # Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a SchemaRDD. -parquetFile = sqlCtx.parquetFile("people.parquet") +parquetFile = sqlContext.parquetFile("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerAsTable("parquetFile"); -teenagers = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") - +teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") +teenNames = teenagers.map(lambda p: "Name: " + p.name) +for teenName in teenNames.collect(): + print teenName {% endhighlight %}
-## Writing Language-Integrated Relational Queries +## JSON Datasets +
-**Language-Integrated queries are currently only supported in Scala.** +
+Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD. +This conversion can be done using one of two methods in a SQLContext: -Spark SQL also supports a domain specific language for writing queries. Once again, -using the data from the above examples: +* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. +* `jsonRdd` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. {% highlight scala %} +// sc is an existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -import sqlContext._ -val people: RDD[Person] = ... // An RDD of case class objects, from the first example. -// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19' -val teenagers = people.where('age >= 10).where('age <= 19).select('name) +// A JSON dataset is pointed to by path. +// The path can be either a single text file or a directory storing text files. +val path = "examples/src/main/resources/people.json" +// Create a SchemaRDD from the file(s) pointed to by path +val people = sqlContext.jsonFile(path) + +// The inferred schema can be visualized using the printSchema() method. +people.printSchema() +// root +// |-- age: IntegerType +// |-- name: StringType + +// Register this SchemaRDD as a table. +people.registerAsTable("people") + +// SQL statements can be run by using the sql methods provided by sqlContext. +val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") + +// Alternatively, a SchemaRDD can be created for a JSON dataset represented by +// an RDD[String] storing one JSON object per string. +val anotherPeopleRDD = sc.parallelize( + """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) +val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) {% endhighlight %} -The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers -prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are -evaluated by the SQL execution engine. A full list of the functions supported can be found in the -[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). +
- +
+Spark SQL can automatically infer the schema of a JSON dataset and load it as a JavaSchemaRDD. +This conversion can be done using one of two methods in a JavaSQLContext : -# Hive Support +* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. +* `jsonRdd` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. + +{% highlight java %} +// sc is an existing JavaSparkContext. +JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); + +// A JSON dataset is pointed to by path. +// The path can be either a single text file or a directory storing text files. +String path = "examples/src/main/resources/people.json"; +// Create a JavaSchemaRDD from the file(s) pointed to by path +JavaSchemaRDD people = sqlContext.jsonFile(path); + +// The inferred schema can be visualized using the printSchema() method. +people.printSchema(); +// root +// |-- age: IntegerType +// |-- name: StringType + +// Register this JavaSchemaRDD as a table. +people.registerAsTable("people"); + +// SQL statements can be run by using the sql methods provided by sqlContext. +JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + +// Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by +// an RDD[String] storing one JSON object per string. +List jsonData = Arrays.asList( + "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); +JavaRDD anotherPeopleRDD = sc.parallelize(jsonData); +JavaSchemaRDD anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); +{% endhighlight %} +
+ +
+Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD. +This conversion can be done using one of two methods in a SQLContext: + +* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. +* `jsonRdd` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. + +{% highlight python %} +# sc is an existing SparkContext. +from pyspark.sql import SQLContext +sqlContext = SQLContext(sc) + +# A JSON dataset is pointed to by path. +# The path can be either a single text file or a directory storing text files. +path = "examples/src/main/resources/people.json" +# Create a SchemaRDD from the file(s) pointed to by path +people = sqlContext.jsonFile(path) + +# The inferred schema can be visualized using the printSchema() method. +people.printSchema() +# root +# |-- age: IntegerType +# |-- name: StringType + +# Register this SchemaRDD as a table. +people.registerAsTable("people") + +# SQL statements can be run by using the sql methods provided by sqlContext. +teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") + +# Alternatively, a SchemaRDD can be created for a JSON dataset represented by +# an RDD[String] storing one JSON object per string. +anotherPeopleRDD = sc.parallelize([ + '{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}']) +anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) +{% endhighlight %} +
+ +
+ +## Hive Tables Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/). However, since Hive has a large number of dependencies, it is not included in the default Spark assembly. @@ -362,17 +492,14 @@ which is similar to `HiveContext`, but creates a local copy of the `metastore` a automatically. {% highlight scala %} -val sc: SparkContext // An existing SparkContext. +// sc is an existing SparkContext. val hiveContext = new org.apache.spark.sql.hive.HiveContext(sc) -// Importing the SQL context gives access to all the public SQL functions and implicit conversions. -import hiveContext._ - -hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +hiveContext.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +hiveContext.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") // Queries are expressed in HiveQL -hql("FROM src SELECT key, value").collect().foreach(println) +hiveContext.hql("FROM src SELECT key, value").collect().foreach(println) {% endhighlight %} @@ -385,14 +512,14 @@ the `sql` method a `JavaHiveContext` also provides an `hql` methods, which allow expressed in HiveQL. {% highlight java %} -JavaSparkContext ctx = ...; // An existing JavaSparkContext. -JavaHiveContext hiveCtx = new org.apache.spark.sql.hive.api.java.HiveContext(ctx); +// sc is an existing JavaSparkContext. +JavaHiveContext hiveContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc); -hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); -hiveCtx.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); +hiveContext.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); +hiveContext.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); // Queries are expressed in HiveQL. -Row[] results = hiveCtx.hql("FROM src SELECT key, value").collect(); +Row[] results = hiveContext.hql("FROM src SELECT key, value").collect(); {% endhighlight %} @@ -406,17 +533,44 @@ the `sql` method a `HiveContext` also provides an `hql` methods, which allows qu expressed in HiveQL. {% highlight python %} - +# sc is an existing SparkContext. from pyspark.sql import HiveContext -hiveCtx = HiveContext(sc) +hiveContext = HiveContext(sc) -hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hiveCtx.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +hiveContext.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +hiveContext.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. -results = hiveCtx.hql("FROM src SELECT key, value").collect() +results = hiveContext.hql("FROM src SELECT key, value").collect() {% endhighlight %} + + +# Writing Language-Integrated Relational Queries + +**Language-Integrated queries are currently only supported in Scala.** + +Spark SQL also supports a domain specific language for writing queries. Once again, +using the data from the above examples: + +{% highlight scala %} +// sc is an existing SparkContext. +val sqlContext = new org.apache.spark.sql.SQLContext(sc) +// Importing the SQL context gives access to all the public SQL functions and implicit conversions. +import sqlContext._ +val people: RDD[Person] = ... // An RDD of case class objects, from the first example. + +// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19' +val teenagers = people.where('age >= 10).where('age <= 19).select('name) +teenagers.map(t => "Name: " + t(0)).collect().foreach(println) +{% endhighlight %} + +The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers +prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are +evaluated by the SQL execution engine. A full list of the functions supported can be found in the +[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). + + \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index ad5ec84b71e69..607df3eddd550 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -18,6 +18,7 @@ package org.apache.spark.examples.sql; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import org.apache.spark.SparkConf; @@ -56,6 +57,7 @@ public static void main(String[] args) throws Exception { JavaSparkContext ctx = new JavaSparkContext(sparkConf); JavaSQLContext sqlCtx = new JavaSQLContext(ctx); + System.out.println("=== Data source: RDD ==="); // Load a text file and convert each line to a Java Bean. JavaRDD people = ctx.textFile("examples/src/main/resources/people.txt").map( new Function() { @@ -84,16 +86,88 @@ public String call(Row row) { return "Name: " + row.getString(0); } }).collect(); + for (String name: teenagerNames) { + System.out.println(name); + } + System.out.println("=== Data source: Parquet File ==="); // JavaSchemaRDDs can be saved as parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet"); - // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. + // Read in the parquet file created above. + // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a JavaSchemaRDD. JavaSchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerAsTable("parquetFile"); - JavaSchemaRDD teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); + JavaSchemaRDD teenagers2 = + sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); + teenagerNames = teenagers2.map(new Function() { + public String call(Row row) { + return "Name: " + row.getString(0); + } + }).collect(); + for (String name: teenagerNames) { + System.out.println(name); + } + + System.out.println("=== Data source: JSON Dataset ==="); + // A JSON dataset is pointed by path. + // The path can be either a single text file or a directory storing text files. + String path = "examples/src/main/resources/people.json"; + // Create a JavaSchemaRDD from the file(s) pointed by path + JavaSchemaRDD peopleFromJsonFile = sqlCtx.jsonFile(path); + + // Because the schema of a JSON dataset is automatically inferred, to write queries, + // it is better to take a look at what is the schema. + peopleFromJsonFile.printSchema(); + // The schema of people is ... + // root + // |-- age: IntegerType + // |-- name: StringType + + // Register this JavaSchemaRDD as a table. + peopleFromJsonFile.registerAsTable("people"); + + // SQL statements can be run by using the sql methods provided by sqlCtx. + JavaSchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + + // The results of SQL queries are JavaSchemaRDDs and support all the normal RDD operations. + // The columns of a row in the result can be accessed by ordinal. + teenagerNames = teenagers3.map(new Function() { + public String call(Row row) { return "Name: " + row.getString(0); } + }).collect(); + for (String name: teenagerNames) { + System.out.println(name); + } + + // Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by + // a RDD[String] storing one JSON object per string. + List jsonData = Arrays.asList( + "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); + JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); + JavaSchemaRDD peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD); + + // Take a look at the schema of this new JavaSchemaRDD. + peopleFromJsonRDD.printSchema(); + // The schema of anotherPeople is ... + // root + // |-- address: StructType + // | |-- city: StringType + // | |-- state: StringType + // |-- name: StringType + + peopleFromJsonRDD.registerAsTable("people2"); + + JavaSchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + List nameAndCity = peopleWithCity.map(new Function() { + public String call(Row row) { + return "Name: " + row.getString(0) + ", City: " + row.getString(1); + } + }).collect(); + for (String name: nameAndCity) { + System.out.println(name); + } } } diff --git a/examples/src/main/resources/people.json b/examples/src/main/resources/people.json new file mode 100644 index 0000000000000..50a859cbd7ee8 --- /dev/null +++ b/examples/src/main/resources/people.json @@ -0,0 +1,3 @@ +{"name":"Michael"} +{"name":"Andy", "age":30} +{"name":"Justin", "age":19} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2d60a44f04f6f..7bb39dc77120b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -76,7 +76,7 @@ object SparkBuild extends Build { lazy val catalyst = Project("catalyst", file("sql/catalyst"), settings = catalystSettings) dependsOn(core) - lazy val sql = Project("sql", file("sql/core"), settings = sqlCoreSettings) dependsOn(core, catalyst) + lazy val sql = Project("sql", file("sql/core"), settings = sqlCoreSettings) dependsOn(core) dependsOn(catalyst % "compile->compile;test->test") lazy val hive = Project("hive", file("sql/hive"), settings = hiveSettings) dependsOn(sql) @@ -501,9 +501,23 @@ object SparkBuild extends Build { def sqlCoreSettings = sharedSettings ++ Seq( name := "spark-sql", libraryDependencies ++= Seq( - "com.twitter" % "parquet-column" % parquetVersion, - "com.twitter" % "parquet-hadoop" % parquetVersion - ) + "com.twitter" % "parquet-column" % parquetVersion, + "com.twitter" % "parquet-hadoop" % parquetVersion, + "com.fasterxml.jackson.core" % "jackson-databind" % "2.3.0" // json4s-jackson 3.2.6 requires jackson-databind 2.3.0. + ), + initialCommands in console := + """ + |import org.apache.spark.sql.catalyst.analysis._ + |import org.apache.spark.sql.catalyst.dsl._ + |import org.apache.spark.sql.catalyst.errors._ + |import org.apache.spark.sql.catalyst.expressions._ + |import org.apache.spark.sql.catalyst.plans.logical._ + |import org.apache.spark.sql.catalyst.rules._ + |import org.apache.spark.sql.catalyst.types._ + |import org.apache.spark.sql.catalyst.util._ + |import org.apache.spark.sql.execution + |import org.apache.spark.sql.test.TestSQLContext._ + |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin ) // Since we don't include hive in the main assembly this project also acts as an alternative diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index c31d49ce837fc..5051c82da32a7 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -15,7 +15,7 @@ # limitations under the License. # -from pyspark.rdd import RDD +from pyspark.rdd import RDD, PipelinedRDD from pyspark.serializers import BatchedSerializer, PickleSerializer from py4j.protocol import Py4JError @@ -137,6 +137,53 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path) return SchemaRDD(jschema_rdd, self) + + def jsonFile(self, path): + """Loads a text file storing one JSON object per line, + returning the result as a L{SchemaRDD}. + It goes through the entire dataset once to determine the schema. + + >>> import tempfile, shutil + >>> jsonFile = tempfile.mkdtemp() + >>> shutil.rmtree(jsonFile) + >>> ofn = open(jsonFile, 'w') + >>> for json in jsonStrings: + ... print>>ofn, json + >>> ofn.close() + >>> srdd = sqlCtx.jsonFile(jsonFile) + >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2, field3 as f3 from table1") + >>> srdd2.collect() == [{"f1": 1, "f2": "row1", "f3":{"field4":11}}, + ... {"f1": 2, "f2": "row2", "f3":{"field4":22}}, + ... {"f1": 3, "f2": "row3", "f3":{"field4":33}}] + True + """ + jschema_rdd = self._ssql_ctx.jsonFile(path) + return SchemaRDD(jschema_rdd, self) + + def jsonRDD(self, rdd): + """Loads an RDD storing one JSON object per string, returning the result as a L{SchemaRDD}. + It goes through the entire dataset once to determine the schema. + + >>> srdd = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2, field3 as f3 from table1") + >>> srdd2.collect() == [{"f1": 1, "f2": "row1", "f3":{"field4":11}}, + ... {"f1": 2, "f2": "row2", "f3":{"field4":22}}, + ... {"f1": 3, "f2": "row3", "f3":{"field4":33}}] + True + """ + def func(split, iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + yield x.encode("utf-8") + keyed = PipelinedRDD(rdd, func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._jvm.BytesToString()) + jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + return SchemaRDD(jschema_rdd, self) + def sql(self, sqlQuery): """Return a L{SchemaRDD} representing the result of the given query. @@ -265,7 +312,7 @@ class SchemaRDD(RDD): For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the L{SchemaRDD} is not operated on directly, as it's underlying - implementation is a RDD composed of Java objects. Instead it is + implementation is an RDD composed of Java objects. Instead it is converted to a PythonRDD in the JVM, on which Python operations can be done. """ @@ -337,6 +384,14 @@ def saveAsTable(self, tableName): """Creates a new table with the contents of this SchemaRDD.""" self._jschema_rdd.saveAsTable(tableName) + def schemaString(self): + """Returns the output schema in the tree format.""" + return self._jschema_rdd.schemaString() + + def printSchema(self): + """Prints out the schema in the tree format.""" + print self.schemaString() + def count(self): """Return the number of elements in this RDD. @@ -436,6 +491,11 @@ def _test(): globs['sqlCtx'] = SQLContext(sc) globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]) + jsonStrings = ['{"field1": 1, "field2": "row1", "field3":{"field4":11}}', + '{"field1" : 2, "field2": "row2", "field3":{"field4":22}}', + '{"field1" : 3, "field2": "row3", "field3":{"field4":33}}'] + globs['jsonStrings'] = jsonStrings + globs['json'] = sc.parallelize(jsonStrings) globs['nestedRdd1'] = sc.parallelize([ {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 6c78c34486010..01d7b569080ea 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -66,6 +66,34 @@ org.scalatest scalatest-maven-plugin + + + + org.apache.maven.plugins + maven-jar-plugin + + + + test-jar + + + + test-jar-on-compile + compile + + test-jar + + + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index d291814c8aa7c..66bff660cadc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -22,6 +22,16 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types._ +object HiveTypeCoercion { + // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. + // The conversion for integral and floating point types have a linear widening hierarchy: + val numericPrecedence = + Seq(NullType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType) + // Boolean is only wider than Void + val booleanPrecedence = Seq(NullType, BooleanType) + val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: booleanPrecedence :: Nil +} + /** * A collection of [[catalyst.rules.Rule Rules]] that can be used to coerce differing types that * participate in operations into compatible ones. Most of these rules are based on Hive semantics, @@ -116,19 +126,18 @@ trait HiveTypeCoercion { * * Additionally, all types when UNION-ed with strings will be promoted to strings. * Other string conversions are handled by PromoteStrings. + * + * Widening types might result in loss of precision in the following cases: + * - IntegerType to FloatType + * - LongType to FloatType + * - LongType to DoubleType */ object WidenTypes extends Rule[LogicalPlan] { - // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. - // The conversion for integral and floating point types have a linear widening hierarchy: - val numericPrecedence = - Seq(NullType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType) - // Boolean is only wider than Void - val booleanPrecedence = Seq(NullType, BooleanType) - val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: booleanPrecedence :: Nil def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { // Try and find a promotion rule that contains both types in question. - val applicableConversion = allPromotions.find(p => p.contains(t1) && p.contains(t2)) + val applicableConversion = + HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) // If found return the widest common type, otherwise None applicableConversion.map(_.filter(t => t == t1 || t == t2).last) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 8199a80f5d6bd..00e2d3bc24be9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.types.{ArrayType, DataType, StructField, StructType} abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { self: PlanType with Product => @@ -123,4 +125,53 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case other => Nil }.toSeq } + + protected def generateSchemaString(schema: Seq[Attribute]): String = { + val builder = new StringBuilder + builder.append("root\n") + val prefix = " |" + schema.foreach { attribute => + val name = attribute.name + val dataType = attribute.dataType + dataType match { + case fields: StructType => + builder.append(s"$prefix-- $name: $StructType\n") + generateSchemaString(fields, s"$prefix |", builder) + case ArrayType(fields: StructType) => + builder.append(s"$prefix-- $name: $ArrayType[$StructType]\n") + generateSchemaString(fields, s"$prefix |", builder) + case ArrayType(elementType: DataType) => + builder.append(s"$prefix-- $name: $ArrayType[$elementType]\n") + case _ => builder.append(s"$prefix-- $name: $dataType\n") + } + } + + builder.toString() + } + + protected def generateSchemaString( + schema: StructType, + prefix: String, + builder: StringBuilder): StringBuilder = { + schema.fields.foreach { + case StructField(name, fields: StructType, _) => + builder.append(s"$prefix-- $name: $StructType\n") + generateSchemaString(fields, s"$prefix |", builder) + case StructField(name, ArrayType(fields: StructType), _) => + builder.append(s"$prefix-- $name: $ArrayType[$StructType]\n") + generateSchemaString(fields, s"$prefix |", builder) + case StructField(name, ArrayType(elementType: DataType), _) => + builder.append(s"$prefix-- $name: $ArrayType[$elementType]\n") + case StructField(name, fieldType: DataType, _) => + builder.append(s"$prefix-- $name: $fieldType\n") + } + + builder + } + + /** Returns the output schema in the tree format. */ + def schemaString: String = generateSchemaString(output) + + /** Prints out the schema in the tree format */ + def printSchema(): Unit = println(schemaString) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 714f01843c0f5..4896f1b955f01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -class CombiningLimitsSuite extends OptimizerTest { +class CombiningLimitsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 6efc0e211eb21..cea97c584f7e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.types._ @@ -27,7 +28,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -class ConstantFoldingSuite extends OptimizerTest { +class ConstantFoldingSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 1f67c80e54906..ebb123c1f909e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -20,13 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.LeftOuter -import org.apache.spark.sql.catalyst.plans.RightOuter +import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -class FilterPushdownSuite extends OptimizerTest { +class FilterPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala index df1409fe7baee..22992fb6f50d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -class SimplifyCaseConversionExpressionsSuite extends OptimizerTest { +class SimplifyCaseConversionExpressionsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala similarity index 88% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 89982d5cd8d74..7e9f47ef21df8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -15,19 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.optimizer +package org.apache.spark.sql.catalyst.plans import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ /** - * Provides helper methods for comparing plans produced by optimization rules with the expected - * result + * Provides helper methods for comparing plans. */ -class OptimizerTest extends FunSuite { +class PlanTest extends FunSuite { /** * Since attribute references are given globally unique ids during analysis, diff --git a/sql/core/pom.xml b/sql/core/pom.xml index e65ca6be485e3..8210fd1f210d1 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -43,6 +43,13 @@ spark-catalyst_${scala.binary.version} ${project.version} + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + com.twitter parquet-column @@ -53,6 +60,11 @@ parquet-hadoop ${parquet.version} + + com.fasterxml.jackson.core + jackson-databind + 2.3.0 + org.scalatest scalatest_${scala.binary.version} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 131c130bbb3e8..f7e03323bed33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -22,24 +22,22 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration -import org.apache.spark.SparkContext import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD - import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.{ScalaReflection, dsl} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.dsl.ExpressionConversions import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor - import org.apache.spark.sql.columnar.InMemoryRelation - import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies - +import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.SparkContext /** * :: AlphaComponent :: @@ -53,7 +51,7 @@ import org.apache.spark.sql.parquet.ParquetRelation class SQLContext(@transient val sparkContext: SparkContext) extends Logging with SQLConf - with dsl.ExpressionConversions + with ExpressionConversions with Serializable { self => @@ -98,6 +96,39 @@ class SQLContext(@transient val sparkContext: SparkContext) def parquetFile(path: String): SchemaRDD = new SchemaRDD(this, parquet.ParquetRelation(path)) + /** + * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]]. + * It goes through the entire dataset once to determine the schema. + * + * @group userf + */ + def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0) + + /** + * :: Experimental :: + */ + @Experimental + def jsonFile(path: String, samplingRatio: Double): SchemaRDD = { + val json = sparkContext.textFile(path) + jsonRDD(json, samplingRatio) + } + + /** + * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a + * [[SchemaRDD]]. + * It goes through the entire dataset once to determine the schema. + * + * @group userf + */ + def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0) + + /** + * :: Experimental :: + */ + @Experimental + def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = + new SchemaRDD(this, JsonRDD.inferSchema(json, samplingRatio)) + /** * :: Experimental :: * Creates an empty parquet file with the schema of class `A`, which can be registered as a table. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 89eaba2d19aa1..7c0efb4566610 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.types.BooleanType +import org.apache.spark.sql.catalyst.types.{DataType, StructType, BooleanType} import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} import org.apache.spark.api.java.JavaRDD import java.util.{Map => JMap} @@ -41,8 +41,10 @@ import java.util.{Map => JMap} * whose elements are scala case classes into a SchemaRDD. This conversion can also be done * explicitly using the `createSchemaRDD` function on a [[SQLContext]]. * - * A `SchemaRDD` can also be created by loading data in from external sources, for example, - * by using the `parquetFile` method on [[SQLContext]]. + * A `SchemaRDD` can also be created by loading data in from external sources. + * Examples are loading data from Parquet files by using by using the + * `parquetFile` method on [[SQLContext]], and loading JSON datasets + * by using `jsonFile` and `jsonRDD` methods on [[SQLContext]]. * * == SQL Queries == * A SchemaRDD can be registered as a table in the [[SQLContext]] that was used to create it. Once @@ -341,14 +343,38 @@ class SchemaRDD( */ def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan) + /** + * Converts a JavaRDD to a PythonRDD. It is used by pyspark. + */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val fieldNames: Seq[String] = this.queryExecution.analyzed.output.map(_.name) + def rowToMap(row: Row, structType: StructType): JMap[String, Any] = { + val fields = structType.fields.map(field => (field.name, field.dataType)) + val map: JMap[String, Any] = new java.util.HashMap + row.zip(fields).foreach { + case (obj, (name, dataType)) => + dataType match { + case struct: StructType => map.put(name, rowToMap(obj.asInstanceOf[Row], struct)) + case other => map.put(name, obj) + } + } + + map + } + + // TODO: Actually, the schema of a row should be represented by a StructType instead of + // a Seq[Attribute]. Once we have finished that change, we can just use rowToMap to + // construct the Map for python. + val fields: Seq[(String, DataType)] = this.queryExecution.analyzed.output.map( + field => (field.name, field.dataType)) this.mapPartitions { iter => val pickle = new Pickler iter.map { row => val map: JMap[String, Any] = new java.util.HashMap - row.zip(fieldNames).foreach { case (obj, name) => - map.put(name, obj) + row.zip(fields).foreach { case (obj, (name, dataType)) => + dataType match { + case struct: StructType => map.put(name, rowToMap(obj.asInstanceOf[Row], struct)) + case other => map.put(name, obj) + } } map }.grouped(10).map(batched => pickle.dumps(batched.toArray)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 656be965a8fd9..fe81721943202 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -122,4 +122,10 @@ private[sql] trait SchemaRDDLike { @Experimental def saveAsTable(tableName: String): Unit = sqlContext.executePlan(InsertIntoCreatedTable(None, tableName, logicalPlan)).toRdd + + /** Returns the output schema in the tree format. */ + def schemaString: String = queryExecution.analyzed.schemaString + + /** Prints out the schema in the tree format. */ + def printSchema(): Unit = println(schemaString) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 352260fa15bbc..ff9842267ffe0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} import org.apache.spark.sql.catalyst.types._ @@ -100,6 +101,25 @@ class JavaSQLContext(val sqlContext: SQLContext) { def parquetFile(path: String): JavaSchemaRDD = new JavaSchemaRDD(sqlContext, ParquetRelation(path)) + /** + * Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]]. + * It goes through the entire dataset once to determine the schema. + * + * @group userf + */ + def jsonFile(path: String): JavaSchemaRDD = + jsonRDD(sqlContext.sparkContext.textFile(path)) + + /** + * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a + * [[JavaSchemaRDD]]. + * It goes through the entire dataset once to determine the schema. + * + * @group userf + */ + def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = + new JavaSchemaRDD(sqlContext, JsonRDD.inferSchema(json, 1.0)) + /** * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only * during the lifetime of this instance of SQLContext. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala new file mode 100644 index 0000000000000..edf86775579d8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.json + +import scala.collection.JavaConversions._ +import scala.math.BigDecimal + +import com.fasterxml.jackson.databind.ObjectMapper + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.Logging + +private[sql] object JsonRDD extends Logging { + + private[sql] def inferSchema( + json: RDD[String], + samplingRatio: Double = 1.0): LogicalPlan = { + require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") + val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) + val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _) + val baseSchema = createSchema(allKeys) + + createLogicalPlan(json, baseSchema) + } + + private def createLogicalPlan( + json: RDD[String], + baseSchema: StructType): LogicalPlan = { + val schema = nullTypeToStringType(baseSchema) + + SparkLogicalPlan(ExistingRdd(asAttributes(schema), parseJson(json).map(asRow(_, schema)))) + } + + private def createSchema(allKeys: Set[(String, DataType)]): StructType = { + // Resolve type conflicts + val resolved = allKeys.groupBy { + case (key, dataType) => key + }.map { + // Now, keys and types are organized in the format of + // key -> Set(type1, type2, ...). + case (key, typeSet) => { + val fieldName = key.substring(1, key.length - 1).split("`.`").toSeq + val dataType = typeSet.map { + case (_, dataType) => dataType + }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) + + (fieldName, dataType) + } + } + + def makeStruct(values: Seq[Seq[String]], prefix: Seq[String]): StructType = { + val (topLevel, structLike) = values.partition(_.size == 1) + val topLevelFields = topLevel.filter { + name => resolved.get(prefix ++ name).get match { + case ArrayType(StructType(Nil)) => false + case ArrayType(_) => true + case struct: StructType => false + case _ => true + } + }.map { + a => StructField(a.head, resolved.get(prefix ++ a).get, nullable = true) + } + + val structFields: Seq[StructField] = structLike.groupBy(_(0)).map { + case (name, fields) => { + val nestedFields = fields.map(_.tail) + val structType = makeStruct(nestedFields, prefix :+ name) + val dataType = resolved.get(prefix :+ name).get + dataType match { + case array: ArrayType => Some(StructField(name, ArrayType(structType), nullable = true)) + case struct: StructType => Some(StructField(name, structType, nullable = true)) + // dataType is StringType means that we have resolved type conflicts involving + // primitive types and complex types. So, the type of name has been relaxed to + // StringType. Also, this field should have already been put in topLevelFields. + case StringType => None + } + } + }.flatMap(field => field).toSeq + + StructType( + (topLevelFields ++ structFields).sortBy { + case StructField(name, _, _) => name + }) + } + + makeStruct(resolved.keySet.toSeq, Nil) + } + + /** + * Returns the most general data type for two given data types. + */ + private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { + // Try and find a promotion rule that contains both types in question. + val applicableConversion = HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p + .contains(t2)) + + // If found return the widest common type, otherwise None + val returnType = applicableConversion.map(_.filter(t => t == t1 || t == t2).last) + + if (returnType.isDefined) { + returnType.get + } else { + // t1 or t2 is a StructType, ArrayType, or an unexpected type. + (t1, t2) match { + case (other: DataType, NullType) => other + case (NullType, other: DataType) => other + case (StructType(fields1), StructType(fields2)) => { + val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { + case (name, fieldTypes) => { + val dataType = fieldTypes.map(field => field.dataType).reduce( + (type1: DataType, type2: DataType) => compatibleType(type1, type2)) + StructField(name, dataType, true) + } + } + StructType(newFields.toSeq.sortBy { + case StructField(name, _, _) => name + }) + } + case (ArrayType(elementType1), ArrayType(elementType2)) => + ArrayType(compatibleType(elementType1, elementType2)) + // TODO: We should use JsonObjectStringType to mark that values of field will be + // strings and every string is a Json object. + case (_, _) => StringType + } + } + } + + private def typeOfPrimitiveValue(value: Any): DataType = { + value match { + case value: java.lang.String => StringType + case value: java.lang.Integer => IntegerType + case value: java.lang.Long => LongType + // Since we do not have a data type backed by BigInteger, + // when we see a Java BigInteger, we use DecimalType. + case value: java.math.BigInteger => DecimalType + case value: java.lang.Double => DoubleType + case value: java.math.BigDecimal => DecimalType + case value: java.lang.Boolean => BooleanType + case null => NullType + // Unexpected data type. + case _ => StringType + } + } + + /** + * Returns the element type of an JSON array. We go through all elements of this array + * to detect any possible type conflict. We use [[compatibleType]] to resolve + * type conflicts. Right now, when the element of an array is another array, we + * treat the element as String. + */ + private def typeOfArray(l: Seq[Any]): ArrayType = { + val elements = l.flatMap(v => Option(v)) + if (elements.isEmpty) { + // If this JSON array is empty, we use NullType as a placeholder. + // If this array is not empty in other JSON objects, we can resolve + // the type after we have passed through all JSON objects. + ArrayType(NullType) + } else { + val elementType = elements.map { + e => e match { + case map: Map[_, _] => StructType(Nil) + // We have an array of arrays. If those element arrays do not have the same + // element types, we will return ArrayType[StringType]. + case seq: Seq[_] => typeOfArray(seq) + case value => typeOfPrimitiveValue(value) + } + }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) + + ArrayType(elementType) + } + } + + /** + * Figures out all key names and data types of values from a parsed JSON object + * (in the format of Map[Stirng, Any]). When the value of a key is an JSON object, we + * only use a placeholder (StructType(Nil)) to mark that it should be a struct + * instead of getting all fields of this struct because a field does not appear + * in this JSON object can appear in other JSON objects. + */ + private def allKeysWithValueTypes(m: Map[String, Any]): Set[(String, DataType)] = { + m.map{ + // Quote the key with backticks to handle cases which have dots + // in the field name. + case (key, dataType) => (s"`$key`", dataType) + }.flatMap { + case (key: String, struct: Map[String, Any]) => { + // The value associted with the key is an JSON object. + allKeysWithValueTypes(struct).map { + case (k, dataType) => (s"$key.$k", dataType) + } ++ Set((key, StructType(Nil))) + } + case (key: String, array: List[Any]) => { + // The value associted with the key is an array. + typeOfArray(array) match { + case ArrayType(StructType(Nil)) => { + // The elements of this arrays are structs. + array.asInstanceOf[List[Map[String, Any]]].flatMap { + element => allKeysWithValueTypes(element) + }.map { + case (k, dataType) => (s"$key.$k", dataType) + } :+ (key, ArrayType(StructType(Nil))) + } + case ArrayType(elementType) => (key, ArrayType(elementType)) :: Nil + } + } + case (key: String, value) => (key, typeOfPrimitiveValue(value)) :: Nil + }.toSet + } + + /** + * Converts a Java Map/List to a Scala Map/List. + * We do not use Jackson's scala module at here because + * DefaultScalaModule in jackson-module-scala will make + * the parsing very slow. + */ + private def scalafy(obj: Any): Any = obj match { + case map: java.util.Map[String, Object] => + // .map(identity) is used as a workaround of non-serializable Map + // generated by .mapValues. + // This issue is documented at https://issues.scala-lang.org/browse/SI-7005 + map.toMap.mapValues(scalafy).map(identity) + case list: java.util.List[Object] => + list.toList.map(scalafy) + case atom => atom + } + + private def parseJson(json: RDD[String]): RDD[Map[String, Any]] = { + // According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72], + // ObjectMapper will not return BigDecimal when + // "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled + // (see NumberDeserializer.deserialize for the logic). + // But, we do not want to enable this feature because it will use BigDecimal + // for every float number, which will be slow. + // So, right now, we will have Infinity for those BigDecimal number. + // TODO: Support BigDecimal. + json.mapPartitions(iter => { + // When there is a key appearing multiple times (a duplicate key), + // the ObjectMapper will take the last value associated with this duplicate key. + // For example: for {"key": 1, "key":2}, we will get "key"->2. + val mapper = new ObjectMapper() + iter.map(record => mapper.readValue(record, classOf[java.util.Map[String, Any]])) + }).map(scalafy).map(_.asInstanceOf[Map[String, Any]]) + } + + private def toLong(value: Any): Long = { + value match { + case value: java.lang.Integer => value.asInstanceOf[Int].toLong + case value: java.lang.Long => value.asInstanceOf[Long] + } + } + + private def toDouble(value: Any): Double = { + value match { + case value: java.lang.Integer => value.asInstanceOf[Int].toDouble + case value: java.lang.Long => value.asInstanceOf[Long].toDouble + case value: java.lang.Double => value.asInstanceOf[Double] + } + } + + private def toDecimal(value: Any): BigDecimal = { + value match { + case value: java.lang.Integer => BigDecimal(value) + case value: java.lang.Long => BigDecimal(value) + case value: java.math.BigInteger => BigDecimal(value) + case value: java.lang.Double => BigDecimal(value) + case value: java.math.BigDecimal => BigDecimal(value) + } + } + + private def toJsonArrayString(seq: Seq[Any]): String = { + val builder = new StringBuilder + builder.append("[") + var count = 0 + seq.foreach { + element => + if (count > 0) builder.append(",") + count += 1 + builder.append(toString(element)) + } + builder.append("]") + + builder.toString() + } + + private def toJsonObjectString(map: Map[String, Any]): String = { + val builder = new StringBuilder + builder.append("{") + var count = 0 + map.foreach { + case (key, value) => + if (count > 0) builder.append(",") + count += 1 + builder.append(s"""\"${key}\":${toString(value)}""") + } + builder.append("}") + + builder.toString() + } + + private def toString(value: Any): String = { + value match { + case value: Map[String, Any] => toJsonObjectString(value) + case value: Seq[Any] => toJsonArrayString(value) + case value => Option(value).map(_.toString).orNull + } + } + + private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ + if (value == null) { + null + } else { + desiredType match { + case ArrayType(elementType) => + value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) + case StringType => toString(value) + case IntegerType => value.asInstanceOf[IntegerType.JvmType] + case LongType => toLong(value) + case DoubleType => toDouble(value) + case DecimalType => toDecimal(value) + case BooleanType => value.asInstanceOf[BooleanType.JvmType] + case NullType => null + } + } + } + + private def asRow(json: Map[String,Any], schema: StructType): Row = { + val row = new GenericMutableRow(schema.fields.length) + schema.fields.zipWithIndex.foreach { + // StructType + case (StructField(name, fields: StructType, _), i) => + row.update(i, json.get(name).flatMap(v => Option(v)).map( + v => asRow(v.asInstanceOf[Map[String, Any]], fields)).orNull) + + // ArrayType(StructType) + case (StructField(name, ArrayType(structType: StructType), _), i) => + row.update(i, + json.get(name).flatMap(v => Option(v)).map( + v => v.asInstanceOf[Seq[Any]].map( + e => asRow(e.asInstanceOf[Map[String, Any]], structType))).orNull) + + // Other cases + case (StructField(name, dataType, _), i) => + row.update(i, json.get(name).flatMap(v => Option(v)).map( + enforceCorrectType(_, dataType)).getOrElse(null)) + } + + row + } + + private def nullTypeToStringType(struct: StructType): StructType = { + val fields = struct.fields.map { + case StructField(fieldName, dataType, nullable) => { + val newType = dataType match { + case NullType => StringType + case ArrayType(NullType) => ArrayType(StringType) + case struct: StructType => nullTypeToStringType(struct) + case other: DataType => other + } + StructField(fieldName, newType, nullable) + } + } + + StructType(fields) + } + + private def asAttributes(struct: StructType): Seq[AttributeReference] = { + struct.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)()) + } + + private def asStruct(attributes: Seq[AttributeReference]): StructType = { + val fields = attributes.map { + case AttributeReference(name, dataType, nullable) => StructField(name, dataType, nullable) + } + + StructType(fields) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index d7f6abaf5d381..ef84ead2e6e8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql -import org.scalatest.FunSuite - import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ -class QueryTest extends FunSuite { +class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. * @param rdd the [[SchemaRDD]] to be executed diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index 9fff7222fe840..020baf0c7ec6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -22,6 +22,7 @@ import scala.beans.BeanProperty import org.scalatest.FunSuite import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.test.TestSQLContext // Implicits @@ -111,4 +112,48 @@ class JavaSQLSuite extends FunSuite { """.stripMargin).collect.head.row === Seq.fill(8)(null)) } + + test("loads JSON datasets") { + val jsonString = + """{"string":"this is a simple string.", + "integer":10, + "long":21474836470, + "bigInteger":92233720368547758070, + "double":1.7976931348623157E308, + "boolean":true, + "null":null + }""".replaceAll("\n", " ") + val rdd = javaCtx.parallelize(jsonString :: Nil) + + var schemaRDD = javaSqlCtx.jsonRDD(rdd) + + schemaRDD.registerAsTable("jsonTable1") + + assert( + javaSqlCtx.sql("select * from jsonTable1").collect.head.row === + Seq(BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.")) + + val file = getTempFilePath("json") + val path = file.toString + rdd.saveAsTextFile(path) + schemaRDD = javaSqlCtx.jsonFile(path) + + schemaRDD.registerAsTable("jsonTable2") + + assert( + javaSqlCtx.sql("select * from jsonTable2").collect.head.row === + Seq(BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala new file mode 100644 index 0000000000000..10bd9f08f0238 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -0,0 +1,519 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.json + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.TestSQLContext._ + +protected case class Schema(output: Seq[Attribute]) extends LeafNode + +class JsonSuite extends QueryTest { + import TestJsonData._ + TestJsonData + + test("Type promotion") { + def checkTypePromotion(expected: Any, actual: Any) { + assert(expected.getClass == actual.getClass, + s"Failed to promote ${actual.getClass} to ${expected.getClass}.") + assert(expected == actual, + s"Promoted value ${actual}(${actual.getClass}) does not equal the expected value " + + s"${expected}(${expected.getClass}).") + } + + val intNumber: Int = 2147483647 + checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType)) + checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) + checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType)) + checkTypePromotion(BigDecimal(intNumber), enforceCorrectType(intNumber, DecimalType)) + + val longNumber: Long = 9223372036854775807L + checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType)) + checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType)) + checkTypePromotion(BigDecimal(longNumber), enforceCorrectType(longNumber, DecimalType)) + + val doubleNumber: Double = 1.7976931348623157E308d + checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) + checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType)) + } + + test("Get compatible type") { + def checkDataType(t1: DataType, t2: DataType, expected: DataType) { + var actual = compatibleType(t1, t2) + assert(actual == expected, + s"Expected $expected as the most general data type for $t1 and $t2, found $actual") + actual = compatibleType(t2, t1) + assert(actual == expected, + s"Expected $expected as the most general data type for $t1 and $t2, found $actual") + } + + // NullType + checkDataType(NullType, BooleanType, BooleanType) + checkDataType(NullType, IntegerType, IntegerType) + checkDataType(NullType, LongType, LongType) + checkDataType(NullType, DoubleType, DoubleType) + checkDataType(NullType, DecimalType, DecimalType) + checkDataType(NullType, StringType, StringType) + checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType)) + checkDataType(NullType, StructType(Nil), StructType(Nil)) + checkDataType(NullType, NullType, NullType) + + // BooleanType + checkDataType(BooleanType, BooleanType, BooleanType) + checkDataType(BooleanType, IntegerType, StringType) + checkDataType(BooleanType, LongType, StringType) + checkDataType(BooleanType, DoubleType, StringType) + checkDataType(BooleanType, DecimalType, StringType) + checkDataType(BooleanType, StringType, StringType) + checkDataType(BooleanType, ArrayType(IntegerType), StringType) + checkDataType(BooleanType, StructType(Nil), StringType) + + // IntegerType + checkDataType(IntegerType, IntegerType, IntegerType) + checkDataType(IntegerType, LongType, LongType) + checkDataType(IntegerType, DoubleType, DoubleType) + checkDataType(IntegerType, DecimalType, DecimalType) + checkDataType(IntegerType, StringType, StringType) + checkDataType(IntegerType, ArrayType(IntegerType), StringType) + checkDataType(IntegerType, StructType(Nil), StringType) + + // LongType + checkDataType(LongType, LongType, LongType) + checkDataType(LongType, DoubleType, DoubleType) + checkDataType(LongType, DecimalType, DecimalType) + checkDataType(LongType, StringType, StringType) + checkDataType(LongType, ArrayType(IntegerType), StringType) + checkDataType(LongType, StructType(Nil), StringType) + + // DoubleType + checkDataType(DoubleType, DoubleType, DoubleType) + checkDataType(DoubleType, DecimalType, DecimalType) + checkDataType(DoubleType, StringType, StringType) + checkDataType(DoubleType, ArrayType(IntegerType), StringType) + checkDataType(DoubleType, StructType(Nil), StringType) + + // DoubleType + checkDataType(DecimalType, DecimalType, DecimalType) + checkDataType(DecimalType, StringType, StringType) + checkDataType(DecimalType, ArrayType(IntegerType), StringType) + checkDataType(DecimalType, StructType(Nil), StringType) + + // StringType + checkDataType(StringType, StringType, StringType) + checkDataType(StringType, ArrayType(IntegerType), StringType) + checkDataType(StringType, StructType(Nil), StringType) + + // ArrayType + checkDataType(ArrayType(IntegerType), ArrayType(IntegerType), ArrayType(IntegerType)) + checkDataType(ArrayType(IntegerType), ArrayType(LongType), ArrayType(LongType)) + checkDataType(ArrayType(IntegerType), ArrayType(StringType), ArrayType(StringType)) + checkDataType(ArrayType(IntegerType), StructType(Nil), StringType) + + // StructType + checkDataType(StructType(Nil), StructType(Nil), StructType(Nil)) + checkDataType( + StructType(StructField("f1", IntegerType, true) :: Nil), + StructType(StructField("f1", IntegerType, true) :: Nil), + StructType(StructField("f1", IntegerType, true) :: Nil)) + checkDataType( + StructType(StructField("f1", IntegerType, true) :: Nil), + StructType(Nil), + StructType(StructField("f1", IntegerType, true) :: Nil)) + checkDataType( + StructType( + StructField("f1", IntegerType, true) :: + StructField("f2", IntegerType, true) :: Nil), + StructType(StructField("f1", LongType, true) :: Nil) , + StructType( + StructField("f1", LongType, true) :: + StructField("f2", IntegerType, true) :: Nil)) + checkDataType( + StructType( + StructField("f1", IntegerType, true) :: Nil), + StructType( + StructField("f2", IntegerType, true) :: Nil), + StructType( + StructField("f1", IntegerType, true) :: + StructField("f2", IntegerType, true) :: Nil)) + checkDataType( + StructType( + StructField("f1", IntegerType, true) :: Nil), + DecimalType, + StringType) + } + + test("Primitive field and type inferring") { + val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) + + val expectedSchema = + AttributeReference("bigInteger", DecimalType, true)() :: + AttributeReference("boolean", BooleanType, true)() :: + AttributeReference("double", DoubleType, true)() :: + AttributeReference("integer", IntegerType, true)() :: + AttributeReference("long", LongType, true)() :: + AttributeReference("null", StringType, true)() :: + AttributeReference("string", StringType, true)() :: Nil + + comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + + jsonSchemaRDD.registerAsTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + } + + test("Complex field and type inferring") { + val jsonSchemaRDD = jsonRDD(complexFieldAndType) + + val expectedSchema = + AttributeReference("arrayOfArray1", ArrayType(ArrayType(StringType)), true)() :: + AttributeReference("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true)() :: + AttributeReference("arrayOfBigInteger", ArrayType(DecimalType), true)() :: + AttributeReference("arrayOfBoolean", ArrayType(BooleanType), true)() :: + AttributeReference("arrayOfDouble", ArrayType(DoubleType), true)() :: + AttributeReference("arrayOfInteger", ArrayType(IntegerType), true)() :: + AttributeReference("arrayOfLong", ArrayType(LongType), true)() :: + AttributeReference("arrayOfNull", ArrayType(StringType), true)() :: + AttributeReference("arrayOfString", ArrayType(StringType), true)() :: + AttributeReference("arrayOfStruct", ArrayType( + StructType(StructField("field1", BooleanType, true) :: + StructField("field2", StringType, true) :: Nil)), true)() :: + AttributeReference("struct", StructType( + StructField("field1", BooleanType, true) :: + StructField("field2", DecimalType, true) :: Nil), true)() :: + AttributeReference("structWithArrayFields", StructType( + StructField("field1", ArrayType(IntegerType), true) :: + StructField("field2", ArrayType(StringType), true) :: Nil), true)() :: Nil + + comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + + jsonSchemaRDD.registerAsTable("jsonTable") + + // Access elements of a primitive array. + checkAnswer( + sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"), + ("str1", "str2", null) :: Nil + ) + + // Access an array of null values. + checkAnswer( + sql("select arrayOfNull from jsonTable"), + Seq(Seq(null, null, null, null)) :: Nil + ) + + // Access elements of a BigInteger array (we use DecimalType internally). + checkAnswer( + sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"), + (BigDecimal("922337203685477580700"), BigDecimal("-922337203685477580800"), null) :: Nil + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"), + (Seq("1", "2", "3"), Seq("str1", "str2")) :: Nil + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"), + (Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) :: Nil + ) + + // Access elements of an array inside a filed with the type of ArrayType(ArrayType). + checkAnswer( + sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"), + ("str2", 2.1) :: Nil + ) + + // Access elements of an array of structs. + checkAnswer( + sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2] from jsonTable"), + (true :: "str1" :: Nil, false :: null :: Nil, null) :: Nil + ) + + // Access a struct and fields inside of it. + checkAnswer( + sql("select struct, struct.field1, struct.field2 from jsonTable"), + ( + Seq(true, BigDecimal("92233720368547758070")), + true, + BigDecimal("92233720368547758070")) :: Nil + ) + + // Access an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"), + (Seq(4, 5, 6), Seq("str1", "str2")) :: Nil + ) + + // Access elements of an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), + (5, null) :: Nil + ) + } + + ignore("Complex field and type inferring (Ignored)") { + val jsonSchemaRDD = jsonRDD(complexFieldAndType) + jsonSchemaRDD.registerAsTable("jsonTable") + + // Right now, "field1" and "field2" are treated as aliases. We should fix it. + checkAnswer( + sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), + (true, "str1") :: Nil + ) + + // Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2. + // Getting all values of a specific field from an array of structs. + checkAnswer( + sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), + (Seq(true, false), Seq("str1", null)) :: Nil + ) + } + + test("Type conflict in primitive field values") { + val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) + + val expectedSchema = + AttributeReference("num_bool", StringType, true)() :: + AttributeReference("num_num_1", LongType, true)() :: + AttributeReference("num_num_2", DecimalType, true)() :: + AttributeReference("num_num_3", DoubleType, true)() :: + AttributeReference("num_str", StringType, true)() :: + AttributeReference("str_bool", StringType, true)() :: Nil + + comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + + jsonSchemaRDD.registerAsTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + ("true", 11L, null, 1.1, "13.1", "str1") :: + ("12", null, BigDecimal("21474836470.9"), null, null, "true") :: + ("false", 21474836470L, BigDecimal("92233720368547758070"), 100, "str1", "false") :: + (null, 21474836570L, BigDecimal(1.1), 21474836470L, "92233720368547758070", null) :: Nil + ) + + // Number and Boolean conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_bool - 10 from jsonTable where num_bool > 11"), + 2 + ) + + // Widening to LongType + checkAnswer( + sql("select num_num_1 - 100 from jsonTable where num_num_1 > 11"), + Seq(21474836370L) :: Seq(21474836470L) :: Nil + ) + + checkAnswer( + sql("select num_num_1 - 100 from jsonTable where num_num_1 > 10"), + Seq(-89) :: Seq(21474836370L) :: Seq(21474836470L) :: Nil + ) + + // Widening to DecimalType + checkAnswer( + sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"), + Seq(BigDecimal("21474836472.1")) :: Seq(BigDecimal("92233720368547758071.2")) :: Nil + ) + + // Widening to DoubleType + checkAnswer( + sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), + Seq(101.2) :: Seq(21474836471.2) :: Nil + ) + + // Number and String conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_str + 1.2 from jsonTable where num_str > 14"), + 92233720368547758071.2 + ) + + // String and Boolean conflict: resolve the type as string. + checkAnswer( + sql("select * from jsonTable where str_bool = 'str1'"), + ("true", 11L, null, 1.1, "13.1", "str1") :: Nil + ) + } + + ignore("Type conflict in primitive field values (Ignored)") { + val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) + jsonSchemaRDD.registerAsTable("jsonTable") + + // Right now, the analyzer does not promote strings in a boolean expreesion. + // Number and Boolean conflict: resolve the type as boolean in this query. + checkAnswer( + sql("select num_bool from jsonTable where NOT num_bool"), + false + ) + + checkAnswer( + sql("select str_bool from jsonTable where NOT str_bool"), + false + ) + + // Right now, the analyzer does not know that num_bool should be treated as a boolean. + // Number and Boolean conflict: resolve the type as boolean in this query. + checkAnswer( + sql("select num_bool from jsonTable where num_bool"), + true + ) + + checkAnswer( + sql("select str_bool from jsonTable where str_bool"), + false + ) + + // Right now, we have a parsing error. + // Number and String conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"), + BigDecimal("92233720368547758061.2") + ) + + // The plan of the following DSL is + // Project [(CAST(num_str#65:4, DoubleType) + 1.2) AS num#78] + // Filter (CAST(CAST(num_str#65:4, DoubleType), DecimalType) > 92233720368547758060) + // ExistingRdd [num_bool#61,num_num_1#62L,num_num_2#63,num_num_3#64,num_str#65,str_bool#66] + // We should directly cast num_str to DecimalType and also need to do the right type promotion + // in the Project. + checkAnswer( + jsonSchemaRDD. + where('num_str > BigDecimal("92233720368547758060")). + select('num_str + 1.2 as Symbol("num")), + BigDecimal("92233720368547758061.2") + ) + + // The following test will fail. The type of num_str is StringType. + // So, to evaluate num_str + 1.2, we first need to use Cast to convert the type. + // In our test data, one value of num_str is 13.1. + // The result of (CAST(num_str#65:4, DoubleType) + 1.2) for this value is 14.299999999999999, + // which is not 14.3. + // Number and String conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_str + 1.2 from jsonTable where num_str > 13"), + Seq(14.3) :: Seq(92233720368547758071.2) :: Nil + ) + } + + test("Type conflict in complex field values") { + val jsonSchemaRDD = jsonRDD(complexFieldValueTypeConflict) + + val expectedSchema = + AttributeReference("array", ArrayType(IntegerType), true)() :: + AttributeReference("num_struct", StringType, true)() :: + AttributeReference("str_array", StringType, true)() :: + AttributeReference("struct", StructType( + StructField("field", StringType, true) :: Nil), true)() :: + AttributeReference("struct_array", StringType, true)() :: Nil + + comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + + jsonSchemaRDD.registerAsTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + (Seq(), "11", "[1,2,3]", Seq(null), "[]") :: + (null, """{"field":false}""", null, null, "{}") :: + (Seq(4, 5, 6), null, "str", Seq(null), "[7,8,9]") :: + (Seq(7), "{}","[str1,str2,33]", Seq("str"), """{"field":true}""") :: Nil + ) + } + + test("Type conflict in array elements") { + val jsonSchemaRDD = jsonRDD(arrayElementTypeConflict) + + val expectedSchema = + AttributeReference("array", ArrayType(StringType), true)() :: Nil + + comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + + jsonSchemaRDD.registerAsTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + Seq(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", + """{"field":str}""")) :: Nil + ) + + // Treat an element as a number. + checkAnswer( + sql("select array[0] + 1 from jsonTable"), + 2 + ) + } + + test("Handling missing fields") { + val jsonSchemaRDD = jsonRDD(missingFields) + + val expectedSchema = + AttributeReference("a", BooleanType, true)() :: + AttributeReference("b", LongType, true)() :: + AttributeReference("c", ArrayType(IntegerType), true)() :: + AttributeReference("d", StructType( + StructField("field", BooleanType, true) :: Nil), true)() :: + AttributeReference("e", StringType, true)() :: Nil + + comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + + jsonSchemaRDD.registerAsTable("jsonTable") + } + + test("Loading a JSON dataset from a text file") { + val file = getTempFilePath("json") + val path = file.toString + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + val jsonSchemaRDD = jsonFile(path) + + val expectedSchema = + AttributeReference("bigInteger", DecimalType, true)() :: + AttributeReference("boolean", BooleanType, true)() :: + AttributeReference("double", DoubleType, true)() :: + AttributeReference("integer", IntegerType, true)() :: + AttributeReference("long", LongType, true)() :: + AttributeReference("null", StringType, true)() :: + AttributeReference("string", StringType, true)() :: Nil + + comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + + jsonSchemaRDD.registerAsTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala new file mode 100644 index 0000000000000..065e04046e8a6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.json + +import org.apache.spark.sql.test.TestSQLContext + +object TestJsonData { + + val primitiveFieldAndType = + TestSQLContext.sparkContext.parallelize( + """{"string":"this is a simple string.", + "integer":10, + "long":21474836470, + "bigInteger":92233720368547758070, + "double":1.7976931348623157E308, + "boolean":true, + "null":null + }""" :: Nil) + + val complexFieldAndType = + TestSQLContext.sparkContext.parallelize( + """{"struct":{"field1": true, "field2": 92233720368547758070}, + "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, + "arrayOfString":["str1", "str2"], + "arrayOfInteger":[1, 2147483647, -2147483648], + "arrayOfLong":[21474836470, 9223372036854775807, -9223372036854775808], + "arrayOfBigInteger":[922337203685477580700, -922337203685477580800], + "arrayOfDouble":[1.2, 1.7976931348623157E308, 4.9E-324, 2.2250738585072014E-308], + "arrayOfBoolean":[true, false, true], + "arrayOfNull":[null, null, null, null], + "arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}], + "arrayOfArray1":[[1, 2, 3], ["str1", "str2"]], + "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] + }""" :: Nil) + + val primitiveFieldValueTypeConflict = + TestSQLContext.sparkContext.parallelize( + """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, + "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: + """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, + "num_bool":12, "num_str":null, "str_bool":true}""" :: + """{"num_num_1":21474836470, "num_num_2":92233720368547758070, "num_num_3": 100, + "num_bool":false, "num_str":"str1", "str_bool":false}""" :: + """{"num_num_1":21474836570, "num_num_2":1.1, "num_num_3": 21474836470, + "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) + + val complexFieldValueTypeConflict = + TestSQLContext.sparkContext.parallelize( + """{"num_struct":11, "str_array":[1, 2, 3], + "array":[], "struct_array":[], "struct": {}}""" :: + """{"num_struct":{"field":false}, "str_array":null, + "array":null, "struct_array":{}, "struct": null}""" :: + """{"num_struct":null, "str_array":"str", + "array":[4, 5, 6], "struct_array":[7, 8, 9], "struct": {"field":null}}""" :: + """{"num_struct":{}, "str_array":["str1", "str2", 33], + "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) + + val arrayElementTypeConflict = + TestSQLContext.sparkContext.parallelize( + """{"array": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}]}""" :: Nil) + + val missingFields = + TestSQLContext.sparkContext.parallelize( + """{"a":true}""" :: + """{"b":21474836470}""" :: + """{"c":[33, 44]}""" :: + """{"d":{"field":true}}""" :: + """{"e":"str"}""" :: Nil) +} From 9e4b4bd0837cfc4ef1af1edcbc56290821e49e92 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 17 Jun 2014 19:34:17 -0700 Subject: [PATCH 29/57] Revert "SPARK-2038: rename "conf" parameters in the saveAsHadoop functions" This reverts commit 443f5e1bbcf9ec55e5ce6e4f738a002a47818100. This commit unfortunately would break source compatibility if users have named the hadoopConf parameter. --- .../apache/spark/rdd/PairRDDFunctions.scala | 49 +++++++++---------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index bff77b4ecbf27..fe36c80e0be84 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -719,9 +719,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - hadoopConf: Configuration = self.context.hadoopConfiguration) + conf: Configuration = self.context.hadoopConfiguration) { - val job = new NewAPIHadoopJob(hadoopConf) + val job = new NewAPIHadoopJob(conf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) @@ -752,25 +752,24 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], - hadoopConf: JobConf = new JobConf(self.context.hadoopConfiguration), + conf: JobConf = new JobConf(self.context.hadoopConfiguration), codec: Option[Class[_ <: CompressionCodec]] = None) { - hadoopConf.setOutputKeyClass(keyClass) - hadoopConf.setOutputValueClass(valueClass) + conf.setOutputKeyClass(keyClass) + conf.setOutputValueClass(valueClass) // Doesn't work in Scala 2.9 due to what may be a generics bug // TODO: Should we uncomment this for Scala 2.10? // conf.setOutputFormat(outputFormatClass) - hadoopConf.set("mapred.output.format.class", outputFormatClass.getName) + conf.set("mapred.output.format.class", outputFormatClass.getName) for (c <- codec) { - hadoopConf.setCompressMapOutput(true) - hadoopConf.set("mapred.output.compress", "true") - hadoopConf.setMapOutputCompressorClass(c) - hadoopConf.set("mapred.output.compression.codec", c.getCanonicalName) - hadoopConf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) + conf.setCompressMapOutput(true) + conf.set("mapred.output.compress", "true") + conf.setMapOutputCompressorClass(c) + conf.set("mapred.output.compression.codec", c.getCanonicalName) + conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) } - hadoopConf.setOutputCommitter(classOf[FileOutputCommitter]) - FileOutputFormat.setOutputPath(hadoopConf, - SparkHadoopWriter.createPathFromString(path, hadoopConf)) - saveAsHadoopDataset(hadoopConf) + conf.setOutputCommitter(classOf[FileOutputCommitter]) + FileOutputFormat.setOutputPath(conf, SparkHadoopWriter.createPathFromString(path, conf)) + saveAsHadoopDataset(conf) } /** @@ -779,8 +778,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. */ - def saveAsNewAPIHadoopDataset(hadoopConf: Configuration) { - val job = new NewAPIHadoopJob(hadoopConf) + def saveAsNewAPIHadoopDataset(conf: Configuration) { + val job = new NewAPIHadoopJob(conf) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id @@ -836,10 +835,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop * MapReduce job. */ - def saveAsHadoopDataset(hadoopConf: JobConf) { - val outputFormatInstance = hadoopConf.getOutputFormat - val keyClass = hadoopConf.getOutputKeyClass - val valueClass = hadoopConf.getOutputValueClass + def saveAsHadoopDataset(conf: JobConf) { + val outputFormatInstance = conf.getOutputFormat + val keyClass = conf.getOutputKeyClass + val valueClass = conf.getOutputValueClass if (outputFormatInstance == null) { throw new SparkException("Output format class not set") } @@ -849,18 +848,18 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) if (valueClass == null) { throw new SparkException("Output value class not set") } - SparkHadoopUtil.get.addCredentials(hadoopConf) + SparkHadoopUtil.get.addCredentials(conf) logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + valueClass.getSimpleName + ")") if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) { // FileOutputFormat ignores the filesystem parameter - val ignoredFs = FileSystem.get(hadoopConf) - hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf) + val ignoredFs = FileSystem.get(conf) + conf.getOutputFormat.checkOutputSpecs(ignoredFs, conf) } - val writer = new SparkHadoopWriter(hadoopConf) + val writer = new SparkHadoopWriter(conf) writer.preSetup() def writeToFile(context: TaskContext, iter: Iterator[(K, V)]) { From 889f7b7624689444ecdb4f0ca16ef78f9bfc8430 Mon Sep 17 00:00:00 2001 From: Vadim Chekan Date: Tue, 17 Jun 2014 22:03:50 -0700 Subject: [PATCH 30/57] [STREAMING] SPARK-2009 Key not found exception when slow receiver starts I got "java.util.NoSuchElementException: key not found: 1401756085000 ms" exception when using kafka stream and 1 sec batchPeriod. Investigation showed that the reason is that ReceiverLauncher.startReceivers is asynchronous (started in a thread). https://github.com/vchekan/spark/blob/master/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala#L206 In case of slow starting receiver, such as Kafka, it easily takes more than 2sec to start. In result, no single "compute" will be called on ReceiverInputDStream before first batch job is executed and receivedBlockInfo remains empty (obviously). Batch job will cause ReceiverInputDStream.getReceivedBlockInfo call and "key not found" exception. The patch makes getReceivedBlockInfo more robust by tolerating missing values. Author: Vadim Chekan Closes #961 from vchekan/branch-1.0 and squashes the following commits: e86f82b [Vadim Chekan] Fixed indentation 4609563 [Vadim Chekan] Key not found exception: if receiver is slow to start, it is possible that getReceivedBlockInfo will be called before compute has been called (cherry picked from commit 26f6b989312a9a48a27a23ecc68702bd14032e55) Signed-off-by: Patrick Wendell --- .../apache/spark/streaming/dstream/ReceiverInputDStream.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 75cabdbf8da26..391e40924f38a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -74,7 +74,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont /** Get information on received blocks. */ private[streaming] def getReceivedBlockInfo(time: Time) = { - receivedBlockInfo(time) + receivedBlockInfo.get(time).getOrElse(Array.empty[ReceivedBlockInfo]) } /** From 587d32012ceeec1e80cec1878312f164cdb76ec8 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 18 Jun 2014 10:51:32 -0700 Subject: [PATCH 31/57] [SPARK-2176][SQL] Extra unnecessary exchange operator in the result of an explain command ``` hql("explain select * from src group by key").collect().foreach(println) [ExplainCommand [plan#27:0]] [ Aggregate false, [key#25], [key#25,value#26]] [ Exchange (HashPartitioning [key#25:0], 200)] [ Exchange (HashPartitioning [key#25:0], 200)] [ Aggregate true, [key#25], [key#25]] [ HiveTableScan [key#25,value#26], (MetastoreRelation default, src, None), None] ``` There are two exchange operators. However, if we do not use explain... ``` hql("select * from src group by key") res4: org.apache.spark.sql.SchemaRDD = SchemaRDD[8] at RDD at SchemaRDD.scala:100 == Query Plan == Aggregate false, [key#8], [key#8,value#9] Exchange (HashPartitioning [key#8:0], 200) Aggregate true, [key#8], [key#8] HiveTableScan [key#8,value#9], (MetastoreRelation default, src, None), None ``` The plan is fine. The cause of this bug is explained below. When we create an `execution.ExplainCommand`, we use the `executedPlan` as the child of this `ExplainCommand`. But, this `executedPlan` is prepared for execution again when we generate the `executedPlan` for the `ExplainCommand`. Basically, `prepareForExecution` is called twice on a physical plan. Because after `prepareForExecution` we have already bounded those references (in `BoundReference`s), `AddExchange` cannot figure out we are using the same partitioning (we use `AttributeReference`s to create an `ExchangeOperator` and then those references will be changed to `BoundReference`s after `prepareForExecution` is called). So, an extra `ExchangeOperator` is inserted. I think in `CommandStrategy`, we should just use the `sparkPlan` (`sparkPlan` is the input of `prepareForExecution`) to initialize the `ExplainCommand` instead of using `executedPlan`. The link to JIRA: https://issues.apache.org/jira/browse/SPARK-2176 Author: Yin Huai Closes #1116 from yhuai/SPARK-2176 and squashes the following commits: 197c19c [Yin Huai] Use sparkPlan to initialize a Physical Explain Command instead of using executedPlan. --- sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala | 2 ++ .../org/apache/spark/sql/execution/SparkStrategies.scala | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f7e03323bed33..1617ec717b2e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -307,6 +307,8 @@ class SQLContext(@transient val sparkContext: SparkContext) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... lazy val sparkPlan = planner(optimizedPlan).next() + // executedPlan should not be used to initialize any SparkPlan. It should be + // only used for execution. lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2233216a6ec52..70c1171148ebb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -251,8 +251,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.SetCommand(key, value) => Seq(execution.SetCommand(key, value, plan.output)(context)) case logical.ExplainCommand(child) => - val executedPlan = context.executePlan(child).executedPlan - Seq(execution.ExplainCommand(executedPlan, plan.output)(context)) + val sparkPlan = context.executePlan(child).sparkPlan + Seq(execution.ExplainCommand(sparkPlan, plan.output)(context)) case logical.CacheCommand(tableName, cache) => Seq(execution.CacheCommand(tableName, cache)(context)) case _ => Nil From 5ad5e3486aa4d13b0991de13f5f17d9897dd2753 Mon Sep 17 00:00:00 2001 From: Raymond Liu Date: Wed, 18 Jun 2014 10:57:45 -0700 Subject: [PATCH 32/57] [SPARK-2162] Double check in doGetLocal to avoid read on removed block. other wise, it will either read in vain in memory level case, or throw exception in disk level case when it believe the block is there while actually it had been removed. Author: Raymond Liu Closes #1103 from colorant/bm and squashes the following commits: daac114 [Raymond Liu] Address comments d1ea287 [Raymond Liu] Double check in doGetLocal to avoid read on removed block. --- .../main/scala/org/apache/spark/storage/BlockManager.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index f52bc7075104b..373987c122620 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -363,6 +363,13 @@ private[spark] class BlockManager( val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { + // Double check to make sure the block is still there, since removeBlock + // method also synchronizes on BlockInfo object, so the block might have + // been removed when we actually come here. + if (blockInfo.get(blockId).isEmpty) { + logDebug(s"Block $blockId had been removed") + return None + } // If another thread is writing the block, wait for it to become ready. if (!info.waitForReady()) { From dd96fcda0145810785b67f847f98b04ff7f0d7c3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Jun 2014 12:48:58 -0700 Subject: [PATCH 33/57] Updated the comment for SPARK-2162. A follow up on #1103 @andrewor14 Author: Reynold Xin Closes #1117 from rxin/SPARK-2162 and squashes the following commits: a4231de [Reynold Xin] Updated the comment for SPARK-2162. --- .../scala/org/apache/spark/storage/BlockManager.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 373987c122620..d2f7baf928b62 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -363,11 +363,13 @@ private[spark] class BlockManager( val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { - // Double check to make sure the block is still there, since removeBlock - // method also synchronizes on BlockInfo object, so the block might have - // been removed when we actually come here. + // Double check to make sure the block is still there. There is a small chance that the + // block has been removed by removeBlock (which also synchronizes on the blockInfo object). + // Note that this only checks metadata tracking. If user intentionally deleted the block + // on disk or from off heap storage without using removeBlock, this conditional check will + // still pass but eventually we will get an exception because we can't find the block. if (blockInfo.get(blockId).isEmpty) { - logDebug(s"Block $blockId had been removed") + logWarning(s"Block $blockId had been removed") return None } From 3870248740d83b0292ccca88a494ce19783847f0 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Wed, 18 Jun 2014 13:16:26 -0700 Subject: [PATCH 34/57] [SPARK-1466] Raise exception if pyspark Gateway process doesn't start. If the gateway process fails to start correctly (e.g., because JAVA_HOME isn't set correctly, there's no Spark jar, etc.), right now pyspark fails because of a very difficult-to-understand error, where we try to parse stdout to get the port where Spark started and there's nothing there. This commit properly catches the error and throws an exception that includes the stderr output for much easier debugging. Thanks to @shivaram and @stogers for helping to fix this issue! Author: Kay Ousterhout Closes #383 from kayousterhout/pyspark and squashes the following commits: 36dd54b [Kay Ousterhout] [SPARK-1466] Raise exception if Gateway process doesn't start. --- python/pyspark/java_gateway.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 91ae8263f66b8..19235d5f79f85 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -43,12 +43,19 @@ def launch_gateway(): # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func) + proc = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE, preexec_fn=preexec_func) else: # preexec_fn not supported on Windows - proc = Popen(command, stdout=PIPE, stdin=PIPE) - # Determine which ephemeral port the server started on: - gateway_port = int(proc.stdout.readline()) + proc = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE) + + try: + # Determine which ephemeral port the server started on: + gateway_port = int(proc.stdout.readline()) + except: + error_code = proc.poll() + raise Exception("Launching GatewayServer failed with exit code %d: %s" % + (error_code, "".join(proc.stderr.readlines()))) + # Create a thread to echo output from the GatewayServer, which is required # for Java log output to show up: class EchoOutputThread(Thread): From 4cbeea83e086bbbb1898bf796a5e5b789bc4cc06 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Wed, 18 Jun 2014 14:56:41 -0700 Subject: [PATCH 35/57] SPARK-2158 Clean up core/stdout file from FileAppenderSuite @tdas Author: Mark Hamstra Closes #1100 from markhamstra/SPARK-2158 and squashes the following commits: ae8e069 [Mark Hamstra] Response to TD's review 2f1e201 [Mark Hamstra] Cleanup 'stdout' file within FileAppenderSuite --- .../org/apache/spark/util/FileAppenderSuite.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 53d7f5c6072e6..02e228945bbd9 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -120,7 +120,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { // on SparkConf settings. def testAppenderSelection[ExpectedAppender: ClassTag, ExpectedRollingPolicy]( - properties: Seq[(String, String)], expectedRollingPolicyParam: Long = -1): FileAppender = { + properties: Seq[(String, String)], expectedRollingPolicyParam: Long = -1): Unit = { // Set spark conf properties val conf = new SparkConf @@ -129,8 +129,9 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { } // Create and test file appender - val inputStream = new PipedInputStream(new PipedOutputStream()) - val appender = FileAppender(inputStream, new File("stdout"), conf) + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream) + val appender = FileAppender(testInputStream, testFile, conf) assert(appender.isInstanceOf[ExpectedAppender]) assert(appender.getClass.getSimpleName === classTag[ExpectedAppender].runtimeClass.getSimpleName) @@ -144,7 +145,8 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { } assert(policyParam === expectedRollingPolicyParam) } - appender + testOutputStream.close() + appender.awaitTermination() } import RollingFileAppender._ From 45a95f82caea55a8616141444285faf58fef128b Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Wed, 18 Jun 2014 15:01:29 -0700 Subject: [PATCH 36/57] Remove unicode operator from RDD.scala MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Some IDEs don’t support unicode characters in source code. Check if this breaks binary compatibility. Author: Doris Xin Closes #1119 from dorx/unicode and squashes the following commits: 05618c3 [Doris Xin] Remove unicode operator from RDD.scala --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 1633b185861b9..cebfd109d825f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -446,7 +446,7 @@ abstract class RDD[T: ClassTag]( * Return this RDD sorted by the given key function. */ def sortBy[K]( - f: (T) ⇒ K, + f: (T) => K, ascending: Boolean = true, numPartitions: Int = this.partitions.size) (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = From 5ff75c748a27bcfae71759d0e509218f0c5d0200 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 18 Jun 2014 17:52:42 -0700 Subject: [PATCH 37/57] [SPARK-2184][SQL] AddExchange isn't idempotent ...redPartitioning. Author: Michael Armbrust Closes #1122 from marmbrus/fixAddExchange and squashes the following commits: 3417537 [Michael Armbrust] Don't bind partitioning expressions as that breaks comparison with requiredPartitioning. --- .../spark/sql/catalyst/expressions/BoundAttribute.scala | 4 ++-- .../org/apache/spark/sql/catalyst/expressions/Row.scala | 3 +++ .../scala/org/apache/spark/sql/execution/Exchange.scala | 8 ++++---- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4ebf6c4584b94..655d4a08fe93b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -68,7 +68,7 @@ class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] { } object BindReferences extends Logging { - def bindReference(expression: Expression, input: Seq[Attribute]): Expression = { + def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexWhere(_.exprId == a.exprId) @@ -83,6 +83,6 @@ object BindReferences extends Logging { BoundReference(ordinal, a) } } - } + }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 77b5429bad432..74ae723686cfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -208,6 +208,9 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { + def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = + this(ordering.map(BindReferences.bindReference(_, inputSchema))) + def compare(a: Row, b: Row): Int = { var i = 0 while (i < ordering.size) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index cef294167f146..05dfb85b38b02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -22,7 +22,7 @@ import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.sql.{SQLConf, SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{MutableProjection, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{NoBind, MutableProjection, RowOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -31,7 +31,7 @@ import org.apache.spark.util.MutablePair * :: DeveloperApi :: */ @DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { +case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode with NoBind { override def outputPartitioning = newPartitioning @@ -42,7 +42,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. val rdd = child.execute().mapPartitions { iter => - val hashExpressions = new MutableProjection(expressions) + val hashExpressions = new MutableProjection(expressions, child.output) val mutablePair = new MutablePair[Row, Row]() iter.map(r => mutablePair.update(hashExpressions(r), r)) } @@ -53,7 +53,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una case RangePartitioning(sortingExpressions, numPartitions) => // TODO: RangePartitioner should take an Ordering. - implicit val ordering = new RowOrdering(sortingExpressions) + implicit val ordering = new RowOrdering(sortingExpressions, child.output) val rdd = child.execute().mapPartitions { iter => val mutablePair = new MutablePair[Row, Null](null, null) From 566f70f2140c1d243fe2368af60ecb390ac8ab3e Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Wed, 18 Jun 2014 22:19:06 -0700 Subject: [PATCH 38/57] Squishing a typo bug before it causes real harm in updateNumRows method in RowMatrix Author: Doris Xin Closes #1125 from dorx/updateNumRows and squashes the following commits: 8564aef [Doris Xin] Squishing a typo bug before it causes real harm --- .../org/apache/spark/mllib/linalg/distributed/RowMatrix.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 00d0b18c27a8d..1a0073c9d487e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -419,7 +419,7 @@ class RowMatrix( /** Updates or verifies the number of rows. */ private def updateNumRows(m: Long) { if (nRows <= 0) { - nRows == m + nRows = m } else { require(nRows == m, s"The number of rows $m is different from what specified or previously computed: ${nRows}.") From 640c294369f49a7602c33c7c389088aec8a316d3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Jun 2014 22:44:12 -0700 Subject: [PATCH 39/57] [SPARK-2187] Explain should not run the optimizer twice. @yhuai @marmbrus @concretevitamin Author: Reynold Xin Closes #1123 from rxin/explain and squashes the following commits: def83b0 [Reynold Xin] Update unit tests for explain. a9d3ba8 [Reynold Xin] [SPARK-2187] Explain should not run the optimizer twice. --- .../spark/sql/execution/SparkStrategies.scala | 5 ++--- .../apache/spark/sql/execution/commands.scala | 16 ++++++++++++---- .../sql/hive/execution/HiveQuerySuite.scala | 5 +---- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 70c1171148ebb..feb280d1d1411 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -250,9 +250,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.SetCommand(key, value) => Seq(execution.SetCommand(key, value, plan.output)(context)) - case logical.ExplainCommand(child) => - val sparkPlan = context.executePlan(child).sparkPlan - Seq(execution.ExplainCommand(sparkPlan, plan.output)(context)) + case logical.ExplainCommand(logicalPlan) => + Seq(execution.ExplainCommand(logicalPlan, plan.output)(context)) case logical.CacheCommand(tableName, cache) => Seq(execution.CacheCommand(tableName, cache)(context)) case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 39b3246c875df..f5d0834a4993d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan trait Command { /** @@ -71,16 +72,23 @@ case class SetCommand( } /** + * An explain command for users to see how a command will be executed. + * + * Note that this command takes in a logical plan, runs the optimizer on the logical plan + * (but do NOT actually execute it). + * * :: DeveloperApi :: */ @DeveloperApi case class ExplainCommand( - child: SparkPlan, output: Seq[Attribute])( + logicalPlan: LogicalPlan, output: Seq[Attribute])( @transient context: SQLContext) - extends UnaryNode with Command { + extends LeafNode with Command { - // Actually "EXPLAIN" command doesn't cause any side effect. - override protected[sql] lazy val sideEffectResult: Seq[String] = this.toString.split("\n") + // Run through the optimizer to generate the physical plan. + override protected[sql] lazy val sideEffectResult: Seq[String] = { + "Physical execution plan:" +: context.executePlan(logicalPlan).executedPlan.toString.split("\n") + } def execute(): RDD[Row] = { val explanation = sideEffectResult.map(row => new GenericRow(Array[Any](row))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index fe698f0fc57b8..8b2bdd513b71f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -202,12 +202,9 @@ class HiveQuerySuite extends HiveComparisonTest { } } - private val explainCommandClassName = - classOf[execution.ExplainCommand].getSimpleName.stripSuffix("$") - def isExplanation(result: SchemaRDD) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } - explanation.size > 1 && explanation.head.startsWith(explainCommandClassName) + explanation.size > 1 && explanation.head.startsWith("Physical execution plan") } test("SPARK-1704: Explain commands as a SchemaRDD") { From 67fca189c944b8f8ba222bb471e343893031bd7b Mon Sep 17 00:00:00 2001 From: WangTao Date: Wed, 18 Jun 2014 23:24:57 -0700 Subject: [PATCH 40/57] Minor fix The value "env" is never used in SparkContext.scala. Add detailed comment for method setDelaySeconds in MetadataCleaner.scala instead of the unsure one. Author: WangTao Closes #1105 from WangTaoTheTonic/master and squashes the following commits: 688358e [WangTao] Minor fix --- core/src/main/scala/org/apache/spark/SparkContext.scala | 1 - .../main/scala/org/apache/spark/util/MetadataCleaner.scala | 7 ++++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0678bdd02110e..f9476ff826a62 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -224,7 +224,6 @@ class SparkContext(config: SparkConf) extends Logging { /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration: Configuration = { - val env = SparkEnv.get val hadoopConf = SparkHadoopUtil.get.newConfiguration() // Explicitly check for S3 environment variables if (System.getenv("AWS_ACCESS_KEY_ID") != null && diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 7ebed5105b9fd..2889e171f627e 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -91,8 +91,13 @@ private[spark] object MetadataCleaner { conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) } + /** + * Set the default delay time (in seconds). + * @param conf SparkConf instance + * @param delay default delay time to set + * @param resetAll whether to reset all to default + */ def setDelaySeconds(conf: SparkConf, delay: Int, resetAll: Boolean = true) { - // override for all ? conf.set("spark.cleaner.ttl", delay.toString) if (resetAll) { for (cleanerType <- MetadataCleanerType.values) { From bce0897bc6b0fc9bca5444dbe3a9e75523ad7481 Mon Sep 17 00:00:00 2001 From: witgo Date: Thu, 19 Jun 2014 12:11:26 -0500 Subject: [PATCH 41/57] [SPARK-2051]In yarn.ClientBase spark.yarn.dist.* do not work Author: witgo Closes #969 from witgo/yarn_ClientBase and squashes the following commits: 8117765 [witgo] review commit 3bdbc52 [witgo] Merge branch 'master' of https://github.com/apache/spark into yarn_ClientBase 5261b6c [witgo] fix sys.props.get("SPARK_YARN_DIST_FILES") e3c1107 [witgo] update docs b6a9aa1 [witgo] merge master c8b4554 [witgo] review commit 2f48789 [witgo] Merge branch 'master' of https://github.com/apache/spark into yarn_ClientBase 8d7b82f [witgo] Merge branch 'master' of https://github.com/apache/spark into yarn_ClientBase 1048549 [witgo] remove Utils.resolveURIs 871f1db [witgo] add spark.yarn.dist.* documentation 41bce59 [witgo] review commit 35d6fa0 [witgo] move to ClientArguments 55d72fc [witgo] Merge branch 'master' of https://github.com/apache/spark into yarn_ClientBase 9cdff16 [witgo] review commit 8bc2f4b [witgo] review commit 20e667c [witgo] Merge branch 'master' into yarn_ClientBase 0961151 [witgo] merge master ce609fc [witgo] Merge branch 'master' into yarn_ClientBase 8362489 [witgo] yarn.ClientBase spark.yarn.dist.* do not work --- docs/running-on-yarn.md | 20 ++++++++++++++++--- .../spark/deploy/yarn/ClientArguments.scala | 15 ++++++++++++-- .../apache/spark/deploy/yarn/ClientBase.scala | 3 ++- .../cluster/YarnClientSchedulerBackend.scala | 4 +--- 4 files changed, 33 insertions(+), 9 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 4243ef480ba39..fecd8f2cc2d48 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -68,15 +68,29 @@ Most of the configs are the same for Spark on YARN as for other deployment modes - spark.yarn.executor.memoryOverhead - 384 + spark.yarn.dist.archives + (none) + + Comma separated list of archives to be extracted into the working directory of each executor. + + + + spark.yarn.dist.files + (none) + + Comma-separated list of files to be placed in the working directory of each executor. + + + + spark.yarn.executor.memoryOverhead + 384 The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. spark.yarn.driver.memoryOverhead - 384 + 384 The amount of off heap memory (in megabytes) to be allocated per driver. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index fd3ef9e1fa2de..62f9b3cf5ab88 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -21,8 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark.SparkConf import org.apache.spark.scheduler.InputFormatInfo -import org.apache.spark.util.IntParam -import org.apache.spark.util.MemoryParam +import org.apache.spark.util.{Utils, IntParam, MemoryParam} // TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware ! @@ -45,6 +44,18 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { parseArgs(args.toList) + // env variable SPARK_YARN_DIST_ARCHIVES/SPARK_YARN_DIST_FILES set in yarn-client then + // it should default to hdfs:// + files = Option(files).getOrElse(sys.env.get("SPARK_YARN_DIST_FILES").orNull) + archives = Option(archives).getOrElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES").orNull) + + // spark.yarn.dist.archives/spark.yarn.dist.files defaults to use file:// if not specified, + // for both yarn-client and yarn-cluster + files = Option(files).getOrElse(sparkConf.getOption("spark.yarn.dist.files"). + map(p => Utils.resolveURIs(p)).orNull) + archives = Option(archives).getOrElse(sparkConf.getOption("spark.yarn.dist.archives"). + map(p => Utils.resolveURIs(p)).orNull) + private def parseArgs(inputArgs: List[String]): Unit = { val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]() val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]() diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 858bcaa95b409..8f2267599914c 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -162,7 +162,7 @@ trait ClientBase extends Logging { val fs = FileSystem.get(conf) val remoteFs = originalPath.getFileSystem(conf) var newPath = originalPath - if (! compareFs(remoteFs, fs)) { + if (!compareFs(remoteFs, fs)) { newPath = new Path(dstDir, originalPath.getName()) logInfo("Uploading " + originalPath + " to " + newPath) FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf) @@ -250,6 +250,7 @@ trait ClientBase extends Logging { } } } + logInfo("Prepared Local resources " + localResources) sparkConf.set(ClientBase.CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(",")) UserGroupInformation.getCurrentUser().addCredentials(credentials) diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 039cf4f276119..412dfe38d55eb 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -70,9 +70,7 @@ private[spark] class YarnClientSchedulerBackend( ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), - ("--name", "SPARK_YARN_APP_NAME", "spark.app.name"), - ("--files", "SPARK_YARN_DIST_FILES", "spark.yarn.dist.files"), - ("--archives", "SPARK_YARN_DIST_ARCHIVES", "spark.yarn.dist.archives")) + ("--name", "SPARK_YARN_APP_NAME", "spark.app.name")) .foreach { case (optName, envVar, sysProp) => addArg(optName, envVar, sysProp, argsArrayBuf) } logDebug("ClientArguments called with: " + argsArrayBuf) From 777c5958c4088182f9e2daba435ccb413a2f69d7 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 19 Jun 2014 14:14:03 -0700 Subject: [PATCH 42/57] [SPARK-2191][SQL] Make sure InsertIntoHiveTable doesn't execute more than once. Author: Michael Armbrust Closes #1129 from marmbrus/doubleCreateAs and squashes the following commits: 9c6d9e4 [Michael Armbrust] Fix typo. 5128fe2 [Michael Armbrust] Make sure InsertIntoHiveTable doesn't execute each time you ask for its result. --- .../org/apache/spark/sql/hive/execution/hiveOperators.scala | 6 +++++- .../apache/spark/sql/hive/execution/HiveQuerySuite.scala | 6 ++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala index a839231449161..240aa0df4935a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala @@ -344,12 +344,16 @@ case class InsertIntoHiveTable( writer.commitJob() } + override def execute() = result + /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. + * + * Note: this is run once and then kept to avoid double insertions. */ - def execute() = { + private lazy val result: RDD[Row] = { val childRdd = child.execute() assert(childRdd != null) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 8b2bdd513b71f..5118f4b3f99fd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -28,6 +28,12 @@ import org.apache.spark.sql.{SchemaRDD, execution, Row} */ class HiveQuerySuite extends HiveComparisonTest { + test("CREATE TABLE AS runs once") { + hql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() + assert(hql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, + "Incorrect number of rows in created table") + } + createQueryTest("between", "SELECT * FROM src WHERE key Between 1 and 2") From f14b00a9c60863afda15681fbf5682247351fa39 Mon Sep 17 00:00:00 2001 From: nravi Date: Thu, 19 Jun 2014 17:11:06 -0700 Subject: [PATCH 43/57] [SPARK-2151] Recognize memory format for spark-submit int format expected for input memory parameter when spark-submit is invoked in standalone cluster mode. Make it consistent with rest of Spark. Author: nravi Closes #1095 from nishkamravi2/master and squashes the following commits: 2b630f9 [nravi] Accept memory input as "30g", "512M" instead of an int value, to be consistent with rest of Spark 3bf8fad [nravi] Merge branch 'master' of https://github.com/apache/spark 5423a03 [nravi] Merge branch 'master' of https://github.com/apache/spark eb663ca [nravi] Merge branch 'master' of https://github.com/apache/spark df2aeb1 [nravi] Improved fix for ConcurrentModificationIssue (Spark-1097, Hadoop-10456) 6b840f0 [nravi] Undo the fix for SPARK-1758 (the problem is fixed) 5108700 [nravi] Fix in Spark for the Concurrent thread modification issue (SPARK-1097, HADOOP-10456) 681b36f [nravi] Fix for SPARK-1758: failing test org.apache.spark.JavaAPISuite.wholeTextFiles --- .../scala/org/apache/spark/deploy/ClientArguments.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 5da9615c9e9af..39150deab863c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -21,6 +21,8 @@ import scala.collection.mutable.ListBuffer import org.apache.log4j.Level +import org.apache.spark.util.MemoryParam + /** * Command-line parser for the driver client. */ @@ -51,8 +53,8 @@ private[spark] class ClientArguments(args: Array[String]) { cores = value.toInt parse(tail) - case ("--memory" | "-m") :: value :: tail => - memory = value.toInt + case ("--memory" | "-m") :: MemoryParam(value) :: tail => + memory = value parse(tail) case ("--supervise" | "-s") :: tail => From 5464e79175e2fc85e2cadf0dd7c9a45dad028326 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Jun 2014 18:24:05 -0700 Subject: [PATCH 44/57] A few minor Spark SQL Scaladoc fixes. Author: Reynold Xin Closes #1139 from rxin/sparksqldoc and squashes the following commits: c3049d8 [Reynold Xin] Fixed line length. 66dc72c [Reynold Xin] A few minor Spark SQL Scaladoc fixes. --- .../sql/catalyst/expressions/Expression.scala | 15 ++--- .../spark/sql/catalyst/plans/QueryPlan.scala | 1 - .../catalyst/plans/logical/LogicalPlan.scala | 12 ++-- .../plans/logical/basicOperators.scala | 65 ++++++++++--------- .../plans/physical/partitioning.scala | 16 ++--- .../apache/spark/sql/execution/Exchange.scala | 9 +-- 6 files changed, 57 insertions(+), 61 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 3912f5f4375fd..0411ce3aefda1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -33,14 +33,11 @@ abstract class Expression extends TreeNode[Expression] { * executed. * * The following conditions are used to determine suitability for constant folding: - * - A [[expressions.Coalesce Coalesce]] is foldable if all of its children are foldable - * - A [[expressions.BinaryExpression BinaryExpression]] is foldable if its both left and right - * child are foldable - * - A [[expressions.Not Not]], [[expressions.IsNull IsNull]], or - * [[expressions.IsNotNull IsNotNull]] is foldable if its child is foldable. - * - A [[expressions.Literal]] is foldable. - * - A [[expressions.Cast Cast]] or [[expressions.UnaryMinus UnaryMinus]] is foldable if its - * child is foldable. + * - A [[Coalesce]] is foldable if all of its children are foldable + * - A [[BinaryExpression]] is foldable if its both left and right child are foldable + * - A [[Not]], [[IsNull]], or [[IsNotNull]] is foldable if its child is foldable + * - A [[Literal]] is foldable + * - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable */ def foldable: Boolean = false def nullable: Boolean @@ -58,7 +55,7 @@ abstract class Expression extends TreeNode[Expression] { lazy val resolved: Boolean = childrenResolved /** - * Returns the [[types.DataType DataType]] of the result of evaluating this expression. It is + * Returns the [[DataType]] of the result of evaluating this expression. It is * invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false). */ def dataType: DataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 00e2d3bc24be9..7b82e19b2e714 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.types.{ArrayType, DataType, StructField, StructType} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0933a31c362d8..edc37e3877c0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -41,19 +41,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { /** * Returns true if this expression and all its children have been resolved to a specific schema * and false if it is still contains any unresolved placeholders. Implementations of LogicalPlan - * can override this (e.g. [[catalyst.analysis.UnresolvedRelation UnresolvedRelation]] should - * return `false`). + * can override this (e.g. + * [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]] + * should return `false`). */ lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved /** * Returns true if all its children of this query plan have been resolved. */ - def childrenResolved = !children.exists(!_.resolved) + def childrenResolved: Boolean = !children.exists(!_.resolved) /** - * Optionally resolves the given string to a - * [[catalyst.expressions.NamedExpression NamedExpression]]. The attribute is expressed as + * Optionally resolves the given string to a [[NamedExpression]]. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ def resolve(name: String): Option[NamedExpression] = { @@ -93,7 +93,7 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { self: Product => // Leaf nodes by definition cannot reference any input attributes. - def references = Set.empty + override def references = Set.empty } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index b777cf4249196..3e0639867b278 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -27,7 +27,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend } /** - * Applies a [[catalyst.expressions.Generator Generator]] to a stream of input rows, combining the + * Applies a [[Generator]] to a stream of input rows, combining the * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. @@ -46,32 +46,32 @@ case class Generate( child: LogicalPlan) extends UnaryNode { - protected def generatorOutput = + protected def generatorOutput: Seq[Attribute] = alias .map(a => generator.output.map(_.withQualifiers(a :: Nil))) .getOrElse(generator.output) - def output = + override def output = if (join) child.output ++ generatorOutput else generatorOutput - def references = + override def references = if (join) child.outputSet else generator.references } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - def output = child.output - def references = condition.references + override def output = child.output + override def references = condition.references } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { // TODO: These aren't really the same attributes as nullability etc might change. - def output = left.output + override def output = left.output override lazy val resolved = childrenResolved && !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } - def references = Set.empty + override def references = Set.empty } case class Join( @@ -80,8 +80,8 @@ case class Join( joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - def references = condition.map(_.references).getOrElse(Set.empty) - def output = joinType match { + override def references = condition.map(_.references).getOrElse(Set.empty) + override def output = joinType match { case LeftSemi => left.output case _ => @@ -96,9 +96,9 @@ case class InsertIntoTable( overwrite: Boolean) extends LogicalPlan { // The table being inserted into is a child for the purposes of transformations. - def children = table :: child :: Nil - def references = Set.empty - def output = child.output + override def children = table :: child :: Nil + override def references = Set.empty + override def output = child.output override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { case (childAttr, tableAttr) => childAttr.dataType == tableAttr.dataType @@ -109,20 +109,20 @@ case class InsertIntoCreatedTable( databaseName: Option[String], tableName: String, child: LogicalPlan) extends UnaryNode { - def references = Set.empty - def output = child.output + override def references = Set.empty + override def output = child.output } case class WriteToFile( path: String, child: LogicalPlan) extends UnaryNode { - def references = Set.empty - def output = child.output + override def references = Set.empty + override def output = child.output } case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { - def output = child.output - def references = order.flatMap(_.references).toSet + override def output = child.output + override def references = order.flatMap(_.references).toSet } case class Aggregate( @@ -131,18 +131,19 @@ case class Aggregate( child: LogicalPlan) extends UnaryNode { - def output = aggregateExpressions.map(_.toAttribute) - def references = (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet + override def output = aggregateExpressions.map(_.toAttribute) + override def references = + (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet } case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { - def output = child.output - def references = limitExpr.references + override def output = child.output + override def references = limitExpr.references } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { - def output = child.output.map(_.withQualifiers(alias :: Nil)) - def references = Set.empty + override def output = child.output.map(_.withQualifiers(alias :: Nil)) + override def references = Set.empty } /** @@ -159,7 +160,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { case otherType => otherType } - val output = child.output.map { + override val output = child.output.map { case a: AttributeReference => AttributeReference( a.name.toLowerCase, @@ -170,21 +171,21 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { case other => other } - def references = Set.empty + override def references = Set.empty } case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) extends UnaryNode { - def output = child.output - def references = Set.empty + override def output = child.output + override def references = Set.empty } case class Distinct(child: LogicalPlan) extends UnaryNode { - def output = child.output - def references = child.outputSet + override def output = child.output + override def references = child.outputSet } case object NoRelation extends LeafNode { - def output = Nil + override def output = Nil } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index ffb3a92f8f340..4bb022cf238af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -46,7 +46,7 @@ case object AllTuples extends Distribution /** * Represents data where tuples that share the same values for the `clustering` - * [[catalyst.expressions.Expression Expressions]] will be co-located. Based on the context, this + * [[Expression Expressions]] will be co-located. Based on the context, this * can mean such tuples are either co-located in the same partition or they will be contiguous * within a single partition. */ @@ -60,7 +60,7 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi /** * Represents data where tuples have been ordered according to the `ordering` - * [[catalyst.expressions.Expression Expressions]]. This is a strictly stronger guarantee than + * [[Expression Expressions]]. This is a strictly stronger guarantee than * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for * the ordering expressions are contiguous and will never be split across partitions. */ @@ -79,19 +79,17 @@ sealed trait Partitioning { val numPartitions: Int /** - * Returns true iff the guarantees made by this - * [[catalyst.plans.physical.Partitioning Partitioning]] are sufficient to satisfy - * the partitioning scheme mandated by the `required` - * [[catalyst.plans.physical.Distribution Distribution]], i.e. the current dataset does not - * need to be re-partitioned for the `required` Distribution (it is possible that tuples within - * a partition need to be reorganized). + * Returns true iff the guarantees made by this [[Partitioning]] are sufficient + * to satisfy the partitioning scheme mandated by the `required` [[Distribution]], + * i.e. the current dataset does not need to be re-partitioned for the `required` + * Distribution (it is possible that tuples within a partition need to be reorganized). */ def satisfies(required: Distribution): Boolean /** * Returns true iff all distribution guarantees made by this partitioning can also be made * for the `other` specified partitioning. - * For example, two [[catalyst.plans.physical.HashPartitioning HashPartitioning]]s are + * For example, two [[HashPartitioning HashPartitioning]]s are * only compatible if the `numPartitions` of them is the same. */ def compatibleWith(other: Partitioning): Boolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 05dfb85b38b02..f46fa0516566f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.sql.{SQLConf, SQLContext, Row} +import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.{NoBind, MutableProjection, RowOrdering} import org.apache.spark.sql.catalyst.plans.physical._ @@ -82,9 +82,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } /** - * Ensures that the [[catalyst.plans.physical.Partitioning Partitioning]] of input data meets the - * [[catalyst.plans.physical.Distribution Distribution]] requirements for each operator by inserting - * [[Exchange]] Operators where required. + * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] + * of input data meets the + * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for + * each operator by inserting [[Exchange]] Operators where required. */ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. From e5514790d70b35422dba2773e43e2e382548fa56 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 19 Jun 2014 21:06:28 -0700 Subject: [PATCH 45/57] HOTFIX: SPARK-2208 local metrics tests can fail on fast machines Author: Patrick Wendell Closes #1141 from pwendell/hotfix and squashes the following commits: 83e4c79 [Patrick Wendell] HOTFIX: SPARK-2208 local metrics tests can fail on fast machines --- .../scala/org/apache/spark/scheduler/SparkListenerSuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index be506e0287a16..abd7b22310f1a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -239,11 +239,14 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers checkNonZeroAvg( taskInfoMetrics.map(_._2.executorDeserializeTime), stageInfo + " executorDeserializeTime") + + /* Test is disabled (SEE SPARK-2208) if (stageInfo.rddInfos.exists(_.name == d4.name)) { checkNonZeroAvg( taskInfoMetrics.map(_._2.shuffleReadMetrics.get.fetchWaitTime), stageInfo + " fetchWaitTime") } + */ taskInfoMetrics.foreach { case (taskInfo, taskMetrics) => taskMetrics.resultSize should be > (0l) From 278ec8a203c7f1de2716d8284f9bdafa54eee1cb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Jun 2014 22:34:21 -0700 Subject: [PATCH 46/57] More minor scaladoc cleanup for Spark SQL. Author: Reynold Xin Closes #1142 from rxin/sqlclean and squashes the following commits: 67a789e [Reynold Xin] More minor scaladoc cleanup for Spark SQL. --- .../catalyst/analysis/HiveTypeCoercion.scala | 8 ++--- .../expressions/namedExpressions.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 34 +++++++++---------- 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 66bff660cadc2..6d331fb501d08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -33,7 +33,7 @@ object HiveTypeCoercion { } /** - * A collection of [[catalyst.rules.Rule Rules]] that can be used to coerce differing types that + * A collection of [[Rule Rules]] that can be used to coerce differing types that * participate in operations into compatible ones. Most of these rules are based on Hive semantics, * but they do not introduce any dependencies on the hive codebase. For this reason they remain in * Catalyst until we have a more standard set of coercions. @@ -53,8 +53,8 @@ trait HiveTypeCoercion { Nil /** - * Applies any changes to [[catalyst.expressions.AttributeReference AttributeReference]] data - * types that are made by other rules to instances higher in the query tree. + * Applies any changes to [[AttributeReference]] data types that are made by other rules to + * instances higher in the query tree. */ object PropagateTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -244,7 +244,7 @@ trait HiveTypeCoercion { } /** - * Casts to/from [[catalyst.types.BooleanType BooleanType]] are transformed into comparisons since + * Casts to/from [[BooleanType]] are transformed into comparisons since * the JVM does not consider Booleans to be numeric types. */ object BooleanCasts extends Rule[LogicalPlan] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index a8145c37c20fa..66ae22e95b60e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -103,7 +103,7 @@ case class Alias(child: Expression, name: String) * A reference to an attribute produced by another operator in the tree. * * @param name The name of this attribute, should only be used during analysis or for debugging. - * @param dataType The [[types.DataType DataType]] of this attribute. + * @param dataType The [[DataType]] of this attribute. * @param nullable True if null is a valid value for this attribute. * @param exprId A globally unique id used to check if different AttributeReferences refer to the * same attribute. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 25a347bec0e4c..b20b5de8c46eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -95,13 +95,13 @@ object ColumnPruning extends Rule[LogicalPlan] { Project(substitutedProjection, child) // Eliminate no-op Projects - case Project(projectList, child) if(child.output == projectList) => child + case Project(projectList, child) if child.output == projectList => child } } /** - * Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with - * equivalent [[catalyst.expressions.Literal Literal]] values. This rule is more specific with + * Replaces [[Expression Expressions]] that can be statically evaluated with + * equivalent [[Literal]] values. This rule is more specific with * Null value propagation from bottom to top of the expression tree. */ object NullPropagation extends Rule[LogicalPlan] { @@ -110,8 +110,8 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) case e @ Sum(Literal(c, _)) if c == 0 => Cast(Literal(0L), e.dataType) case e @ Average(Literal(c, _)) if c == 0 => Literal(0.0, e.dataType) - case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType) - case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType) + case e @ IsNull(c) if !c.nullable => Literal(false, BooleanType) + case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType) case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType) @@ -154,8 +154,8 @@ object NullPropagation extends Rule[LogicalPlan] { } /** - * Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with - * equivalent [[catalyst.expressions.Literal Literal]] values. + * Replaces [[Expression Expressions]] that can be statically evaluated with + * equivalent [[Literal]] values. */ object ConstantFolding extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -197,7 +197,7 @@ object BooleanSimplification extends Rule[LogicalPlan] { } /** - * Combines two adjacent [[catalyst.plans.logical.Filter Filter]] operators into one, merging the + * Combines two adjacent [[Filter]] operators into one, merging the * conditions into one conjunctive predicate. */ object CombineFilters extends Rule[LogicalPlan] { @@ -223,9 +223,8 @@ object SimplifyFilters extends Rule[LogicalPlan] { } /** - * Pushes [[catalyst.plans.logical.Filter Filter]] operators through - * [[catalyst.plans.logical.Project Project]] operators, in-lining any - * [[catalyst.expressions.Alias Aliases]] that were defined in the projection. + * Pushes [[Filter]] operators through [[Project]] operators, in-lining any [[Alias Aliases]] + * that were defined in the projection. * * This heuristic is valid assuming the expression evaluation cost is minimal. */ @@ -248,10 +247,10 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { } /** - * Pushes down [[catalyst.plans.logical.Filter Filter]] operators where the `condition` can be + * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other - * [[catalyst.plans.logical.Filter Filter]] conditions are moved into the `condition` of the - * [[catalyst.plans.logical.Join Join]]. + * [[Filter]] conditions are moved into the `condition` of the [[Join]]. + * * And also Pushes down the join filter, where the `condition` can be evaluated using only the * attributes of the left or right side of sub query when applicable. * @@ -345,8 +344,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } /** - * Removes [[catalyst.expressions.Cast Casts]] that are unnecessary because the input is already - * the correct type. + * Removes [[Cast Casts]] that are unnecessary because the input is already the correct type. */ object SimplifyCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { @@ -355,7 +353,7 @@ object SimplifyCasts extends Rule[LogicalPlan] { } /** - * Combines two adjacent [[catalyst.plans.logical.Limit Limit]] operators into one, merging the + * Combines two adjacent [[Limit]] operators into one, merging the * expressions into one single expression. */ object CombineLimits extends Rule[LogicalPlan] { @@ -366,7 +364,7 @@ object CombineLimits extends Rule[LogicalPlan] { } /** - * Removes the inner [[catalyst.expressions.CaseConversionExpression]] that are unnecessary because + * Removes the inner [[CaseConversionExpression]] that are unnecessary because * the inner conversion is overwritten by the outer one. */ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { From d3b7671c1f9c1eca956fda15fa7573649fd284b3 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 19 Jun 2014 23:39:03 -0700 Subject: [PATCH 47/57] [SQL] Improve Speed of InsertIntoHiveTable Author: Michael Armbrust Closes #1130 from marmbrus/noFunctional and squashes the following commits: ccdb68c [Michael Armbrust] Remove functional programming and Array allocations from fast path in InsertIntoHiveTable. --- .../spark/sql/hive/execution/hiveOperators.scala | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala index 240aa0df4935a..b19579331fea0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala @@ -371,12 +371,18 @@ case class InsertIntoHiveTable( ObjectInspectorCopyOption.JAVA) .asInstanceOf[StructObjectInspector] + + val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val outputData = new Array[Any](fieldOIs.length) iter.map { row => - // Casts Strings to HiveVarchars when necessary. - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector) - val mappedRow = row.zip(fieldOIs).map(wrap) + var i = 0 + while (i < row.length) { + // Casts Strings to HiveVarchars when necessary. + outputData(i) = wrap(row(i), fieldOIs(i)) + i += 1 + } - serializer.serialize(mappedRow.toArray, standardOI) + serializer.serialize(outputData, standardOI) } } From f397e92eb2986f4436fb9e66777fc652f91d8494 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 19 Jun 2014 23:41:38 -0700 Subject: [PATCH 48/57] [SPARK-2177][SQL] describe table result contains only one column ``` scala> hql("describe src").collect().foreach(println) [key string None ] [value string None ] ``` The result should contain 3 columns instead of one. This screws up JDBC or even the downstream consumer of the Scala/Java/Python APIs. I am providing a workaround. We handle a subset of describe commands in Spark SQL, which are defined by ... ``` DESCRIBE [EXTENDED] [db_name.]table_name ``` All other cases are treated as Hive native commands. Also, if we upgrade Hive to 0.13, we need to check the results of context.sessionState.isHiveServerQuery() to determine how to split the result. This method is introduced by https://issues.apache.org/jira/browse/HIVE-4545. We may want to set Hive to use JsonMetaDataFormatter for the output of a DDL statement (`set hive.ddl.output.format=json` introduced by https://issues.apache.org/jira/browse/HIVE-2822). The link to JIRA: https://issues.apache.org/jira/browse/SPARK-2177 Author: Yin Huai Closes #1118 from yhuai/SPARK-2177 and squashes the following commits: fd2534c [Yin Huai] Merge remote-tracking branch 'upstream/master' into SPARK-2177 b9b9aa5 [Yin Huai] rxin's comments. e7c4e72 [Yin Huai] Fix unit test. 656b068 [Yin Huai] 100 characters. 6387217 [Yin Huai] Merge remote-tracking branch 'upstream/master' into SPARK-2177 8003cf3 [Yin Huai] Generate strings with the format like Hive for unit tests. 9787fff [Yin Huai] Merge remote-tracking branch 'upstream/master' into SPARK-2177 440c5af [Yin Huai] rxin's comments. f1a417e [Yin Huai] Update doc. 83adb2f [Yin Huai] Merge remote-tracking branch 'upstream/master' into SPARK-2177 366f891 [Yin Huai] Add describe command. 74bd1d4 [Yin Huai] Merge remote-tracking branch 'upstream/master' into SPARK-2177 342fdf7 [Yin Huai] Split to up to 3 parts. 725e88c [Yin Huai] Merge remote-tracking branch 'upstream/master' into SPARK-2177 bb8bbef [Yin Huai] Split every string in the result of a describe command. --- .../sql/catalyst/plans/logical/commands.scala | 16 +++ .../apache/spark/sql/execution/commands.scala | 21 ++++ .../apache/spark/sql/hive/HiveContext.scala | 5 + .../org/apache/spark/sql/hive/HiveQl.scala | 70 +++++++++--- .../spark/sql/hive/HiveStrategies.scala | 10 ++ .../sql/hive/execution/hiveOperators.scala | 60 +++++++++++ .../hive/execution/HiveComparisonTest.scala | 24 ++++- .../execution/HiveCompatibilitySuite.scala | 17 +-- .../sql/hive/execution/HiveQuerySuite.scala | 102 ++++++++++++++++-- 9 files changed, 294 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 3299e86b85941..1d5f033f0d274 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -60,3 +60,19 @@ case class ExplainCommand(plan: LogicalPlan) extends Command { * Returned for the "CACHE TABLE tableName" and "UNCACHE TABLE tableName" command. */ case class CacheCommand(tableName: String, doCache: Boolean) extends Command + +/** + * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. + * @param table The table to be described. + * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. + * It is effective only when the table is a Hive table. + */ +case class DescribeCommand( + table: LogicalPlan, + isExtended: Boolean) extends Command { + override def output = Seq( + // Column names are based on Hive. + BoundReference(0, AttributeReference("col_name", StringType, nullable = false)()), + BoundReference(1, AttributeReference("data_type", StringType, nullable = false)()), + BoundReference(2, AttributeReference("comment", StringType, nullable = false)())) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index f5d0834a4993d..acb1b0f4dc229 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -121,3 +121,24 @@ case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: override def output: Seq[Attribute] = Seq.empty } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( + @transient context: SQLContext) + extends LeafNode with Command { + + override protected[sql] lazy val sideEffectResult: Seq[(String, String, String)] = { + Seq(("# Registered as a temporary table", null, null)) ++ + child.output.map(field => (field.name, field.dataType.toString, null)) + } + + override def execute(): RDD[Row] = { + val rows = sideEffectResult.map { + case (name, dataType, comment) => new GenericRow(Array[Any](name, dataType, comment)) + } + context.sparkContext.parallelize(rows, 1) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index cc95b7af0abf6..7695242a81601 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.{Command => PhysicalCommand} +import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand /** * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is @@ -291,6 +292,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * execution is simply passed back to Hive. */ def stringResult(): Seq[String] = executedPlan match { + case describeHiveTableCommand: DescribeHiveTableCommand => + // If it is a describe command for a Hive table, we want to have the output format + // be similar with Hive. + describeHiveTableCommand.hiveString case command: PhysicalCommand => command.sideEffectResult.map(_.toString) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 844673f66d103..df761b073a75a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -52,7 +52,6 @@ private[hive] case class AddFile(filePath: String) extends Command private[hive] object HiveQl { protected val nativeCommands = Seq( "TOK_DESCFUNCTION", - "TOK_DESCTABLE", "TOK_DESCDATABASE", "TOK_SHOW_TABLESTATUS", "TOK_SHOWDATABASES", @@ -120,6 +119,12 @@ private[hive] object HiveQl { "TOK_SWITCHDATABASE" ) + // Commands that we do not need to explain. + protected val noExplainCommands = Seq( + "TOK_CREATETABLE", + "TOK_DESCTABLE" + ) ++ nativeCommands + /** * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations * similar to [[catalyst.trees.TreeNode]]. @@ -362,13 +367,20 @@ private[hive] object HiveQl { } } + protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { + val (db, tableName) = + tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { + case Seq(tableOnly) => (None, tableOnly) + case Seq(databaseName, table) => (Some(databaseName), table) + } + + (db, tableName) + } + protected def nodeToPlan(node: Node): LogicalPlan = node match { // Just fake explain for any of the native commands. - case Token("TOK_EXPLAIN", explainArgs) if nativeCommands contains explainArgs.head.getText => - ExplainCommand(NoRelation) - // Create tables aren't native commands due to CTAS queries, but we still don't need to - // explain them. - case Token("TOK_EXPLAIN", explainArgs) if explainArgs.head.getText == "TOK_CREATETABLE" => + case Token("TOK_EXPLAIN", explainArgs) + if noExplainCommands.contains(explainArgs.head.getText) => ExplainCommand(NoRelation) case Token("TOK_EXPLAIN", explainArgs) => // Ignore FORMATTED if present. @@ -377,6 +389,39 @@ private[hive] object HiveQl { // TODO: support EXTENDED? ExplainCommand(nodeToPlan(query)) + case Token("TOK_DESCTABLE", describeArgs) => + // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL + val Some(tableType) :: formatted :: extended :: pretty :: Nil = + getClauses(Seq("TOK_TABTYPE", "FORMATTED", "EXTENDED", "PRETTY"), describeArgs) + if (formatted.isDefined || pretty.isDefined) { + // FORMATTED and PRETTY are not supported and this statement will be treated as + // a Hive native command. + NativePlaceholder + } else { + tableType match { + case Token("TOK_TABTYPE", nameParts) if nameParts.size == 1 => { + nameParts.head match { + case Token(".", dbName :: tableName :: Nil) => + // It is describing a table with the format like "describe db.table". + // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue. + val (db, tableName) = extractDbNameTableName(nameParts.head) + DescribeCommand( + UnresolvedRelation(db, tableName, None), extended.isDefined) + case Token(".", dbName :: tableName :: colName :: Nil) => + // It is describing a column with the format like "describe db.table column". + NativePlaceholder + case tableName => + // It is describing a table with the format like "describe table". + DescribeCommand( + UnresolvedRelation(None, tableName.getText, None), + extended.isDefined) + } + } + // All other cases. + case _ => NativePlaceholder + } + } + case Token("TOK_CREATETABLE", children) if children.collect { case t@Token("TOK_QUERY", _) => t }.nonEmpty => // TODO: Parse other clauses. @@ -414,11 +459,8 @@ private[hive] object HiveQl { s"Unhandled clauses: ${notImplemented.flatten.map(dumpTree(_)).mkString("\n")}") } - val (db, tableName) = - tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { - case Seq(tableOnly) => (None, tableOnly) - case Seq(databaseName, table) => (Some(databaseName), table) - } + val (db, tableName) = extractDbNameTableName(tableNameParts) + InsertIntoCreatedTable(db, tableName, nodeToPlan(query)) // If its not a "CREATE TABLE AS" like above then just pass it back to hive as a native command. @@ -736,11 +778,7 @@ private[hive] object HiveQl { val Some(tableNameParts) :: partitionClause :: Nil = getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - val (db, tableName) = - tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { - case Seq(tableOnly) => (None, tableOnly) - case Seq(databaseName, table) => (Some(databaseName), table) - } + val (db, tableName) = extractDbNameTableName(tableNameParts) val partitionKeys = partitionClause.map(_.getChildren.map { // Parse partitions. We also make keys case insensitive. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 0ac0ee9071f36..af7687b40429b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -81,6 +81,16 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.NativeCommand(sql) => NativeCommand(sql, plan.output)(context) :: Nil + case describe: logical.DescribeCommand => { + val resolvedTable = context.executePlan(describe.table).analyzed + resolvedTable match { + case t: MetastoreRelation => + Seq(DescribeHiveTableCommand( + t, describe.output, describe.isExtended)(context)) + case o: LogicalPlan => + Seq(DescribeCommand(planLater(o), describe.output)(context)) + } + } case _ => Nil } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala index b19579331fea0..2de2db28a7e04 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.hive.execution import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.MetaStoreUtils +import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Hive} +import org.apache.hadoop.hive.ql.metadata.formatting.MetaDataFormatUtils import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption @@ -462,3 +464,61 @@ case class NativeCommand( override def otherCopyArgs = context :: Nil } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class DescribeHiveTableCommand( + table: MetastoreRelation, + output: Seq[Attribute], + isExtended: Boolean)( + @transient context: HiveContext) + extends LeafNode with Command { + + // Strings with the format like Hive. It is used for result comparison in our unit tests. + lazy val hiveString: Seq[String] = { + val alignment = 20 + val delim = "\t" + + sideEffectResult.map { + case (name, dataType, comment) => + String.format("%-" + alignment + "s", name) + delim + + String.format("%-" + alignment + "s", dataType) + delim + + String.format("%-" + alignment + "s", Option(comment).getOrElse("None")) + } + } + + override protected[sql] lazy val sideEffectResult: Seq[(String, String, String)] = { + // Trying to mimic the format of Hive's output. But not exactly the same. + var results: Seq[(String, String, String)] = Nil + + val columns: Seq[FieldSchema] = table.hiveQlTable.getCols + val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols + results ++= columns.map(field => (field.getName, field.getType, field.getComment)) + if (!partitionColumns.isEmpty) { + val partColumnInfo = + partitionColumns.map(field => (field.getName, field.getType, field.getComment)) + results ++= + partColumnInfo ++ + Seq(("# Partition Information", "", "")) ++ + Seq((s"# ${output.get(0).name}", output.get(1).name, output.get(2).name)) ++ + partColumnInfo + } + + if (isExtended) { + results ++= Seq(("Detailed Table Information", table.hiveQlTable.getTTable.toString, "")) + } + + results + } + + override def execute(): RDD[Row] = { + val rows = sideEffectResult.map { + case (name, dataType, comment) => new GenericRow(Array[Any](name, dataType, comment)) + } + context.sparkContext.parallelize(rows, 1) + } + + override def otherCopyArgs = context :: Nil +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 24c929ff7430d..08ef4d9b6bb93 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -144,6 +144,12 @@ abstract class HiveComparisonTest case _: SetCommand => Seq("0") case _: LogicalNativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "") case _: ExplainCommand => answer + case _: DescribeCommand => + // Filter out non-deterministic lines and lines which do not have actual results but + // can introduce problems because of the way Hive formats these lines. + // Then, remove empty lines. Do not sort the results. + answer.filterNot( + r => nonDeterministicLine(r) || ignoredLine(r)).map(_.trim).filterNot(_ == "") case plan => if (isSorted(plan)) answer else answer.sorted } orderedAnswer.map(cleanPaths) @@ -169,6 +175,16 @@ abstract class HiveComparisonTest protected def nonDeterministicLine(line: String) = nonDeterministicLineIndicators.exists(line contains _) + // This list contains indicators for those lines which do not have actual results and we + // want to ignore. + lazy val ignoredLineIndicators = Seq( + "# Partition Information", + "# col_name" + ) + + protected def ignoredLine(line: String) = + ignoredLineIndicators.exists(line contains _) + /** * Removes non-deterministic paths from `str` so cached answers will compare correctly. */ @@ -329,11 +345,17 @@ abstract class HiveComparisonTest if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && preparedHive != catalyst) { - val hivePrintOut = s"== HIVE - ${hive.size} row(s) ==" +: preparedHive + val hivePrintOut = s"== HIVE - ${preparedHive.size} row(s) ==" +: preparedHive val catalystPrintOut = s"== CATALYST - ${catalyst.size} row(s) ==" +: catalyst val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n") + println("hive output") + hive.foreach(println) + + println("catalyst printout") + catalyst.foreach(println) + if (recomputeCache) { logger.warn(s"Clearing cache files for failed test $testCaseName") hiveCacheFiles.foreach(_.delete()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index ee194dbcb77b2..cdfc2d0c17384 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -78,7 +78,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter_merge", "alter_concatenate_indexed_table", "protectmode2", - "describe_table", + //"describe_table", "describe_comment_nonascii", "udf5", "udf_java_method", @@ -177,7 +177,16 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // After stop taking the `stringOrError` route, exceptions are thrown from these cases. // See SPARK-2129 for details. "join_view", - "mergejoins_mixed" + "mergejoins_mixed", + + // Returning the result of a describe state as a JSON object is not supported. + "describe_table_json", + "describe_database_json", + "describe_formatted_view_partitioned_json", + + // Hive returns the results of describe as plain text. Comments with multiple lines + // introduce extra lines in the Hive results, which make the result comparison fail. + "describe_comment_indent" ) /** @@ -292,11 +301,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "default_partition_name", "delimiter", "desc_non_existent_tbl", - "describe_comment_indent", - "describe_database_json", "describe_formatted_view_partitioned", - "describe_formatted_view_partitioned_json", - "describe_table_json", "diff_part_input_formats", "disable_file_format_check", "drop_function", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 5118f4b3f99fd..9f5cf282f7c48 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -21,7 +21,9 @@ import scala.util.Try import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SchemaRDD, execution, Row} +import org.apache.spark.sql.{SchemaRDD, Row} + +case class TestData(a: Int, b: String) /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. @@ -240,13 +242,6 @@ class HiveQuerySuite extends HiveComparisonTest { .map(_.getString(0)) .contains(tableName)) - assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { - hql(s"DESCRIBE $tableName") - .select('result) - .collect() - .map(_.getString(0).split("\t").map(_.trim)) - } - assert(isExplanation(hql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) TestHive.reset() @@ -263,6 +258,97 @@ class HiveQuerySuite extends HiveComparisonTest { assert(Try(q0.count()).isSuccess) } + test("DESCRIBE commands") { + hql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") + + hql( + """FROM src INSERT OVERWRITE TABLE test_describe_commands1 PARTITION (dt='2008-06-08') + |SELECT key, value + """.stripMargin) + + // Describe a table + assertResult( + Array( + Array("key", "int", null), + Array("value", "string", null), + Array("dt", "string", null), + Array("# Partition Information", "", ""), + Array("# col_name", "data_type", "comment"), + Array("dt", "string", null)) + ) { + hql("DESCRIBE test_describe_commands1") + .select('col_name, 'data_type, 'comment) + .collect() + } + + // Describe a table with a fully qualified table name + assertResult( + Array( + Array("key", "int", null), + Array("value", "string", null), + Array("dt", "string", null), + Array("# Partition Information", "", ""), + Array("# col_name", "data_type", "comment"), + Array("dt", "string", null)) + ) { + hql("DESCRIBE default.test_describe_commands1") + .select('col_name, 'data_type, 'comment) + .collect() + } + + // Describe a column is a native command + assertResult(Array(Array("value", "string", "from deserializer"))) { + hql("DESCRIBE test_describe_commands1 value") + .select('result) + .collect() + .map(_.getString(0).split("\t").map(_.trim)) + } + + // Describe a column is a native command + assertResult(Array(Array("value", "string", "from deserializer"))) { + hql("DESCRIBE default.test_describe_commands1 value") + .select('result) + .collect() + .map(_.getString(0).split("\t").map(_.trim)) + } + + // Describe a partition is a native command + assertResult( + Array( + Array("key", "int", "None"), + Array("value", "string", "None"), + Array("dt", "string", "None"), + Array("", "", ""), + Array("# Partition Information", "", ""), + Array("# col_name", "data_type", "comment"), + Array("", "", ""), + Array("dt", "string", "None")) + ) { + hql("DESCRIBE test_describe_commands1 PARTITION (dt='2008-06-08')") + .select('result) + .collect() + .map(_.getString(0).split("\t").map(_.trim)) + } + + // Describe a registered temporary table. + val testData: SchemaRDD = + TestHive.sparkContext.parallelize( + TestData(1, "str1") :: + TestData(1, "str2") :: Nil) + testData.registerAsTable("test_describe_commands2") + + assertResult( + Array( + Array("# Registered as a temporary table", null, null), + Array("a", "IntegerType", null), + Array("b", "StringType", null)) + ) { + hql("DESCRIBE test_describe_commands2") + .select('col_name, 'data_type, 'comment) + .collect() + } + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" From f479cf3743e416ee08e62806e1b34aff5998ac22 Mon Sep 17 00:00:00 2001 From: Andre Schumacher Date: Thu, 19 Jun 2014 23:47:45 -0700 Subject: [PATCH 49/57] SPARK-1293 [SQL] Parquet support for nested types It should be possible to import and export data stored in Parquet's columnar format that contains nested types. For example: ```java message AddressBook { required binary owner; optional group ownerPhoneNumbers { repeated binary array; } optional group contacts { repeated group array { required binary name; optional binary phoneNumber; } } optional group nameToApartmentNumber { repeated group map { required binary key; required int32 value; } } } ``` The example could model a type (AddressBook) that contains records made of strings (owner), lists (ownerPhoneNumbers) and a table of contacts (e.g., a list of pairs or a map that can contain null values but keys must not be null). The list of tasks are as follows:
Implement support for converting nested Parquet types to Spark/Catalyst types:
- [x] Structs - [x] Lists - [x] Maps (note: currently keys need to be Strings)
Implement import (via ``parquetFile``) of nested Parquet types (first version in this PR)
- [x] Initial version
Implement export (via ``saveAsParquetFile``)
- [x] Initial version
Test support for AvroParquet, etc.
- [x] Initial testing of import of avro-generated Parquet data (simple + nested) Example: ```scala val data = TestSQLContext .parquetFile("input.dir") .toSchemaRDD data.registerAsTable("data") sql("SELECT owner, contacts[1].name, nameToApartmentNumber['John'] FROM data").collect() ``` Author: Andre Schumacher Author: Michael Armbrust Closes #360 from AndreSchumacher/nested_parquet and squashes the following commits: 30708c8 [Andre Schumacher] Taking out AvroParquet test for now to remove Avro dependency 95c1367 [Andre Schumacher] Changes to ParquetRelation and its metadata 7eceb67 [Andre Schumacher] Review feedback 94eea3a [Andre Schumacher] Scalastyle 403061f [Andre Schumacher] Fixing some issues with tests and schema metadata b8a8b9a [Andre Schumacher] More fixes to short and byte conversion 63d1b57 [Andre Schumacher] Cleaning up and Scalastyle 88e6bdb [Andre Schumacher] Attempting to fix loss of schema 37e0a0a [Andre Schumacher] Cleaning up 14c3fd8 [Andre Schumacher] Attempting to fix Spark-Parquet schema conversion 3e1456c [Michael Armbrust] WIP: Directly serialize catalyst attributes. f7aeba3 [Michael Armbrust] [SPARK-1982] Support for ByteType and ShortType. 3104886 [Michael Armbrust] Nested Rows should be Rows, not Seqs. 3c6b25f [Andre Schumacher] Trying to reduce no-op changes wrt master 31465d6 [Andre Schumacher] Scalastyle: fixing commented out bottom de02538 [Andre Schumacher] Cleaning up ParquetTestData 2f5a805 [Andre Schumacher] Removing stripMargin from test schemas 191bc0d [Andre Schumacher] Changing to Seq for ArrayType, refactoring SQLParser for nested field extension cbb5793 [Andre Schumacher] Code review feedback 32229c7 [Andre Schumacher] Removing Row nested values and placing by generic types 0ae9376 [Andre Schumacher] Doc strings and simplifying ParquetConverter.scala a6b4f05 [Andre Schumacher] Cleaning up ArrayConverter, moving classTag to NativeType, adding NativeRow 431f00f [Andre Schumacher] Fixing problems introduced during rebase c52ff2c [Andre Schumacher] Adding native-array converter 619c397 [Andre Schumacher] Completing Map testcase 79d81d5 [Andre Schumacher] Replacing field names for array and map in WriteSupport f466ff0 [Andre Schumacher] Added ParquetAvro tests and revised Array conversion adc1258 [Andre Schumacher] Optimizing imports e99cc51 [Andre Schumacher] Fixing nested WriteSupport and adding tests 1dc5ac9 [Andre Schumacher] First version of WriteSupport for nested types d1911dc [Andre Schumacher] Simplifying ArrayType conversion f777b4b [Andre Schumacher] Scalastyle 824500c [Andre Schumacher] Adding attribute resolution for MapType b539fde [Andre Schumacher] First commit for MapType a594aed [Andre Schumacher] Scalastyle 4e25fcb [Andre Schumacher] Adding resolution of complex ArrayTypes f8f8911 [Andre Schumacher] For primitive rows fall back to more efficient converter, code reorg 6dbc9b7 [Andre Schumacher] Fixing some problems intruduced during rebase b7fcc35 [Andre Schumacher] Documenting conversions, bugfix, wrappers of Rows ee70125 [Andre Schumacher] fixing one problem with arrayconverter 98219cf [Andre Schumacher] added struct converter 5d80461 [Andre Schumacher] fixing one problem with nested structs and breaking up files 1b1b3d6 [Andre Schumacher] Fixing one problem with nested arrays ddb40d2 [Andre Schumacher] Extending tests for nested Parquet data 745a42b [Andre Schumacher] Completing testcase for nested data (Addressbook( 6125c75 [Andre Schumacher] First working nested Parquet record input 4d4892a [Andre Schumacher] First commit nested Parquet read converters aa688fe [Andre Schumacher] Adding conversion of nested Parquet schemas --- .../apache/spark/sql/catalyst/SqlParser.scala | 111 +-- .../catalyst/expressions/complexTypes.scala | 2 + .../spark/sql/catalyst/types/dataTypes.scala | 98 ++- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/api/java/JavaSQLContext.scala | 4 +- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../spark/sql/parquet/ParquetConverter.scala | 667 ++++++++++++++++++ .../spark/sql/parquet/ParquetRelation.scala | 182 +---- .../sql/parquet/ParquetTableOperations.scala | 25 +- .../sql/parquet/ParquetTableSupport.scala | 326 ++++++--- .../spark/sql/parquet/ParquetTestData.scala | 298 +++++++- .../spark/sql/parquet/ParquetTypes.scala | 408 +++++++++++ .../spark/sql/parquet/ParquetQuerySuite.scala | 356 +++++++++- .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 +- 14 files changed, 2102 insertions(+), 384 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 46fcfbb9e26ba..2ad2d04af5704 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -66,43 +66,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected case class Keyword(str: String) protected implicit def asParser(k: Keyword): Parser[String] = - allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - - protected class SqlLexical extends StdLexical { - case class FloatLit(chars: String) extends Token { - override def toString = chars - } - override lazy val token: Parser[Token] = ( - identChar ~ rep( identChar | digit ) ^^ - { case first ~ rest => processIdent(first :: rest mkString "") } - | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ { - case i ~ None => NumericLit(i mkString "") - case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString("")) - } - | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^ - { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") } - | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^ - { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } - | EofCh ^^^ EOF - | '\'' ~> failure("unclosed string literal") - | '\"' ~> failure("unclosed string literal") - | delim - | failure("illegal character") - ) - - override def identChar = letter | elem('.') | elem('_') - - override def whitespace: Parser[Any] = rep( - whitespaceChar - | '/' ~ '*' ~ comment - | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) - | '#' ~ rep( chrExcept(EofCh, '\n') ) - | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') ) - | '/' ~ '*' ~ failure("unclosed comment") - ) - } - - override val lexical = new SqlLexical + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) protected val ALL = Keyword("ALL") protected val AND = Keyword("AND") @@ -161,24 +125,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers { this.getClass .getMethods .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword]) - - /** Generate all variations of upper and lower case of a given string */ - private def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { - if (s == "") { - Stream(prefix) - } else { - allCaseVersions(s.tail, prefix + s.head.toLower) ++ - allCaseVersions(s.tail, prefix + s.head.toUpper) - } - } + .map(_.invoke(this).asInstanceOf[Keyword].str) - lexical.reserved ++= reservedWords.flatMap(w => allCaseVersions(w.str)) - - lexical.delimiters += ( - "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]" - ) + override val lexical = new SqlLexical(reservedWords) protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { exprs.zipWithIndex.map { @@ -383,7 +332,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) protected lazy val baseExpression: PackratParser[Expression] = - expression ~ "[" ~ expression <~ "]" ^^ { + expression ~ "[" ~ expression <~ "]" ^^ { case base ~ _ ~ ordinal => GetItem(base, ordinal) } | TRUE ^^^ Literal(true, BooleanType) | @@ -399,3 +348,55 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected lazy val dataType: Parser[DataType] = STRING ^^^ StringType } + +class SqlLexical(val keywords: Seq[String]) extends StdLexical { + case class FloatLit(chars: String) extends Token { + override def toString = chars + } + + reserved ++= keywords.flatMap(w => allCaseVersions(w)) + + delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":", "[", "]" + ) + + override lazy val token: Parser[Token] = ( + identChar ~ rep( identChar | digit ) ^^ + { case first ~ rest => processIdent(first :: rest mkString "") } + | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ { + case i ~ None => NumericLit(i mkString "") + case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString("")) + } + | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^ + { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") } + | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^ + { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '\"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar = letter | elem('_') | elem('.') + + override def whitespace: Parser[Any] = rep( + whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) + | '#' ~ rep( chrExcept(EofCh, '\n') ) + | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') ) + | '/' ~ '*' ~ failure("unclosed comment") + ) + + /** Generate all variations of upper and lower case of a given string */ + def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { + if (s == "") { + Stream(prefix) + } else { + allCaseVersions(s.tail, prefix + s.head.toLower) ++ + allCaseVersions(s.tail, prefix + s.head.toUpper) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index b6aeae92f8bec..5d3bb25ad568c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -50,6 +50,8 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { null } else { if (child.dataType.isInstanceOf[ArrayType]) { + // TODO: consider using Array[_] for ArrayType child to avoid + // boxing of primitives val baseValue = value.asInstanceOf[Seq[_]] val o = key.asInstanceOf[Int] if (o >= baseValue.size || o < 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index da34bd3a21503..bb77bccf86176 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -19,9 +19,71 @@ package org.apache.spark.sql.catalyst.types import java.sql.Timestamp -import scala.reflect.runtime.universe.{typeTag, TypeTag} +import scala.util.parsing.combinator.RegexParsers -import org.apache.spark.sql.catalyst.expressions.Expression +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.util.Utils + +/** + * + */ +object DataType extends RegexParsers { + protected lazy val primitiveType: Parser[DataType] = + "StringType" ^^^ StringType | + "FloatType" ^^^ FloatType | + "IntegerType" ^^^ IntegerType | + "ByteType" ^^^ ByteType | + "ShortType" ^^^ ShortType | + "DoubleType" ^^^ DoubleType | + "LongType" ^^^ LongType | + "BinaryType" ^^^ BinaryType | + "BooleanType" ^^^ BooleanType | + "DecimalType" ^^^ DecimalType | + "TimestampType" ^^^ TimestampType + + protected lazy val arrayType: Parser[DataType] = + "ArrayType" ~> "(" ~> dataType <~ ")" ^^ ArrayType + + protected lazy val mapType: Parser[DataType] = + "MapType" ~> "(" ~> dataType ~ "," ~ dataType <~ ")" ^^ { + case t1 ~ _ ~ t2 => MapType(t1, t2) + } + + protected lazy val structField: Parser[StructField] = + ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { + case name ~ tpe ~ nullable => + StructField(name, tpe, nullable = nullable) + } + + protected lazy val boolVal: Parser[Boolean] = + "true" ^^^ true | + "false" ^^^ false + + + protected lazy val structType: Parser[DataType] = + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { + case fields => new StructType(fields) + } + + protected lazy val dataType: Parser[DataType] = + arrayType | + mapType | + structType | + primitiveType + + /** + * Parses a string representation of a DataType. + * + * TODO: Generate parser as pickler... + */ + def apply(asString: String): DataType = parseAll(dataType, asString) match { + case Success(result, _) => result + case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure") + } +} abstract class DataType { /** Matches any expression that evaluates to this DataType */ @@ -29,25 +91,36 @@ abstract class DataType { case e: Expression if e.dataType == this => true case _ => false } + + def isPrimitive: Boolean = false } case object NullType extends DataType +trait PrimitiveType extends DataType { + override def isPrimitive = true +} + abstract class NativeType extends DataType { type JvmType @transient val tag: TypeTag[JvmType] val ordering: Ordering[JvmType] + + @transient val classTag = { + val mirror = runtimeMirror(Utils.getSparkClassLoader) + ClassTag[JvmType](mirror.runtimeClass(tag.tpe)) + } } -case object StringType extends NativeType { +case object StringType extends NativeType with PrimitiveType { type JvmType = String @transient lazy val tag = typeTag[JvmType] val ordering = implicitly[Ordering[JvmType]] } -case object BinaryType extends DataType { +case object BinaryType extends DataType with PrimitiveType { type JvmType = Array[Byte] } -case object BooleanType extends NativeType { +case object BooleanType extends NativeType with PrimitiveType { type JvmType = Boolean @transient lazy val tag = typeTag[JvmType] val ordering = implicitly[Ordering[JvmType]] @@ -63,7 +136,7 @@ case object TimestampType extends NativeType { } } -abstract class NumericType extends NativeType { +abstract class NumericType extends NativeType with PrimitiveType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets @@ -154,6 +227,17 @@ case object FloatType extends FractionalType { case class ArrayType(elementType: DataType) extends DataType case class StructField(name: String, dataType: DataType, nullable: Boolean) -case class StructType(fields: Seq[StructField]) extends DataType + +object StructType { + def fromAttributes(attributes: Seq[Attribute]): StructType = { + StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) + } + + // def apply(fields: Seq[StructField]) = new StructType(fields.toIndexedSeq) +} + +case class StructType(fields: Seq[StructField]) extends DataType { + def toAttributes = fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) +} case class MapType(keyType: DataType, valueType: DataType) extends DataType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1617ec717b2e0..ab376e5504d35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -94,7 +94,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def parquetFile(path: String): SchemaRDD = - new SchemaRDD(this, parquet.ParquetRelation(path)) + new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration))) /** * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index ff9842267ffe0..ff6deeda2394d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -99,7 +99,9 @@ class JavaSQLContext(val sqlContext: SQLContext) { * Loads a parquet file, returning the result as a [[JavaSchemaRDD]]. */ def parquetFile(path: String): JavaSchemaRDD = - new JavaSchemaRDD(sqlContext, ParquetRelation(path)) + new JavaSchemaRDD( + sqlContext, + ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration))) /** * Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index feb280d1d1411..4694f25d6d630 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -154,7 +154,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.WriteToFile(path, child) => val relation = ParquetRelation.create(path, child, sparkContext.hadoopConfiguration) - InsertIntoParquetTable(relation, planLater(child), overwrite=true)(sparkContext) :: Nil + // Note: overwrite=false because otherwise the metadata we just created will be deleted + InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sparkContext) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala new file mode 100644 index 0000000000000..889a408e3c393 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -0,0 +1,667 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} + +import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} +import parquet.schema.MessageType + +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.expressions.{GenericRow, Row, Attribute} +import org.apache.spark.sql.parquet.CatalystConverter.FieldType + +/** + * Collection of converters of Parquet types (group and primitive types) that + * model arrays and maps. The conversions are partly based on the AvroParquet + * converters that are part of Parquet in order to be able to process these + * types. + * + * There are several types of converters: + *
    + *
  • [[org.apache.spark.sql.parquet.CatalystPrimitiveConverter]] for primitive + * (numeric, boolean and String) types
  • + *
  • [[org.apache.spark.sql.parquet.CatalystNativeArrayConverter]] for arrays + * of native JVM element types; note: currently null values are not supported!
  • + *
  • [[org.apache.spark.sql.parquet.CatalystArrayConverter]] for arrays of + * arbitrary element types (including nested element types); note: currently + * null values are not supported!
  • + *
  • [[org.apache.spark.sql.parquet.CatalystStructConverter]] for structs
  • + *
  • [[org.apache.spark.sql.parquet.CatalystMapConverter]] for maps; note: + * currently null values are not supported!
  • + *
  • [[org.apache.spark.sql.parquet.CatalystPrimitiveRowConverter]] for rows + * of only primitive element types
  • + *
  • [[org.apache.spark.sql.parquet.CatalystGroupConverter]] for other nested + * records, including the top-level row record
  • + *
+ */ + +private[sql] object CatalystConverter { + // The type internally used for fields + type FieldType = StructField + + // This is mostly Parquet convention (see, e.g., `ConversionPatterns`). + // Note that "array" for the array elements is chosen by ParquetAvro. + // Using a different value will result in Parquet silently dropping columns. + val ARRAY_ELEMENTS_SCHEMA_NAME = "array" + val MAP_KEY_SCHEMA_NAME = "key" + val MAP_VALUE_SCHEMA_NAME = "value" + val MAP_SCHEMA_NAME = "map" + + // TODO: consider using Array[T] for arrays to avoid boxing of primitive types + type ArrayScalaType[T] = Seq[T] + type StructScalaType[T] = Seq[T] + type MapScalaType[K, V] = Map[K, V] + + protected[parquet] def createConverter( + field: FieldType, + fieldIndex: Int, + parent: CatalystConverter): Converter = { + val fieldType: DataType = field.dataType + fieldType match { + // For native JVM types we use a converter with native arrays + case ArrayType(elementType: NativeType) => { + new CatalystNativeArrayConverter(elementType, fieldIndex, parent) + } + // This is for other types of arrays, including those with nested fields + case ArrayType(elementType: DataType) => { + new CatalystArrayConverter(elementType, fieldIndex, parent) + } + case StructType(fields: Seq[StructField]) => { + new CatalystStructConverter(fields.toArray, fieldIndex, parent) + } + case MapType(keyType: DataType, valueType: DataType) => { + new CatalystMapConverter( + Array( + new FieldType(MAP_KEY_SCHEMA_NAME, keyType, false), + new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, true)), + fieldIndex, + parent) + } + // Strings, Shorts and Bytes do not have a corresponding type in Parquet + // so we need to treat them separately + case StringType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addBinary(value: Binary): Unit = + parent.updateString(fieldIndex, value) + } + } + case ShortType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addInt(value: Int): Unit = + parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.JvmType]) + } + } + case ByteType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addInt(value: Int): Unit = + parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType]) + } + } + // All other primitive types use the default converter + case ctype: NativeType => { // note: need the type tag here! + new CatalystPrimitiveConverter(parent, fieldIndex) + } + case _ => throw new RuntimeException( + s"unable to convert datatype ${field.dataType.toString} in CatalystConverter") + } + } + + protected[parquet] def createRootConverter( + parquetSchema: MessageType, + attributes: Seq[Attribute]): CatalystConverter = { + // For non-nested types we use the optimized Row converter + if (attributes.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType))) { + new CatalystPrimitiveRowConverter(attributes.toArray) + } else { + new CatalystGroupConverter(attributes.toArray) + } + } +} + +private[parquet] abstract class CatalystConverter extends GroupConverter { + /** + * The number of fields this group has + */ + protected[parquet] val size: Int + + /** + * The index of this converter in the parent + */ + protected[parquet] val index: Int + + /** + * The parent converter + */ + protected[parquet] val parent: CatalystConverter + + /** + * Called by child converters to update their value in its parent (this). + * Note that if possible the more specific update methods below should be used + * to avoid auto-boxing of native JVM types. + * + * @param fieldIndex + * @param value + */ + protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit + + protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = + updateField(fieldIndex, value) + + protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = + updateField(fieldIndex, value.getBytes) + + protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = + updateField(fieldIndex, value.toStringUsingUTF8) + + protected[parquet] def isRootConverter: Boolean = parent == null + + protected[parquet] def clearBuffer(): Unit + + /** + * Should only be called in the root (group) converter! + * + * @return + */ + def getCurrentRecord: Row = throw new UnsupportedOperationException +} + +/** + * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record + * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. + * + * @param schema The corresponding Catalyst schema in the form of a list of attributes. + */ +private[parquet] class CatalystGroupConverter( + protected[parquet] val schema: Array[FieldType], + protected[parquet] val index: Int, + protected[parquet] val parent: CatalystConverter, + protected[parquet] var current: ArrayBuffer[Any], + protected[parquet] var buffer: ArrayBuffer[Row]) + extends CatalystConverter { + + def this(schema: Array[FieldType], index: Int, parent: CatalystConverter) = + this( + schema, + index, + parent, + current=null, + buffer=new ArrayBuffer[Row]( + CatalystArrayConverter.INITIAL_ARRAY_SIZE)) + + /** + * This constructor is used for the root converter only! + */ + def this(attributes: Array[Attribute]) = + this(attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), 0, null) + + protected [parquet] val converters: Array[Converter] = + schema.map(field => + CatalystConverter.createConverter(field, schema.indexOf(field), this)) + .toArray + + override val size = schema.size + + override def getCurrentRecord: Row = { + assert(isRootConverter, "getCurrentRecord should only be called in root group converter!") + // TODO: use iterators if possible + // Note: this will ever only be called in the root converter when the record has been + // fully processed. Therefore it will be difficult to use mutable rows instead, since + // any non-root converter never would be sure when it would be safe to re-use the buffer. + new GenericRow(current.toArray) + } + + override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + + // for child converters to update upstream values + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { + current.update(fieldIndex, value) + } + + override protected[parquet] def clearBuffer(): Unit = buffer.clear() + + override def start(): Unit = { + current = ArrayBuffer.fill(size)(null) + converters.foreach { + converter => if (!converter.isPrimitive) { + converter.asInstanceOf[CatalystConverter].clearBuffer + } + } + } + + override def end(): Unit = { + if (!isRootConverter) { + assert(current!=null) // there should be no empty groups + buffer.append(new GenericRow(current.toArray)) + parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]])) + } + } +} + +/** + * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record + * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. Note that his + * converter is optimized for rows of primitive types (non-nested records). + */ +private[parquet] class CatalystPrimitiveRowConverter( + protected[parquet] val schema: Array[FieldType], + protected[parquet] var current: ParquetRelation.RowType) + extends CatalystConverter { + + // This constructor is used for the root converter only + def this(attributes: Array[Attribute]) = + this( + attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), + new ParquetRelation.RowType(attributes.length)) + + protected [parquet] val converters: Array[Converter] = + schema.map(field => + CatalystConverter.createConverter(field, schema.indexOf(field), this)) + .toArray + + override val size = schema.size + + override val index = 0 + + override val parent = null + + // Should be only called in root group converter! + override def getCurrentRecord: ParquetRelation.RowType = current + + override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + + // for child converters to update upstream values + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { + throw new UnsupportedOperationException // child converters should use the + // specific update methods below + } + + override protected[parquet] def clearBuffer(): Unit = {} + + override def start(): Unit = { + var i = 0 + while (i < size) { + current.setNullAt(i) + i = i + 1 + } + } + + override def end(): Unit = {} + + // Overriden here to avoid auto-boxing for primitive types + override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = + current.setBoolean(fieldIndex, value) + + override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = + current.setInt(fieldIndex, value) + + override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = + current.setLong(fieldIndex, value) + + override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = + current.setShort(fieldIndex, value) + + override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = + current.setByte(fieldIndex, value) + + override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = + current.setDouble(fieldIndex, value) + + override protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = + current.setFloat(fieldIndex, value) + + override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = + current.update(fieldIndex, value.getBytes) + + override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = + current.setString(fieldIndex, value.toStringUsingUTF8) +} + +/** + * A `parquet.io.api.PrimitiveConverter` that converts Parquet types to Catalyst types. + * + * @param parent The parent group converter. + * @param fieldIndex The index inside the record. + */ +private[parquet] class CatalystPrimitiveConverter( + parent: CatalystConverter, + fieldIndex: Int) extends PrimitiveConverter { + override def addBinary(value: Binary): Unit = + parent.updateBinary(fieldIndex, value) + + override def addBoolean(value: Boolean): Unit = + parent.updateBoolean(fieldIndex, value) + + override def addDouble(value: Double): Unit = + parent.updateDouble(fieldIndex, value) + + override def addFloat(value: Float): Unit = + parent.updateFloat(fieldIndex, value) + + override def addInt(value: Int): Unit = + parent.updateInt(fieldIndex, value) + + override def addLong(value: Long): Unit = + parent.updateLong(fieldIndex, value) +} + +object CatalystArrayConverter { + val INITIAL_ARRAY_SIZE = 20 +} + +/** + * A `parquet.io.api.GroupConverter` that converts a single-element groups that + * match the characteristics of an array (see + * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an + * [[org.apache.spark.sql.catalyst.types.ArrayType]]. + * + * @param elementType The type of the array elements (complex or primitive) + * @param index The position of this (array) field inside its parent converter + * @param parent The parent converter + * @param buffer A data buffer + */ +private[parquet] class CatalystArrayConverter( + val elementType: DataType, + val index: Int, + protected[parquet] val parent: CatalystConverter, + protected[parquet] var buffer: Buffer[Any]) + extends CatalystConverter { + + def this(elementType: DataType, index: Int, parent: CatalystConverter) = + this( + elementType, + index, + parent, + new ArrayBuffer[Any](CatalystArrayConverter.INITIAL_ARRAY_SIZE)) + + protected[parquet] val converter: Converter = CatalystConverter.createConverter( + new CatalystConverter.FieldType( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + elementType, + false), + fieldIndex=0, + parent=this) + + override def getConverter(fieldIndex: Int): Converter = converter + + // arrays have only one (repeated) field, which is its elements + override val size = 1 + + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { + // fieldIndex is ignored (assumed to be zero but not checked) + if(value == null) { + throw new IllegalArgumentException("Null values inside Parquet arrays are not supported!") + } + buffer += value + } + + override protected[parquet] def clearBuffer(): Unit = { + buffer.clear() + } + + override def start(): Unit = { + if (!converter.isPrimitive) { + converter.asInstanceOf[CatalystConverter].clearBuffer + } + } + + override def end(): Unit = { + assert(parent != null) + // here we need to make sure to use ArrayScalaType + parent.updateField(index, buffer.toArray.toSeq) + clearBuffer() + } +} + +/** + * A `parquet.io.api.GroupConverter` that converts a single-element groups that + * match the characteristics of an array (see + * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an + * [[org.apache.spark.sql.catalyst.types.ArrayType]]. + * + * @param elementType The type of the array elements (native) + * @param index The position of this (array) field inside its parent converter + * @param parent The parent converter + * @param capacity The (initial) capacity of the buffer + */ +private[parquet] class CatalystNativeArrayConverter( + val elementType: NativeType, + val index: Int, + protected[parquet] val parent: CatalystConverter, + protected[parquet] var capacity: Int = CatalystArrayConverter.INITIAL_ARRAY_SIZE) + extends CatalystConverter { + + type NativeType = elementType.JvmType + + private var buffer: Array[NativeType] = elementType.classTag.newArray(capacity) + + private var elements: Int = 0 + + protected[parquet] val converter: Converter = CatalystConverter.createConverter( + new CatalystConverter.FieldType( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + elementType, + false), + fieldIndex=0, + parent=this) + + override def getConverter(fieldIndex: Int): Converter = converter + + // arrays have only one (repeated) field, which is its elements + override val size = 1 + + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = + throw new UnsupportedOperationException + + // Overriden here to avoid auto-boxing for primitive types + override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateShort(fieldIndex: Int, value: Short): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateByte(fieldIndex: Int, value: Byte): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = { + checkGrowBuffer() + buffer(elements) = value.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = { + checkGrowBuffer() + buffer(elements) = value.getBytes.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = { + checkGrowBuffer() + buffer(elements) = value.toStringUsingUTF8.asInstanceOf[NativeType] + elements += 1 + } + + override protected[parquet] def clearBuffer(): Unit = { + elements = 0 + } + + override def start(): Unit = {} + + override def end(): Unit = { + assert(parent != null) + // here we need to make sure to use ArrayScalaType + parent.updateField( + index, + buffer.slice(0, elements).toSeq) + clearBuffer() + } + + private def checkGrowBuffer(): Unit = { + if (elements >= capacity) { + val newCapacity = 2 * capacity + val tmp: Array[NativeType] = elementType.classTag.newArray(newCapacity) + Array.copy(buffer, 0, tmp, 0, capacity) + buffer = tmp + capacity = newCapacity + } + } +} + +/** + * This converter is for multi-element groups of primitive or complex types + * that have repetition level optional or required (so struct fields). + * + * @param schema The corresponding Catalyst schema in the form of a list of + * attributes. + * @param index + * @param parent + */ +private[parquet] class CatalystStructConverter( + override protected[parquet] val schema: Array[FieldType], + override protected[parquet] val index: Int, + override protected[parquet] val parent: CatalystConverter) + extends CatalystGroupConverter(schema, index, parent) { + + override protected[parquet] def clearBuffer(): Unit = {} + + // TODO: think about reusing the buffer + override def end(): Unit = { + assert(!isRootConverter) + // here we need to make sure to use StructScalaType + // Note: we need to actually make a copy of the array since we + // may be in a nested field + parent.updateField(index, new GenericRow(current.toArray)) + } +} + +/** + * A `parquet.io.api.GroupConverter` that converts two-element groups that + * match the characteristics of a map (see + * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an + * [[org.apache.spark.sql.catalyst.types.MapType]]. + * + * @param schema + * @param index + * @param parent + */ +private[parquet] class CatalystMapConverter( + protected[parquet] val schema: Array[FieldType], + override protected[parquet] val index: Int, + override protected[parquet] val parent: CatalystConverter) + extends CatalystConverter { + + private val map = new HashMap[Any, Any]() + + private val keyValueConverter = new CatalystConverter { + private var currentKey: Any = null + private var currentValue: Any = null + val keyConverter = CatalystConverter.createConverter(schema(0), 0, this) + val valueConverter = CatalystConverter.createConverter(schema(1), 1, this) + + override def getConverter(fieldIndex: Int): Converter = { + if (fieldIndex == 0) keyConverter else valueConverter + } + + override def end(): Unit = CatalystMapConverter.this.map += currentKey -> currentValue + + override def start(): Unit = { + currentKey = null + currentValue = null + } + + override protected[parquet] val size: Int = 2 + override protected[parquet] val index: Int = 0 + override protected[parquet] val parent: CatalystConverter = CatalystMapConverter.this + + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { + fieldIndex match { + case 0 => + currentKey = value + case 1 => + currentValue = value + case _ => + new RuntimePermission(s"trying to update Map with fieldIndex $fieldIndex") + } + } + + override protected[parquet] def clearBuffer(): Unit = {} + } + + override protected[parquet] val size: Int = 1 + + override protected[parquet] def clearBuffer(): Unit = {} + + override def start(): Unit = { + map.clear() + } + + override def end(): Unit = { + // here we need to make sure to use MapScalaType + parent.updateField(index, map.toMap) + } + + override def getConverter(fieldIndex: Int): Converter = keyValueConverter + + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = + throw new UnsupportedOperationException +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 32813a66de3c3..96c131a7f8af1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -20,25 +20,16 @@ package org.apache.spark.sql.parquet import java.io.IOException import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction -import org.apache.hadoop.mapreduce.Job -import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{ParquetOutputFormat, Footer, ParquetFileWriter, ParquetFileReader} -import parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata} -import parquet.io.api.{Binary, RecordConsumer} -import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType, MessageTypeParser} -import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} -import parquet.schema.Type.Repetition +import parquet.hadoop.ParquetOutputFormat +import parquet.hadoop.metadata.CompressionCodecName +import parquet.schema.MessageType import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} -import org.apache.spark.sql.catalyst.types._ - -// Implicits -import scala.collection.JavaConversions._ /** * Relation that consists of data stored in a Parquet columnar format. @@ -52,21 +43,20 @@ import scala.collection.JavaConversions._ * * @param path The path to the Parquet file. */ -private[sql] case class ParquetRelation(val path: String) - extends LeafNode with MultiInstanceRelation { +private[sql] case class ParquetRelation( + val path: String, + @transient val conf: Option[Configuration] = None) extends LeafNode with MultiInstanceRelation { self: Product => /** Schema derived from ParquetFile */ def parquetSchema: MessageType = ParquetTypesConverter - .readMetaData(new Path(path)) + .readMetaData(new Path(path), conf) .getFileMetaData .getSchema /** Attributes */ - override val output = - ParquetTypesConverter - .convertToAttributes(parquetSchema) + override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf) override def newInstance = ParquetRelation(path).asInstanceOf[this.type] @@ -141,7 +131,9 @@ private[sql] object ParquetRelation { } ParquetRelation.enableLogForwarding() ParquetTypesConverter.writeMetaData(attributes, path, conf) - new ParquetRelation(path.toString) + new ParquetRelation(path.toString, Some(conf)) { + override val output = attributes + } } private def checkPath(pathStr: String, allowExisting: Boolean, conf: Configuration): Path = { @@ -170,151 +162,3 @@ private[sql] object ParquetRelation { path } } - -private[parquet] object ParquetTypesConverter { - def toDataType(parquetType : ParquetPrimitiveTypeName): DataType = parquetType match { - // for now map binary to string type - // TODO: figure out how Parquet uses strings or why we can't use them in a MessageType schema - case ParquetPrimitiveTypeName.BINARY => StringType - case ParquetPrimitiveTypeName.BOOLEAN => BooleanType - case ParquetPrimitiveTypeName.DOUBLE => DoubleType - case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY => ArrayType(ByteType) - case ParquetPrimitiveTypeName.FLOAT => FloatType - case ParquetPrimitiveTypeName.INT32 => IntegerType - case ParquetPrimitiveTypeName.INT64 => LongType - case ParquetPrimitiveTypeName.INT96 => - // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? - sys.error("Warning: potential loss of precision: converting INT96 to long") - LongType - case _ => sys.error( - s"Unsupported parquet datatype $parquetType") - } - - def fromDataType(ctype: DataType): ParquetPrimitiveTypeName = ctype match { - case StringType => ParquetPrimitiveTypeName.BINARY - case BooleanType => ParquetPrimitiveTypeName.BOOLEAN - case DoubleType => ParquetPrimitiveTypeName.DOUBLE - case ArrayType(ByteType) => ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY - case FloatType => ParquetPrimitiveTypeName.FLOAT - case IntegerType => ParquetPrimitiveTypeName.INT32 - case LongType => ParquetPrimitiveTypeName.INT64 - case _ => sys.error(s"Unsupported datatype $ctype") - } - - def consumeType(consumer: RecordConsumer, ctype: DataType, record: Row, index: Int): Unit = { - ctype match { - case StringType => consumer.addBinary( - Binary.fromByteArray( - record(index).asInstanceOf[String].getBytes("utf-8") - ) - ) - case IntegerType => consumer.addInteger(record.getInt(index)) - case LongType => consumer.addLong(record.getLong(index)) - case DoubleType => consumer.addDouble(record.getDouble(index)) - case FloatType => consumer.addFloat(record.getFloat(index)) - case BooleanType => consumer.addBoolean(record.getBoolean(index)) - case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") - } - } - - def getSchema(schemaString : String) : MessageType = - MessageTypeParser.parseMessageType(schemaString) - - def convertToAttributes(parquetSchema: MessageType) : Seq[Attribute] = { - parquetSchema.getColumns.map { - case (desc) => - val ctype = toDataType(desc.getType) - val name: String = desc.getPath.mkString(".") - new AttributeReference(name, ctype, false)() - } - } - - // TODO: allow nesting? - def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { - val fields: Seq[ParquetType] = attributes.map { - a => new ParquetPrimitiveType(Repetition.OPTIONAL, fromDataType(a.dataType), a.name) - } - new MessageType("root", fields) - } - - def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration) { - if (origPath == null) { - throw new IllegalArgumentException("Unable to write Parquet metadata: path is null") - } - val fs = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException( - s"Unable to write Parquet metadata: path $origPath is incorrectly formatted") - } - val path = origPath.makeQualified(fs) - if (fs.exists(path) && !fs.getFileStatus(path).isDir) { - throw new IllegalArgumentException(s"Expected to write to directory $path but found file") - } - val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) - if (fs.exists(metadataPath)) { - try { - fs.delete(metadataPath, true) - } catch { - case e: IOException => - throw new IOException(s"Unable to delete previous PARQUET_METADATA_FILE at $metadataPath") - } - } - val extraMetadata = new java.util.HashMap[String, String]() - extraMetadata.put("path", path.toString) - // TODO: add extra data, e.g., table name, date, etc.? - - val parquetSchema: MessageType = - ParquetTypesConverter.convertFromAttributes(attributes) - val metaData: FileMetaData = new FileMetaData( - parquetSchema, - extraMetadata, - "Spark") - - ParquetRelation.enableLogForwarding() - ParquetFileWriter.writeMetadataFile( - conf, - path, - new Footer(path, new ParquetMetadata(metaData, Nil)) :: Nil) - } - - /** - * Try to read Parquet metadata at the given Path. We first see if there is a summary file - * in the parent directory. If so, this is used. Else we read the actual footer at the given - * location. - * @param origPath The path at which we expect one (or more) Parquet files. - * @return The `ParquetMetadata` containing among other things the schema. - */ - def readMetaData(origPath: Path): ParquetMetadata = { - if (origPath == null) { - throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") - } - val job = new Job() - // TODO: since this is called from ParquetRelation (LogicalPlan) we don't have access - // to SparkContext's hadoopConfig; in principle the default FileSystem may be different(?!) - val conf = ContextUtil.getConfiguration(job) - val fs: FileSystem = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") - } - val path = origPath.makeQualified(fs) - if (!fs.getFileStatus(path).isDir) { - throw new IllegalArgumentException( - s"Expected $path for be a directory with Parquet files/metadata") - } - ParquetRelation.enableLogForwarding() - val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) - // if this is a new table that was just created we will find only the metadata file - if (fs.exists(metadataPath) && fs.isFile(metadataPath)) { - ParquetFileReader.readFooter(conf, metadataPath) - } else { - // there may be one or more Parquet files in the given directory - val footers = ParquetFileReader.readFooters(conf, fs.getFileStatus(path)) - // TODO: for now we assume that all footers (if there is more than one) have identical - // metadata; we may want to add a check here at some point - if (footers.size() == 0) { - throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path") - } - footers(0).getParquetMetadata - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 65ba1246fbf9a..624f2e2fa13f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -36,6 +36,7 @@ import parquet.schema.MessageType import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} +import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} /** @@ -64,10 +65,13 @@ case class ParquetTableScan( NewFileInputFormat.addInputPath(job, path) } - // Store Parquet schema in `Configuration` + // Store both requested and original schema in `Configuration` conf.set( - RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertFromAttributes(output).toString) + RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, + ParquetTypesConverter.convertToString(output)) + conf.set( + RowWriteSupport.SPARK_ROW_SCHEMA, + ParquetTypesConverter.convertToString(relation.output)) // Store record filtering predicate in `Configuration` // Note 1: the input format ignores all predicates that cannot be expressed @@ -166,13 +170,18 @@ case class InsertIntoParquetTable( val job = new Job(sc.hadoopConfiguration) - ParquetOutputFormat.setWriteSupportClass( - job, - classOf[org.apache.spark.sql.parquet.RowWriteSupport]) + val writeSupport = + if (child.output.map(_.dataType).forall(_.isPrimitive)) { + logger.debug("Initializing MutableRowWriteSupport") + classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport] + } else { + classOf[org.apache.spark.sql.parquet.RowWriteSupport] + } + + ParquetOutputFormat.setWriteSupportClass(job, writeSupport) - // TODO: move that to function in object val conf = ContextUtil.getConfiguration(job) - conf.set(RowWriteSupport.PARQUET_ROW_SCHEMA, relation.parquetSchema.toString) + RowWriteSupport.setSchema(relation.output, conf) val fspath = new Path(relation.path) val fs = fspath.getFileSystem(conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 71ba0fecce47a..bfcbdeb34a92f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -29,21 +29,23 @@ import parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.SparkSqlSerializer +import com.google.common.io.BaseEncoding /** * A `parquet.io.api.RecordMaterializer` for Rows. * *@param root The root group converter for the record. */ -private[parquet] class RowRecordMaterializer(root: CatalystGroupConverter) +private[parquet] class RowRecordMaterializer(root: CatalystConverter) extends RecordMaterializer[Row] { - def this(parquetSchema: MessageType) = - this(new CatalystGroupConverter(ParquetTypesConverter.convertToAttributes(parquetSchema))) + def this(parquetSchema: MessageType, attributes: Seq[Attribute]) = + this(CatalystConverter.createRootConverter(parquetSchema, attributes)) override def getCurrentRecord: Row = root.getCurrentRecord - override def getRootConverter: GroupConverter = root + override def getRootConverter: GroupConverter = root.asInstanceOf[GroupConverter] } /** @@ -56,68 +58,94 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { stringMap: java.util.Map[String, String], fileSchema: MessageType, readContext: ReadContext): RecordMaterializer[Row] = { - log.debug(s"preparing for read with file schema $fileSchema") - new RowRecordMaterializer(readContext.getRequestedSchema) + log.debug(s"preparing for read with Parquet file schema $fileSchema") + // Note: this very much imitates AvroParquet + val parquetSchema = readContext.getRequestedSchema + var schema: Seq[Attribute] = null + + if (readContext.getReadSupportMetadata != null) { + // first try to find the read schema inside the metadata (can result from projections) + if ( + readContext + .getReadSupportMetadata + .get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) != null) { + schema = ParquetTypesConverter.convertFromString( + readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA)) + } else { + // if unavailable, try the schema that was read originally from the file or provided + // during the creation of the Parquet relation + if (readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_METADATA_KEY) != null) { + schema = ParquetTypesConverter.convertFromString( + readContext.getReadSupportMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) + } + } + } + // if both unavailable, fall back to deducing the schema from the given Parquet schema + if (schema == null) { + log.debug("falling back to Parquet read schema") + schema = ParquetTypesConverter.convertToAttributes(parquetSchema) + } + log.debug(s"list of attributes that will be read: $schema") + new RowRecordMaterializer(parquetSchema, schema) } override def init( configuration: Configuration, keyValueMetaData: java.util.Map[String, String], fileSchema: MessageType): ReadContext = { - val requested_schema_string = - configuration.get(RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, fileSchema.toString) - val requested_schema = - MessageTypeParser.parseMessageType(requested_schema_string) - log.debug(s"read support initialized for requested schema $requested_schema") - ParquetRelation.enableLogForwarding() - new ReadContext(requested_schema, keyValueMetaData) + var parquetSchema: MessageType = fileSchema + var metadata: java.util.Map[String, String] = new java.util.HashMap[String, String]() + val requestedAttributes = RowReadSupport.getRequestedSchema(configuration) + + if (requestedAttributes != null) { + parquetSchema = ParquetTypesConverter.convertFromAttributes(requestedAttributes) + metadata.put( + RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, + ParquetTypesConverter.convertToString(requestedAttributes)) + } + + val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) + if (origAttributesStr != null) { + metadata.put(RowReadSupport.SPARK_METADATA_KEY, origAttributesStr) + } + + return new ReadSupport.ReadContext(parquetSchema, metadata) } } private[parquet] object RowReadSupport { - val PARQUET_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" + val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" + val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" + + private def getRequestedSchema(configuration: Configuration): Seq[Attribute] = { + val schemaString = configuration.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) + if (schemaString == null) null else ParquetTypesConverter.convertFromString(schemaString) + } } /** * A `parquet.hadoop.api.WriteSupport` for Row ojects. */ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { - def setSchema(schema: MessageType, configuration: Configuration) { - // for testing - this.schema = schema - // TODO: could use Attributes themselves instead of Parquet schema? - configuration.set( - RowWriteSupport.PARQUET_ROW_SCHEMA, - schema.toString) - configuration.set( - ParquetOutputFormat.WRITER_VERSION, - ParquetProperties.WriterVersion.PARQUET_1_0.toString) - } - - def getSchema(configuration: Configuration): MessageType = { - MessageTypeParser.parseMessageType(configuration.get(RowWriteSupport.PARQUET_ROW_SCHEMA)) - } - private var schema: MessageType = null - private var writer: RecordConsumer = null - private var attributes: Seq[Attribute] = null + private[parquet] var writer: RecordConsumer = null + private[parquet] var attributes: Seq[Attribute] = null override def init(configuration: Configuration): WriteSupport.WriteContext = { - schema = if (schema == null) getSchema(configuration) else schema - attributes = ParquetTypesConverter.convertToAttributes(schema) - log.debug(s"write support initialized for requested schema $schema") + attributes = if (attributes == null) RowWriteSupport.getSchema(configuration) else attributes + + log.debug(s"write support initialized for requested schema $attributes") ParquetRelation.enableLogForwarding() new WriteSupport.WriteContext( - schema, + ParquetTypesConverter.convertFromAttributes(attributes), new java.util.HashMap[java.lang.String, java.lang.String]()) } override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { writer = recordConsumer - log.debug(s"preparing for write with schema $schema") + log.debug(s"preparing for write with schema $attributes") } - // TODO: add groups (nested fields) override def write(record: Row): Unit = { if (attributes.size > record.size) { throw new IndexOutOfBoundsException( @@ -130,98 +158,176 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { // null values indicate optional fields but we do not check currently if (record(index) != null && record(index) != Nil) { writer.startField(attributes(index).name, index) - ParquetTypesConverter.consumeType(writer, attributes(index).dataType, record, index) + writeValue(attributes(index).dataType, record(index)) writer.endField(attributes(index).name, index) } index = index + 1 } writer.endMessage() } -} -private[parquet] object RowWriteSupport { - val PARQUET_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.schema" -} - -/** - * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record to a `Row` object. - * - * @param schema The corresponding Catalyst schema in the form of a list of attributes. - */ -private[parquet] class CatalystGroupConverter( - schema: Seq[Attribute], - protected[parquet] val current: ParquetRelation.RowType) extends GroupConverter { - - def this(schema: Seq[Attribute]) = this(schema, new ParquetRelation.RowType(schema.length)) - - val converters: Array[Converter] = schema.map { - a => a.dataType match { - case ctype: NativeType => - // note: for some reason matching for StringType fails so use this ugly if instead - if (ctype == StringType) { - new CatalystPrimitiveStringConverter(this, schema.indexOf(a)) - } else { - new CatalystPrimitiveConverter(this, schema.indexOf(a)) - } - case _ => throw new RuntimeException( - s"unable to convert datatype ${a.dataType.toString} in CatalystGroupConverter") + private[parquet] def writeValue(schema: DataType, value: Any): Unit = { + if (value != null && value != Nil) { + schema match { + case t @ ArrayType(_) => writeArray( + t, + value.asInstanceOf[CatalystConverter.ArrayScalaType[_]]) + case t @ MapType(_, _) => writeMap( + t, + value.asInstanceOf[CatalystConverter.MapScalaType[_, _]]) + case t @ StructType(_) => writeStruct( + t, + value.asInstanceOf[CatalystConverter.StructScalaType[_]]) + case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value) + } } - }.toArray + } - override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + private[parquet] def writePrimitive(schema: PrimitiveType, value: Any): Unit = { + if (value != null && value != Nil) { + schema match { + case StringType => writer.addBinary( + Binary.fromByteArray( + value.asInstanceOf[String].getBytes("utf-8") + ) + ) + case IntegerType => writer.addInteger(value.asInstanceOf[Int]) + case ShortType => writer.addInteger(value.asInstanceOf[Int]) + case LongType => writer.addLong(value.asInstanceOf[Long]) + case ByteType => writer.addInteger(value.asInstanceOf[Int]) + case DoubleType => writer.addDouble(value.asInstanceOf[Double]) + case FloatType => writer.addFloat(value.asInstanceOf[Float]) + case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case _ => sys.error(s"Do not know how to writer $schema to consumer") + } + } + } - private[parquet] def getCurrentRecord: ParquetRelation.RowType = current + private[parquet] def writeStruct( + schema: StructType, + struct: CatalystConverter.StructScalaType[_]): Unit = { + if (struct != null && struct != Nil) { + val fields = schema.fields.toArray + writer.startGroup() + var i = 0 + while(i < fields.size) { + if (struct(i) != null && struct(i) != Nil) { + writer.startField(fields(i).name, i) + writeValue(fields(i).dataType, struct(i)) + writer.endField(fields(i).name, i) + } + i = i + 1 + } + writer.endGroup() + } + } - override def start(): Unit = { - var i = 0 - while (i < schema.length) { - current.setNullAt(i) - i = i + 1 + // TODO: support null values, see + // https://issues.apache.org/jira/browse/SPARK-1649 + private[parquet] def writeArray( + schema: ArrayType, + array: CatalystConverter.ArrayScalaType[_]): Unit = { + val elementType = schema.elementType + writer.startGroup() + if (array.size > 0) { + writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) + var i = 0 + while(i < array.size) { + writeValue(elementType, array(i)) + i = i + 1 + } + writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) } + writer.endGroup() } - override def end(): Unit = {} + // TODO: support null values, see + // https://issues.apache.org/jira/browse/SPARK-1649 + private[parquet] def writeMap( + schema: MapType, + map: CatalystConverter.MapScalaType[_, _]): Unit = { + writer.startGroup() + if (map.size > 0) { + writer.startField(CatalystConverter.MAP_SCHEMA_NAME, 0) + writer.startGroup() + writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) + for(key <- map.keys) { + writeValue(schema.keyType, key) + } + writer.endField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) + writer.startField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) + for(value <- map.values) { + writeValue(schema.valueType, value) + } + writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) + writer.endGroup() + writer.endField(CatalystConverter.MAP_SCHEMA_NAME, 0) + } + writer.endGroup() + } } -/** - * A `parquet.io.api.PrimitiveConverter` that converts Parquet types to Catalyst types. - * - * @param parent The parent group converter. - * @param fieldIndex The index inside the record. - */ -private[parquet] class CatalystPrimitiveConverter( - parent: CatalystGroupConverter, - fieldIndex: Int) extends PrimitiveConverter { - // TODO: consider refactoring these together with ParquetTypesConverter - override def addBinary(value: Binary): Unit = - parent.getCurrentRecord.update(fieldIndex, value.getBytes) +// Optimized for non-nested rows +private[parquet] class MutableRowWriteSupport extends RowWriteSupport { + override def write(record: Row): Unit = { + if (attributes.size > record.size) { + throw new IndexOutOfBoundsException( + s"Trying to write more fields than contained in row (${attributes.size}>${record.size})") + } - override def addBoolean(value: Boolean): Unit = - parent.getCurrentRecord.setBoolean(fieldIndex, value) + var index = 0 + writer.startMessage() + while(index < attributes.size) { + // null values indicate optional fields but we do not check currently + if (record(index) != null && record(index) != Nil) { + writer.startField(attributes(index).name, index) + consumeType(attributes(index).dataType, record, index) + writer.endField(attributes(index).name, index) + } + index = index + 1 + } + writer.endMessage() + } - override def addDouble(value: Double): Unit = - parent.getCurrentRecord.setDouble(fieldIndex, value) + private def consumeType( + ctype: DataType, + record: Row, + index: Int): Unit = { + ctype match { + case StringType => writer.addBinary( + Binary.fromByteArray( + record(index).asInstanceOf[String].getBytes("utf-8") + ) + ) + case IntegerType => writer.addInteger(record.getInt(index)) + case ShortType => writer.addInteger(record.getShort(index)) + case LongType => writer.addLong(record.getLong(index)) + case ByteType => writer.addInteger(record.getByte(index)) + case DoubleType => writer.addDouble(record.getDouble(index)) + case FloatType => writer.addFloat(record.getFloat(index)) + case BooleanType => writer.addBoolean(record.getBoolean(index)) + case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") + } + } +} - override def addFloat(value: Float): Unit = - parent.getCurrentRecord.setFloat(fieldIndex, value) +private[parquet] object RowWriteSupport { + val SPARK_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.attributes" - override def addInt(value: Int): Unit = - parent.getCurrentRecord.setInt(fieldIndex, value) + def getSchema(configuration: Configuration): Seq[Attribute] = { + val schemaString = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) + if (schemaString == null) { + throw new RuntimeException("Missing schema!") + } + ParquetTypesConverter.convertFromString(schemaString) + } - override def addLong(value: Long): Unit = - parent.getCurrentRecord.setLong(fieldIndex, value) + def setSchema(schema: Seq[Attribute], configuration: Configuration) { + val encoded = ParquetTypesConverter.convertToString(schema) + configuration.set(SPARK_ROW_SCHEMA, encoded) + configuration.set( + ParquetOutputFormat.WRITER_VERSION, + ParquetProperties.WriterVersion.PARQUET_1_0.toString) + } } -/** - * A `parquet.io.api.PrimitiveConverter` that converts Parquet strings (fixed-length byte arrays) - * into Catalyst Strings. - * - * @param parent The parent group converter. - * @param fieldIndex The index inside the record. - */ -private[parquet] class CatalystPrimitiveStringConverter( - parent: CatalystGroupConverter, - fieldIndex: Int) extends CatalystPrimitiveConverter(parent, fieldIndex) { - override def addBinary(value: Binary): Unit = - parent.getCurrentRecord.setString(fieldIndex, value.toStringUsingUTF8) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index 46c7172985642..1dc58633a2a68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -17,14 +17,19 @@ package org.apache.spark.sql.parquet +import java.io.File + import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.mapreduce.Job import parquet.example.data.{GroupWriter, Group} import parquet.example.data.simple.SimpleGroup -import parquet.hadoop.ParquetWriter +import parquet.hadoop.{ParquetReader, ParquetFileReader, ParquetWriter} import parquet.hadoop.api.WriteSupport import parquet.hadoop.api.WriteSupport.WriteContext +import parquet.hadoop.example.GroupReadSupport +import parquet.hadoop.util.ContextUtil import parquet.io.api.RecordConsumer import parquet.schema.{MessageType, MessageTypeParser} @@ -51,13 +56,13 @@ private[sql] object ParquetTestData { val testSchema = """message myrecord { - |optional boolean myboolean; - |optional int32 myint; - |optional binary mystring; - |optional int64 mylong; - |optional float myfloat; - |optional double mydouble; - |}""".stripMargin + optional boolean myboolean; + optional int32 myint; + optional binary mystring; + optional int64 mylong; + optional float myfloat; + optional double mydouble; + }""" // field names for test assertion error messages val testSchemaFieldNames = Seq( @@ -71,23 +76,23 @@ private[sql] object ParquetTestData { val subTestSchema = """ - |message myrecord { - |optional boolean myboolean; - |optional int64 mylong; - |} - """.stripMargin + message myrecord { + optional boolean myboolean; + optional int64 mylong; + } + """ val testFilterSchema = """ - |message myrecord { - |required boolean myboolean; - |required int32 myint; - |required binary mystring; - |required int64 mylong; - |required float myfloat; - |required double mydouble; - |} - """.stripMargin + message myrecord { + required boolean myboolean; + required int32 myint; + required binary mystring; + required int64 mylong; + required float myfloat; + required double mydouble; + } + """ // field names for test assertion error messages val subTestSchemaFieldNames = Seq( @@ -100,9 +105,110 @@ private[sql] object ParquetTestData { lazy val testData = new ParquetRelation(testDir.toURI.toString) + val testNestedSchema1 = + // based on blogpost example, source: + // https://blog.twitter.com/2013/dremel-made-simple-with-parquet + // note: instead of string we have to use binary (?) otherwise + // Parquet gives us: + // IllegalArgumentException: expected one of [INT64, INT32, BOOLEAN, + // BINARY, FLOAT, DOUBLE, INT96, FIXED_LEN_BYTE_ARRAY] + // Also repeated primitives seem tricky to convert (AvroParquet + // only uses them in arrays?) so only use at most one in each group + // and nothing else in that group (-> is mapped to array)! + // The "values" inside ownerPhoneNumbers is a keyword currently + // so that array types can be translated correctly. + """ + message AddressBook { + required binary owner; + optional group ownerPhoneNumbers { + repeated binary array; + } + optional group contacts { + repeated group array { + required binary name; + optional binary phoneNumber; + } + } + } + """ + + + val testNestedSchema2 = + """ + message TestNested2 { + required int32 firstInt; + optional int32 secondInt; + optional group longs { + repeated int64 array; + } + required group entries { + repeated group array { + required double value; + optional boolean truth; + } + } + optional group outerouter { + repeated group array { + repeated group array { + repeated int32 array; + } + } + } + } + """ + + val testNestedSchema3 = + """ + message TestNested3 { + required int32 x; + optional group booleanNumberPairs { + repeated group array { + required int32 key; + optional group value { + repeated group array { + required double nestedValue; + optional boolean truth; + } + } + } + } + } + """ + + val testNestedSchema4 = + """ + message TestNested4 { + required int32 x; + optional group data1 { + repeated group map { + required binary key; + required int32 value; + } + } + required group data2 { + repeated group map { + required binary key; + required group value { + required int64 payload1; + optional binary payload2; + } + } + } + } + """ + + val testNestedDir1 = Utils.createTempDir() + val testNestedDir2 = Utils.createTempDir() + val testNestedDir3 = Utils.createTempDir() + val testNestedDir4 = Utils.createTempDir() + + lazy val testNestedData1 = new ParquetRelation(testNestedDir1.toURI.toString) + lazy val testNestedData2 = new ParquetRelation(testNestedDir2.toURI.toString) + def writeFile() = { - testDir.delete + testDir.delete() val path: Path = new Path(new Path(testDir.toURI), new Path("part-r-0.parquet")) + val job = new Job() val schema: MessageType = MessageTypeParser.parseMessageType(testSchema) val writeSupport = new TestGroupWriteSupport(schema) val writer = new ParquetWriter[Group](path, writeSupport) @@ -150,5 +256,149 @@ private[sql] object ParquetTestData { } writer.close() } + + def writeNestedFile1() { + // example data from https://blog.twitter.com/2013/dremel-made-simple-with-parquet + testNestedDir1.delete() + val path: Path = new Path(new Path(testNestedDir1.toURI), new Path("part-r-0.parquet")) + val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema1) + + val r1 = new SimpleGroup(schema) + r1.add(0, "Julien Le Dem") + r1.addGroup(1) + .append(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, "555 123 4567") + .append(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, "555 666 1337") + .append(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, "XXX XXX XXXX") + val contacts = r1.addGroup(2) + contacts.addGroup(0) + .append("name", "Dmitriy Ryaboy") + .append("phoneNumber", "555 987 6543") + contacts.addGroup(0) + .append("name", "Chris Aniszczyk") + + val r2 = new SimpleGroup(schema) + r2.add(0, "A. Nonymous") + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + writer.write(r1) + writer.write(r2) + writer.close() + } + + def writeNestedFile2() { + testNestedDir2.delete() + val path: Path = new Path(new Path(testNestedDir2.toURI), new Path("part-r-0.parquet")) + val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema2) + + val r1 = new SimpleGroup(schema) + r1.add(0, 1) + r1.add(1, 7) + val longs = r1.addGroup(2) + longs.add(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME , 1.toLong << 32) + longs.add(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 1.toLong << 33) + longs.add(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 1.toLong << 34) + val booleanNumberPair = r1.addGroup(3).addGroup(0) + booleanNumberPair.add("value", 2.5) + booleanNumberPair.add("truth", false) + val top_level = r1.addGroup(4) + val second_level_a = top_level.addGroup(0) + val second_level_b = top_level.addGroup(0) + val third_level_aa = second_level_a.addGroup(0) + val third_level_ab = second_level_a.addGroup(0) + val third_level_c = second_level_b.addGroup(0) + third_level_aa.add( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + 7) + third_level_ab.add( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + 8) + third_level_c.add( + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + 9) + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + writer.write(r1) + writer.close() + } + + def writeNestedFile3() { + testNestedDir3.delete() + val path: Path = new Path(new Path(testNestedDir3.toURI), new Path("part-r-0.parquet")) + val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema3) + + val r1 = new SimpleGroup(schema) + r1.add(0, 1) + val booleanNumberPairs = r1.addGroup(1) + val g1 = booleanNumberPairs.addGroup(0) + g1.add(0, 1) + val nested1 = g1.addGroup(1) + val ng1 = nested1.addGroup(0) + ng1.add(0, 1.5) + ng1.add(1, false) + val ng2 = nested1.addGroup(0) + ng2.add(0, 2.5) + ng2.add(1, true) + val g2 = booleanNumberPairs.addGroup(0) + g2.add(0, 2) + val ng3 = g2.addGroup(1) + .addGroup(0) + ng3.add(0, 3.5) + ng3.add(1, false) + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + writer.write(r1) + writer.close() + } + + def writeNestedFile4() { + testNestedDir4.delete() + val path: Path = new Path(new Path(testNestedDir4.toURI), new Path("part-r-0.parquet")) + val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema4) + + val r1 = new SimpleGroup(schema) + r1.add(0, 7) + val map1 = r1.addGroup(1) + val keyValue1 = map1.addGroup(0) + keyValue1.add(0, "key1") + keyValue1.add(1, 1) + val keyValue2 = map1.addGroup(0) + keyValue2.add(0, "key2") + keyValue2.add(1, 2) + val map2 = r1.addGroup(2) + val keyValue3 = map2.addGroup(0) + // TODO: currently only string key type supported + keyValue3.add(0, "seven") + val valueGroup1 = keyValue3.addGroup(1) + valueGroup1.add(0, 42.toLong) + valueGroup1.add(1, "the answer") + val keyValue4 = map2.addGroup(0) + // TODO: currently only string key type supported + keyValue4.add(0, "eight") + val valueGroup2 = keyValue4.addGroup(1) + valueGroup2.add(0, 49.toLong) + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + writer.write(r1) + writer.close() + } + + // TODO: this is not actually used anywhere but useful for debugging + /* def readNestedFile(file: File, schemaString: String): Unit = { + val configuration = new Configuration() + val path = new Path(new Path(file.toURI), new Path("part-r-0.parquet")) + val fs: FileSystem = path.getFileSystem(configuration) + val schema: MessageType = MessageTypeParser.parseMessageType(schemaString) + assert(schema != null) + val outputStatus: FileStatus = fs.getFileStatus(new Path(path.toString)) + val footers = ParquetFileReader.readFooter(configuration, outputStatus) + assert(footers != null) + val reader = new ParquetReader(new Path(path.toString), new GroupReadSupport()) + val first = reader.read() + assert(first != null) + } */ } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala new file mode 100644 index 0000000000000..f9046368e7ced --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.io.IOException + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.mapreduce.Job + +import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter} +import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData} +import parquet.hadoop.util.ContextUtil +import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType} +import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns} +import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} +import parquet.schema.Type.Repetition + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.types._ + +// Implicits +import scala.collection.JavaConversions._ + +private[parquet] object ParquetTypesConverter extends Logging { + def isPrimitiveType(ctype: DataType): Boolean = + classOf[PrimitiveType] isAssignableFrom ctype.getClass + + def toPrimitiveDataType(parquetType : ParquetPrimitiveTypeName): DataType = parquetType match { + case ParquetPrimitiveTypeName.BINARY => StringType + case ParquetPrimitiveTypeName.BOOLEAN => BooleanType + case ParquetPrimitiveTypeName.DOUBLE => DoubleType + case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY => ArrayType(ByteType) + case ParquetPrimitiveTypeName.FLOAT => FloatType + case ParquetPrimitiveTypeName.INT32 => IntegerType + case ParquetPrimitiveTypeName.INT64 => LongType + case ParquetPrimitiveTypeName.INT96 => + // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? + sys.error("Potential loss of precision: cannot convert INT96") + case _ => sys.error( + s"Unsupported parquet datatype $parquetType") + } + + /** + * Converts a given Parquet `Type` into the corresponding + * [[org.apache.spark.sql.catalyst.types.DataType]]. + * + * We apply the following conversion rules: + *
    + *
  • Primitive types are converter to the corresponding primitive type.
  • + *
  • Group types that have a single field that is itself a group, which has repetition + * level `REPEATED`, are treated as follows:
      + *
    • If the nested group has name `values`, the surrounding group is converted + * into an [[ArrayType]] with the corresponding field type (primitive or + * complex) as element type.
    • + *
    • If the nested group has name `map` and two fields (named `key` and `value`), + * the surrounding group is converted into a [[MapType]] + * with the corresponding key and value (value possibly complex) types. + * Note that we currently assume map values are not nullable.
    • + *
    • Other group types are converted into a [[StructType]] with the corresponding + * field types.
  • + *
+ * Note that fields are determined to be `nullable` if and only if their Parquet repetition + * level is not `REQUIRED`. + * + * @param parquetType The type to convert. + * @return The corresponding Catalyst type. + */ + def toDataType(parquetType: ParquetType): DataType = { + def correspondsToMap(groupType: ParquetGroupType): Boolean = { + if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { + false + } else { + // This mostly follows the convention in ``parquet.schema.ConversionPatterns`` + val keyValueGroup = groupType.getFields.apply(0).asGroupType() + keyValueGroup.getRepetition == Repetition.REPEATED && + keyValueGroup.getName == CatalystConverter.MAP_SCHEMA_NAME && + keyValueGroup.getFieldCount == 2 && + keyValueGroup.getFields.apply(0).getName == CatalystConverter.MAP_KEY_SCHEMA_NAME && + keyValueGroup.getFields.apply(1).getName == CatalystConverter.MAP_VALUE_SCHEMA_NAME + } + } + + def correspondsToArray(groupType: ParquetGroupType): Boolean = { + groupType.getFieldCount == 1 && + groupType.getFieldName(0) == CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME && + groupType.getFields.apply(0).getRepetition == Repetition.REPEATED + } + + if (parquetType.isPrimitive) { + toPrimitiveDataType(parquetType.asPrimitiveType.getPrimitiveTypeName) + } else { + val groupType = parquetType.asGroupType() + parquetType.getOriginalType match { + // if the schema was constructed programmatically there may be hints how to convert + // it inside the metadata via the OriginalType field + case ParquetOriginalType.LIST => { // TODO: check enums! + assert(groupType.getFieldCount == 1) + val field = groupType.getFields.apply(0) + new ArrayType(toDataType(field)) + } + case ParquetOriginalType.MAP => { + assert( + !groupType.getFields.apply(0).isPrimitive, + "Parquet Map type malformatted: expected nested group for map!") + val keyValueGroup = groupType.getFields.apply(0).asGroupType() + assert( + keyValueGroup.getFieldCount == 2, + "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") + val keyType = toDataType(keyValueGroup.getFields.apply(0)) + assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) + val valueType = toDataType(keyValueGroup.getFields.apply(1)) + assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) + new MapType(keyType, valueType) + } + case _ => { + // Note: the order of these checks is important! + if (correspondsToMap(groupType)) { // MapType + val keyValueGroup = groupType.getFields.apply(0).asGroupType() + val keyType = toDataType(keyValueGroup.getFields.apply(0)) + assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) + val valueType = toDataType(keyValueGroup.getFields.apply(1)) + assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) + new MapType(keyType, valueType) + } else if (correspondsToArray(groupType)) { // ArrayType + val elementType = toDataType(groupType.getFields.apply(0)) + new ArrayType(elementType) + } else { // everything else: StructType + val fields = groupType + .getFields + .map(ptype => new StructField( + ptype.getName, + toDataType(ptype), + ptype.getRepetition != Repetition.REQUIRED)) + new StructType(fields) + } + } + } + } + } + + /** + * For a given Catalyst [[org.apache.spark.sql.catalyst.types.DataType]] return + * the name of the corresponding Parquet primitive type or None if the given type + * is not primitive. + * + * @param ctype The type to convert + * @return The name of the corresponding Parquet primitive type + */ + def fromPrimitiveDataType(ctype: DataType): + Option[ParquetPrimitiveTypeName] = ctype match { + case StringType => Some(ParquetPrimitiveTypeName.BINARY) + case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN) + case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE) + case ArrayType(ByteType) => + Some(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + case FloatType => Some(ParquetPrimitiveTypeName.FLOAT) + case IntegerType => Some(ParquetPrimitiveTypeName.INT32) + // There is no type for Byte or Short so we promote them to INT32. + case ShortType => Some(ParquetPrimitiveTypeName.INT32) + case ByteType => Some(ParquetPrimitiveTypeName.INT32) + case LongType => Some(ParquetPrimitiveTypeName.INT64) + case _ => None + } + + /** + * Converts a given Catalyst [[org.apache.spark.sql.catalyst.types.DataType]] into + * the corresponding Parquet `Type`. + * + * The conversion follows the rules below: + *
    + *
  • Primitive types are converted into Parquet's primitive types.
  • + *
  • [[org.apache.spark.sql.catalyst.types.StructType]]s are converted + * into Parquet's `GroupType` with the corresponding field types.
  • + *
  • [[org.apache.spark.sql.catalyst.types.ArrayType]]s are converted + * into a 2-level nested group, where the outer group has the inner + * group as sole field. The inner group has name `values` and + * repetition level `REPEATED` and has the element type of + * the array as schema. We use Parquet's `ConversionPatterns` for this + * purpose.
  • + *
  • [[org.apache.spark.sql.catalyst.types.MapType]]s are converted + * into a nested (2-level) Parquet `GroupType` with two fields: a key + * type and a value type. The nested group has repetition level + * `REPEATED` and name `map`. We use Parquet's `ConversionPatterns` + * for this purpose
  • + *
+ * Parquet's repetition level is generally set according to the following rule: + *
    + *
  • If the call to `fromDataType` is recursive inside an enclosing `ArrayType` or + * `MapType`, then the repetition level is set to `REPEATED`.
  • + *
  • Otherwise, if the attribute whose type is converted is `nullable`, the Parquet + * type gets repetition level `OPTIONAL` and otherwise `REQUIRED`.
  • + *
+ * + *@param ctype The type to convert + * @param name The name of the [[org.apache.spark.sql.catalyst.expressions.Attribute]] + * whose type is converted + * @param nullable When true indicates that the attribute is nullable + * @param inArray When true indicates that this is a nested attribute inside an array. + * @return The corresponding Parquet type. + */ + def fromDataType( + ctype: DataType, + name: String, + nullable: Boolean = true, + inArray: Boolean = false): ParquetType = { + val repetition = + if (inArray) { + Repetition.REPEATED + } else { + if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED + } + val primitiveType = fromPrimitiveDataType(ctype) + if (primitiveType.isDefined) { + new ParquetPrimitiveType(repetition, primitiveType.get, name) + } else { + ctype match { + case ArrayType(elementType) => { + val parquetElementType = fromDataType( + elementType, + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + nullable = false, + inArray = true) + ConversionPatterns.listType(repetition, name, parquetElementType) + } + case StructType(structFields) => { + val fields = structFields.map { + field => fromDataType(field.dataType, field.name, field.nullable, inArray = false) + } + new ParquetGroupType(repetition, name, fields) + } + case MapType(keyType, valueType) => { + val parquetKeyType = + fromDataType( + keyType, + CatalystConverter.MAP_KEY_SCHEMA_NAME, + nullable = false, + inArray = false) + val parquetValueType = + fromDataType( + valueType, + CatalystConverter.MAP_VALUE_SCHEMA_NAME, + nullable = false, + inArray = false) + ConversionPatterns.mapType( + repetition, + name, + parquetKeyType, + parquetValueType) + } + case _ => sys.error(s"Unsupported datatype $ctype") + } + } + } + + def convertToAttributes(parquetSchema: ParquetType): Seq[Attribute] = { + parquetSchema + .asGroupType() + .getFields + .map( + field => + new AttributeReference( + field.getName, + toDataType(field), + field.getRepetition != Repetition.REQUIRED)()) + } + + def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { + val fields = attributes.map( + attribute => + fromDataType(attribute.dataType, attribute.name, attribute.nullable)) + new MessageType("root", fields) + } + + def convertFromString(string: String): Seq[Attribute] = { + DataType(string) match { + case s: StructType => s.toAttributes + case other => sys.error(s"Can convert $string to row") + } + } + + def convertToString(schema: Seq[Attribute]): String = { + StructType.fromAttributes(schema).toString + } + + def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = { + if (origPath == null) { + throw new IllegalArgumentException("Unable to write Parquet metadata: path is null") + } + val fs = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException( + s"Unable to write Parquet metadata: path $origPath is incorrectly formatted") + } + val path = origPath.makeQualified(fs) + if (fs.exists(path) && !fs.getFileStatus(path).isDir) { + throw new IllegalArgumentException(s"Expected to write to directory $path but found file") + } + val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fs.exists(metadataPath)) { + try { + fs.delete(metadataPath, true) + } catch { + case e: IOException => + throw new IOException(s"Unable to delete previous PARQUET_METADATA_FILE at $metadataPath") + } + } + val extraMetadata = new java.util.HashMap[String, String]() + extraMetadata.put( + RowReadSupport.SPARK_METADATA_KEY, + ParquetTypesConverter.convertToString(attributes)) + // TODO: add extra data, e.g., table name, date, etc.? + + val parquetSchema: MessageType = + ParquetTypesConverter.convertFromAttributes(attributes) + val metaData: FileMetaData = new FileMetaData( + parquetSchema, + extraMetadata, + "Spark") + + ParquetRelation.enableLogForwarding() + ParquetFileWriter.writeMetadataFile( + conf, + path, + new Footer(path, new ParquetMetadata(metaData, Nil)) :: Nil) + } + + /** + * Try to read Parquet metadata at the given Path. We first see if there is a summary file + * in the parent directory. If so, this is used. Else we read the actual footer at the given + * location. + * @param origPath The path at which we expect one (or more) Parquet files. + * @param configuration The Hadoop configuration to use. + * @return The `ParquetMetadata` containing among other things the schema. + */ + def readMetaData(origPath: Path, configuration: Option[Configuration]): ParquetMetadata = { + if (origPath == null) { + throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") + } + val job = new Job() + val conf = configuration.getOrElse(ContextUtil.getConfiguration(job)) + val fs: FileSystem = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") + } + val path = origPath.makeQualified(fs) + if (!fs.getFileStatus(path).isDir) { + throw new IllegalArgumentException( + s"Expected $path for be a directory with Parquet files/metadata") + } + ParquetRelation.enableLogForwarding() + val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) + // if this is a new table that was just created we will find only the metadata file + if (fs.exists(metadataPath) && fs.isFile(metadataPath)) { + ParquetFileReader.readFooter(conf, metadataPath) + } else { + // there may be one or more Parquet files in the given directory + val footers = ParquetFileReader.readFooters(conf, fs.getFileStatus(path)) + // TODO: for now we assume that all footers (if there is more than one) have identical + // metadata; we may want to add a check here at some point + if (footers.size() == 0) { + throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path") + } + footers(0).getParquetMetadata + } + } + + /** + * Reads in Parquet Metadata from the given path and tries to extract the schema + * (Catalyst attributes) from the application-specific key-value map. If this + * is empty it falls back to converting from the Parquet file schema which + * may lead to an upcast of types (e.g., {byte, short} to int). + * + * @param origPath The path at which we expect one (or more) Parquet files. + * @param conf The Hadoop configuration to use. + * @return A list of attributes that make up the schema. + */ + def readSchemaFromFile(origPath: Path, conf: Option[Configuration]): Seq[Attribute] = { + val keyValueMetadata: java.util.Map[String, String] = + readMetaData(origPath, conf) + .getFileMetaData + .getKeyValueMetaData + if (keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY) != null) { + convertFromString(keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) + } else { + val attributes = convertToAttributes( + readMetaData(origPath, conf).getFileMetaData.getSchema) + log.warn(s"Falling back to schema conversion from Parquet types; result: $attributes") + attributes + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 9810520bb9ae6..0c239d00b199b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -26,15 +26,16 @@ import parquet.hadoop.ParquetFileWriter import parquet.hadoop.util.ContextUtil import parquet.schema.MessageTypeParser +import org.apache.spark.SparkContext import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.util.getTempFilePath import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.TestData import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.catalyst.expressions.Equals -import org.apache.spark.sql.catalyst.types.IntegerType +import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star} import org.apache.spark.util.Utils // Implicits @@ -56,15 +57,37 @@ case class OptionalReflectData( doubleField: Option[Double], booleanField: Option[Boolean]) +case class Nested(i: Int, s: String) + +case class Data(array: Seq[Int], nested: Nested) + +case class AllDataTypes( + stringField: String, + intField: Int, + longField: Long, + floatField: Float, + doubleField: Double, + shortField: Short, + byteField: Byte, + booleanField: Boolean) + class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { import TestData._ TestData // Load test data tables. var testRDD: SchemaRDD = null + // TODO: remove this once SqlParser can parse nested select statements + var nestedParserSqlContext: NestedParserSQLContext = null + override def beforeAll() { + nestedParserSqlContext = new NestedParserSQLContext(TestSQLContext.sparkContext) ParquetTestData.writeFile() ParquetTestData.writeFilterFile() + ParquetTestData.writeNestedFile1() + ParquetTestData.writeNestedFile2() + ParquetTestData.writeNestedFile3() + ParquetTestData.writeNestedFile4() testRDD = parquetFile(ParquetTestData.testDir.toString) testRDD.registerAsTable("testsource") parquetFile(ParquetTestData.testFilterDir.toString) @@ -74,9 +97,33 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA override def afterAll() { Utils.deleteRecursively(ParquetTestData.testDir) Utils.deleteRecursively(ParquetTestData.testFilterDir) + Utils.deleteRecursively(ParquetTestData.testNestedDir1) + Utils.deleteRecursively(ParquetTestData.testNestedDir2) + Utils.deleteRecursively(ParquetTestData.testNestedDir3) + Utils.deleteRecursively(ParquetTestData.testNestedDir4) // here we should also unregister the table?? } + test("Read/Write All Types") { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val range = (0 to 255) + TestSQLContext.sparkContext.parallelize(range) + .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)) + .saveAsParquetFile(tempDir) + val result = parquetFile(tempDir).collect() + range.foreach { + i => + assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}") + assert(result(i).getInt(1) === i) + assert(result(i).getLong(2) === i.toLong) + assert(result(i).getFloat(3) === i.toFloat) + assert(result(i).getDouble(4) === i.toDouble) + assert(result(i).getShort(5) === i.toShort) + assert(result(i).getByte(6) === i.toByte) + assert(result(i).getBoolean(7) === (i % 2 == 0)) + } + } + test("self-join parquet files") { val x = ParquetTestData.testData.as('x) val y = ParquetTestData.testData.as('y) @@ -154,7 +201,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA path, TestSQLContext.sparkContext.hadoopConfiguration) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val metaData = ParquetTypesConverter.readMetaData(path) + val metaData = ParquetTypesConverter.readMetaData(path, Some(ContextUtil.getConfiguration(job))) assert(metaData != null) ParquetTestData .testData @@ -197,10 +244,37 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value in line $i") } Utils.deleteRecursively(file) - assert(true) } - test("insert (appending) to same table via Scala API") { + test("Insert (overwrite) via Scala API") { + val dirname = Utils.createTempDir() + val source_rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) + .map(i => TestRDDEntry(i, s"val_$i")) + source_rdd.registerAsTable("source") + val dest_rdd = createParquetFile[TestRDDEntry](dirname.toString) + dest_rdd.registerAsTable("dest") + sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() + val rdd_copy1 = sql("SELECT * FROM dest").collect() + assert(rdd_copy1.size === 100) + assert(rdd_copy1(0).apply(0) === 1) + assert(rdd_copy1(0).apply(1) === "val_1") + // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is + // executed twice otherwise?! + sql("INSERT INTO dest SELECT * FROM source") + val rdd_copy2 = sql("SELECT * FROM dest").collect() + assert(rdd_copy2.size === 200) + assert(rdd_copy2(0).apply(0) === 1) + assert(rdd_copy2(0).apply(1) === "val_1") + assert(rdd_copy2(99).apply(0) === 100) + assert(rdd_copy2(99).apply(1) === "val_100") + assert(rdd_copy2(100).apply(0) === 1) + assert(rdd_copy2(100).apply(1) === "val_1") + Utils.deleteRecursively(dirname) + } + + test("Insert (appending) to same table via Scala API") { + // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is + // executed twice otherwise?! sql("INSERT INTO testsource SELECT * FROM testsource") val double_rdd = sql("SELECT * FROM testsource").collect() assert(double_rdd != null) @@ -363,4 +437,272 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val query = sql(s"SELECT mystring FROM testfiltersource WHERE myint < 10") assert(query.collect().size === 10) } + + test("Importing nested Parquet file (Addressbook)") { + val result = TestSQLContext + .parquetFile(ParquetTestData.testNestedDir1.toString) + .toSchemaRDD + .collect() + assert(result != null) + assert(result.size === 2) + val first_record = result(0) + val second_record = result(1) + assert(first_record != null) + assert(second_record != null) + assert(first_record.size === 3) + assert(second_record(1) === null) + assert(second_record(2) === null) + assert(second_record(0) === "A. Nonymous") + assert(first_record(0) === "Julien Le Dem") + val first_owner_numbers = first_record(1) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]] + val first_contacts = first_record(2) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]] + assert(first_owner_numbers != null) + assert(first_owner_numbers(0) === "555 123 4567") + assert(first_owner_numbers(2) === "XXX XXX XXXX") + assert(first_contacts(0) + .asInstanceOf[CatalystConverter.StructScalaType[_]].size === 2) + val first_contacts_entry_one = first_contacts(0) + .asInstanceOf[CatalystConverter.StructScalaType[_]] + assert(first_contacts_entry_one(0) === "Dmitriy Ryaboy") + assert(first_contacts_entry_one(1) === "555 987 6543") + val first_contacts_entry_two = first_contacts(1) + .asInstanceOf[CatalystConverter.StructScalaType[_]] + assert(first_contacts_entry_two(0) === "Chris Aniszczyk") + } + + test("Importing nested Parquet file (nested numbers)") { + val result = TestSQLContext + .parquetFile(ParquetTestData.testNestedDir2.toString) + .toSchemaRDD + .collect() + assert(result.size === 1, "number of top-level rows incorrect") + assert(result(0).size === 5, "number of fields in row incorrect") + assert(result(0)(0) === 1) + assert(result(0)(1) === 7) + val subresult1 = result(0)(2).asInstanceOf[CatalystConverter.ArrayScalaType[_]] + assert(subresult1.size === 3) + assert(subresult1(0) === (1.toLong << 32)) + assert(subresult1(1) === (1.toLong << 33)) + assert(subresult1(2) === (1.toLong << 34)) + val subresult2 = result(0)(3) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) + .asInstanceOf[CatalystConverter.StructScalaType[_]] + assert(subresult2.size === 2) + assert(subresult2(0) === 2.5) + assert(subresult2(1) === false) + val subresult3 = result(0)(4) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]] + assert(subresult3.size === 2) + assert(subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 2) + val subresult4 = subresult3(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]] + assert(subresult4(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) + assert(subresult4(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) + assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]].size === 1) + assert(subresult3(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) + } + + test("Simple query on addressbook") { + val data = TestSQLContext + .parquetFile(ParquetTestData.testNestedDir1.toString) + .toSchemaRDD + val tmp = data.where('owner === "Julien Le Dem").select('owner as 'a, 'contacts as 'c).collect() + assert(tmp.size === 1) + assert(tmp(0)(0) === "Julien Le Dem") + } + + test("Projection in addressbook") { + val data = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir1.toString) + .toSchemaRDD + data.registerAsTable("data") + val query = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM data") + val tmp = query.collect() + assert(tmp.size === 2) + assert(tmp(0).size === 2) + assert(tmp(0)(0) === "Julien Le Dem") + assert(tmp(0)(1) === "Chris Aniszczyk") + assert(tmp(1)(0) === "A. Nonymous") + assert(tmp(1)(1) === null) + } + + test("Simple query on nested int data") { + val data = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir2.toString) + .toSchemaRDD + data.registerAsTable("data") + val result1 = nestedParserSqlContext.sql("SELECT entries[0].value FROM data").collect() + assert(result1.size === 1) + assert(result1(0).size === 1) + assert(result1(0)(0) === 2.5) + val result2 = nestedParserSqlContext.sql("SELECT entries[0] FROM data").collect() + assert(result2.size === 1) + val subresult1 = result2(0)(0).asInstanceOf[CatalystConverter.StructScalaType[_]] + assert(subresult1.size === 2) + assert(subresult1(0) === 2.5) + assert(subresult1(1) === false) + val result3 = nestedParserSqlContext.sql("SELECT outerouter FROM data").collect() + val subresult2 = result3(0)(0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]] + assert(subresult2(0).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 7) + assert(subresult2(1).asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 8) + assert(result3(0)(0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](1) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) + .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) === 9) + } + + test("nested structs") { + val data = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir3.toString) + .toSchemaRDD + data.registerAsTable("data") + val result1 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() + assert(result1.size === 1) + assert(result1(0).size === 1) + assert(result1(0)(0) === false) + val result2 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() + assert(result2.size === 1) + assert(result2(0).size === 1) + assert(result2(0)(0) === true) + val result3 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() + assert(result3.size === 1) + assert(result3(0).size === 1) + assert(result3(0)(0) === false) + } + + test("simple map") { + val data = TestSQLContext + .parquetFile(ParquetTestData.testNestedDir4.toString) + .toSchemaRDD + data.registerAsTable("mapTable") + val result1 = sql("SELECT data1 FROM mapTable").collect() + assert(result1.size === 1) + assert(result1(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, _]] + .getOrElse("key1", 0) === 1) + assert(result1(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, _]] + .getOrElse("key2", 0) === 2) + val result2 = sql("""SELECT data1["key1"] FROM mapTable""").collect() + assert(result2(0)(0) === 1) + } + + test("map with struct values") { + val data = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir4.toString) + .toSchemaRDD + data.registerAsTable("mapTable") + val result1 = nestedParserSqlContext.sql("SELECT data2 FROM mapTable").collect() + assert(result1.size === 1) + val entry1 = result1(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] + .getOrElse("seven", null) + assert(entry1 != null) + assert(entry1(0) === 42) + assert(entry1(1) === "the answer") + val entry2 = result1(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] + .getOrElse("eight", null) + assert(entry2 != null) + assert(entry2(0) === 49) + assert(entry2(1) === null) + val result2 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() + assert(result2.size === 1) + assert(result2(0)(0) === 42.toLong) + assert(result2(0)(1) === "the answer") + } + + test("Writing out Addressbook and reading it back in") { + // TODO: find out why CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME + // has no effect in this test case + val tmpdir = Utils.createTempDir() + Utils.deleteRecursively(tmpdir) + val result = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir1.toString) + .toSchemaRDD + result.saveAsParquetFile(tmpdir.toString) + nestedParserSqlContext + .parquetFile(tmpdir.toString) + .toSchemaRDD + .registerAsTable("tmpcopy") + val tmpdata = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() + assert(tmpdata.size === 2) + assert(tmpdata(0).size === 2) + assert(tmpdata(0)(0) === "Julien Le Dem") + assert(tmpdata(0)(1) === "Chris Aniszczyk") + assert(tmpdata(1)(0) === "A. Nonymous") + assert(tmpdata(1)(1) === null) + Utils.deleteRecursively(tmpdir) + } + + test("Writing out Map and reading it back in") { + val data = nestedParserSqlContext + .parquetFile(ParquetTestData.testNestedDir4.toString) + .toSchemaRDD + val tmpdir = Utils.createTempDir() + Utils.deleteRecursively(tmpdir) + data.saveAsParquetFile(tmpdir.toString) + nestedParserSqlContext + .parquetFile(tmpdir.toString) + .toSchemaRDD + .registerAsTable("tmpmapcopy") + val result1 = nestedParserSqlContext.sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() + assert(result1.size === 1) + assert(result1(0)(0) === 2) + val result2 = nestedParserSqlContext.sql("SELECT data2 FROM tmpmapcopy").collect() + assert(result2.size === 1) + val entry1 = result2(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] + .getOrElse("seven", null) + assert(entry1 != null) + assert(entry1(0) === 42) + assert(entry1(1) === "the answer") + val entry2 = result2(0)(0) + .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] + .getOrElse("eight", null) + assert(entry2 != null) + assert(entry2(0) === 49) + assert(entry2(1) === null) + val result3 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() + assert(result3.size === 1) + assert(result3(0)(0) === 42.toLong) + assert(result3(0)(1) === "the answer") + Utils.deleteRecursively(tmpdir) + } +} + +// TODO: the code below is needed temporarily until the standard parser is able to parse +// nested field expressions correctly +class NestedParserSQLContext(@transient override val sparkContext: SparkContext) extends SQLContext(sparkContext) { + override protected[sql] val parser = new NestedSqlParser() +} + +class NestedSqlLexical(override val keywords: Seq[String]) extends SqlLexical(keywords) { + override def identChar = letter | elem('_') + delimiters += (".") +} + +class NestedSqlParser extends SqlParser { + override val lexical = new NestedSqlLexical(reservedWords) + + override protected lazy val baseExpression: PackratParser[Expression] = + expression ~ "[" ~ expression <~ "]" ^^ { + case base ~ _ ~ ordinal => GetItem(base, ordinal) + } | + expression ~ "." ~ ident ^^ { + case base ~ _ ~ fieldName => GetField(base, fieldName) + } | + TRUE ^^^ Literal(true, BooleanType) | + FALSE ^^^ Literal(false, BooleanType) | + cast | + "(" ~> expression <~ ")" | + function | + "-" ~> literal ^^ UnaryMinus | + ident ^^ UnresolvedAttribute | + "*" ^^^ Star(None) | + literal } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 68284344afd55..f923d68932f83 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -208,7 +208,9 @@ object HiveMetastoreTypes extends RegexParsers { } protected lazy val structType: Parser[DataType] = - "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ StructType + "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ { + case fields => new StructType(fields) + } protected lazy val dataType: Parser[DataType] = arrayType | From 61756409736a64bd42577782cb7468557fa0b642 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Jun 2014 23:58:23 -0700 Subject: [PATCH 50/57] [SPARK-2210] cast to boolean on boolean value gets turned into NOT((boolean_condition) = 0) ``` explain select cast(cast(key=0 as boolean) as boolean) aaa from src ``` should be ``` [Physical execution plan:] [Project [(key#10:0 = 0) AS aaa#7]] [ HiveTableScan [key#10], (MetastoreRelation default, src, None), None] ``` However, it is currently ``` [Physical execution plan:] [Project [NOT((key#10=0) = 0) AS aaa#7]] [ HiveTableScan [key#10], (MetastoreRelation default, src, None), None] ``` Author: Reynold Xin Closes #1144 from rxin/booleancast and squashes the following commits: c4e543d [Reynold Xin] [SPARK-2210] boolean cast on boolean value should be removed. --- .../catalyst/analysis/HiveTypeCoercion.scala | 4 ++- .../execution/HiveTypeCoercionSuite.scala | 25 ++++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 6d331fb501d08..c0714bcdd0afb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -251,7 +251,9 @@ trait HiveTypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - + // Skip if the type is boolean type already. Note that this extra cast should be removed + // by optimizer.SimplifyCasts. + case Cast(e, BooleanType) if e.dataType == BooleanType => e case Cast(e, BooleanType) => Not(Equals(e, Literal(0))) case Cast(e, dataType) if e.dataType == BooleanType => Cast(If(e, Literal(1), Literal(0)), dataType) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index e030c8ee3dfc8..cc8744c9668eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.catalyst.expressions.{Cast, Equals} +import org.apache.spark.sql.execution.Project +import org.apache.spark.sql.hive.test.TestHive + /** - * A set of tests that validate type promotion rules. + * A set of tests that validate type promotion and coercion rules. */ class HiveTypeCoercionSuite extends HiveComparisonTest { val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'") @@ -28,4 +32,23 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1") } } + + test("[SPARK-2210] boolean cast on boolean value should be removed") { + val q = "select cast(cast(key=0 as boolean) as boolean) from src" + val project = TestHive.hql(q).queryExecution.executedPlan.collect { case e: Project => e }.head + + // No cast expression introduced + project.transformAllExpressions { case c: Cast => + assert(false, "unexpected cast " + c) + c + } + + // Only one Equals + var numEquals = 0 + project.transformAllExpressions { case e: Equals => + numEquals += 1 + e + } + assert(numEquals === 1) + } } From c55bbb49f7ec653f0ff635015d3bc789ca26c4eb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 20 Jun 2014 00:01:19 -0700 Subject: [PATCH 51/57] [SPARK-2209][SQL] Cast shouldn't do null check twice. Also took the chance to clean up cast a little bit. Too many arrows on each line before! Author: Reynold Xin Closes #1143 from rxin/cast and squashes the following commits: dd006cb [Reynold Xin] Code review feedback. c2b88ae [Reynold Xin] [SPARK-2209][SQL] Cast shouldn't do null check twice. --- .../spark/sql/catalyst/expressions/Cast.scala | 274 ++++++++++-------- 1 file changed, 159 insertions(+), 115 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 0b3a4e728ec54..1f9716e385e9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -24,72 +24,87 @@ import org.apache.spark.sql.catalyst.types._ /** Cast the child expression to the target data type. */ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { override def foldable = child.foldable - def nullable = (child.dataType, dataType) match { + + override def nullable = (child.dataType, dataType) match { case (StringType, _: NumericType) => true case (StringType, TimestampType) => true case _ => child.nullable } + override def toString = s"CAST($child, $dataType)" type EvaluatedType = Any - def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) { - null - } else { - func(a.asInstanceOf[T]) - } + // [[func]] assumes the input is no longer null because eval already does the null check. + @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) // UDFToString - def castToString: Any => Any = child.dataType match { - case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8")) - case _ => nullOrCast[Any](_, _.toString) + private[this] def castToString: Any => Any = child.dataType match { + case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) + case _ => buildCast[Any](_, _.toString) } // BinaryConverter - def castToBinary: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, _.getBytes("UTF-8")) + private[this] def castToBinary: Any => Any = child.dataType match { + case StringType => buildCast[String](_, _.getBytes("UTF-8")) } // UDFToBoolean - def castToBoolean: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, _.length() != 0) - case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || b.getNanos() != 0)}) - case LongType => nullOrCast[Long](_, _ != 0) - case IntegerType => nullOrCast[Int](_, _ != 0) - case ShortType => nullOrCast[Short](_, _ != 0) - case ByteType => nullOrCast[Byte](_, _ != 0) - case DecimalType => nullOrCast[BigDecimal](_, _ != 0) - case DoubleType => nullOrCast[Double](_, _ != 0) - case FloatType => nullOrCast[Float](_, _ != 0) + private[this] def castToBoolean: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, _.length() != 0) + case TimestampType => + buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0) + case LongType => + buildCast[Long](_, _ != 0) + case IntegerType => + buildCast[Int](_, _ != 0) + case ShortType => + buildCast[Short](_, _ != 0) + case ByteType => + buildCast[Byte](_, _ != 0) + case DecimalType => + buildCast[BigDecimal](_, _ != 0) + case DoubleType => + buildCast[Double](_, _ != 0) + case FloatType => + buildCast[Float](_, _ != 0) } // TimestampConverter - def castToTimestamp: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => { - // Throw away extra if more than 9 decimal places - val periodIdx = s.indexOf("."); - var n = s - if (periodIdx != -1) { - if (n.length() - periodIdx > 9) { + private[this] def castToTimestamp: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => { + // Throw away extra if more than 9 decimal places + val periodIdx = s.indexOf(".") + var n = s + if (periodIdx != -1 && n.length() - periodIdx > 9) { n = n.substring(0, periodIdx + 10) } - } - try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null} - }) - case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000)) - case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000)) - case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000)) - case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000)) - case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000)) + try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null } + }) + case BooleanType => + buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0) * 1000)) + case LongType => + buildCast[Long](_, l => new Timestamp(l * 1000)) + case IntegerType => + buildCast[Int](_, i => new Timestamp(i * 1000)) + case ShortType => + buildCast[Short](_, s => new Timestamp(s * 1000)) + case ByteType => + buildCast[Byte](_, b => new Timestamp(b * 1000)) // TimestampWritable.decimalToTimestamp - case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d)) + case DecimalType => + buildCast[BigDecimal](_, d => decimalToTimestamp(d)) // TimestampWritable.doubleToTimestamp - case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d)) + case DoubleType => + buildCast[Double](_, d => decimalToTimestamp(d)) // TimestampWritable.floatToTimestamp - case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f)) + case FloatType => + buildCast[Float](_, f => decimalToTimestamp(f)) } - private def decimalToTimestamp(d: BigDecimal) = { + private[this] def decimalToTimestamp(d: BigDecimal) = { val seconds = d.longValue() val bd = (d - seconds) * 1000000000 val nanos = bd.intValue() @@ -104,85 +119,118 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } // Timestamp to long, converting milliseconds to seconds - private def timestampToLong(ts: Timestamp) = ts.getTime / 1000 + private[this] def timestampToLong(ts: Timestamp) = ts.getTime / 1000 - private def timestampToDouble(ts: Timestamp) = { + private[this] def timestampToDouble(ts: Timestamp) = { // First part is the seconds since the beginning of time, followed by nanosecs. ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000 } - def castToLong: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toLong catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1L else 0L) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t)) - case DecimalType => nullOrCast[BigDecimal](_, _.toLong) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) - } - - def castToInt: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toInt catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toInt) - case DecimalType => nullOrCast[BigDecimal](_, _.toInt) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) - } - - def castToShort: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toShort catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toShort else 0.toShort) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toShort) - case DecimalType => nullOrCast[BigDecimal](_, _.toShort) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort - } - - def castToByte: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toByte catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toByte else 0.toByte) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toByte) - case DecimalType => nullOrCast[BigDecimal](_, _.toByte) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte - } - - def castToDecimal: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0)) + private[this] def castToLong: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toLong catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1L else 0L) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t)) + case DecimalType => + buildCast[BigDecimal](_, _.toLong) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) + } + + private[this] def castToInt: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toInt catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1 else 0) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t).toInt) + case DecimalType => + buildCast[BigDecimal](_, _.toInt) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) + } + + private[this] def castToShort: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toShort catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t).toShort) + case DecimalType => + buildCast[BigDecimal](_, _.toShort) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort + } + + private[this] def castToByte: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toByte catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) + case TimestampType => + buildCast[Timestamp](_, t => timestampToLong(t).toByte) + case DecimalType => + buildCast[BigDecimal](_, _.toByte) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte + } + + private[this] def castToDecimal: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try BigDecimal(s.toDouble) catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0)) case TimestampType => // Note that we lose precision here. - nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) - case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) - } - - def castToDouble: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toDouble catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1d else 0d) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t)) - case DecimalType => nullOrCast[BigDecimal](_, _.toDouble) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) - } - - def castToFloat: Any => Any = child.dataType match { - case StringType => nullOrCast[String](_, s => try s.toFloat catch { - case _: NumberFormatException => null - }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1f else 0f) - case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat) - case DecimalType => nullOrCast[BigDecimal](_, _.toFloat) - case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) + buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) + case x: NumericType => + b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) + } + + private[this] def castToDouble: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toDouble catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1d else 0d) + case TimestampType => + buildCast[Timestamp](_, t => timestampToDouble(t)) + case DecimalType => + buildCast[BigDecimal](_, _.toDouble) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) + } + + private[this] def castToFloat: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => try s.toFloat catch { + case _: NumberFormatException => null + }) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1f else 0f) + case TimestampType => + buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) + case DecimalType => + buildCast[BigDecimal](_, _.toFloat) + case x: NumericType => + b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } - private lazy val cast: Any => Any = dataType match { + private[this] lazy val cast: Any => Any = dataType match { case StringType => castToString case BinaryType => castToBinary case DecimalType => castToDecimal @@ -198,10 +246,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { override def eval(input: Row): Any = { val evaluated = child.eval(input) - if (evaluated == null) { - null - } else { - cast(evaluated) - } + if (evaluated == null) null else cast(evaluated) } } From f46e02fcdbb3f86a8761c078708388d18282ee0c Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Fri, 20 Jun 2014 00:06:57 -0700 Subject: [PATCH 52/57] SPARK-2203: PySpark defaults to use same num reduce partitions as map side For shuffle-based operators, such as rdd.groupBy() or rdd.sortByKey(), PySpark will always assume that the default parallelism to use for the reduce side is ctx.defaultParallelism, which is a constant typically determined by the number of cores in cluster. In contrast, Spark's Partitioner#defaultPartitioner will use the same number of reduce partitions as map partitions unless the defaultParallelism config is explicitly set. This tends to be a better default in order to avoid OOMs, and should also be the behavior of PySpark. JIRA: https://issues.apache.org/jira/browse/SPARK-2203 Author: Aaron Davidson Closes #1138 from aarondav/pyfix and squashes the following commits: 1bd5751 [Aaron Davidson] SPARK-2203: PySpark defaults to use same num reduce partitions as map partitions --- python/pyspark/rdd.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a0b2c744f0e7f..62a95c84675dd 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -512,7 +512,7 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x): [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)] """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._defaultReducePartitions() bounds = list() @@ -1154,7 +1154,7 @@ def partitionBy(self, numPartitions, partitionFunc=None): set([]) """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._defaultReducePartitions() if partitionFunc is None: partitionFunc = lambda x: 0 if x is None else hash(x) @@ -1212,7 +1212,7 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, [('a', '11'), ('b', '1')] """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._defaultReducePartitions() def combineLocally(iterator): combiners = {} for x in iterator: @@ -1475,6 +1475,21 @@ def getStorageLevel(self): java_storage_level.replication()) return storage_level + def _defaultReducePartitions(self): + """ + Returns the default number of partitions to use during reduce tasks (e.g., groupBy). + If spark.default.parallelism is set, then we'll use the value from SparkContext + defaultParallelism, otherwise we'll use the number of partitions in this RDD. + + This mirrors the behavior of the Scala Partitioner#defaultPartitioner, intended to reduce + the likelihood of OOMs. Once PySpark adopts Partitioner-based APIs, this behavior will + be inherent. + """ + if self.ctx._conf.contains("spark.default.parallelism"): + return self.ctx.defaultParallelism + else: + return self.getNumPartitions() + # TODO: `lookup` is disabled because we can't make direct comparisons based # on the key; we need to compare the hash of the key to the hash of the # keys in the pairs. This could be an expensive operation, since those From 324952892085d1933bcf392ce8f2ced452fe741e Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 20 Jun 2014 00:12:52 -0700 Subject: [PATCH 53/57] [SPARK-2196] [SQL] Fix nullability of CaseWhen. `CaseWhen` should use `branches.length` to check if `elseValue` is provided or not. Author: Takuya UESHIN Closes #1133 from ueshin/issues/SPARK-2196 and squashes the following commits: 510f12d [Takuya UESHIN] Add some tests. dc25e8d [Takuya UESHIN] Fix nullable of CaseWhen to be nullable if the elseValue is nullable. 4f049cc [Takuya UESHIN] Fix nullability of CaseWhen. --- .../sql/catalyst/expressions/predicates.scala | 4 +- .../ExpressionEvaluationSuite.scala | 43 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2902906df2844..2718d4364601c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -233,10 +233,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq @transient private[this] lazy val values = branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq + @transient private[this] lazy val elseValue = + if (branches.length % 2 == 0) None else Option(branches.last) override def nullable = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - values.exists(_.nullable) || (values.length % 2 == 0) + values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) } override lazy val resolved = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 8c3b062d0f801..84d72814778ba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -333,6 +333,49 @@ class ExpressionEvaluationSuite extends FunSuite { Literal("^Ba*n", StringType) :: c2 :: Nil), true, row) } + test("case when") { + val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c")) + val c1 = 'a.boolean.at(0) + val c2 = 'a.boolean.at(1) + val c3 = 'a.boolean.at(2) + val c4 = 'a.string.at(3) + val c5 = 'a.string.at(4) + val c6 = 'a.string.at(5) + + checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(Literal(null, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal(false, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal(true, BooleanType), c4, c6)), "a", row) + + checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) + checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) + + assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) + + val c4_notNull = 'a.boolean.notNull.at(3) + val c5_notNull = 'a.boolean.notNull.at(4) + val c6_notNull = 'a.boolean.notNull.at(5) + + assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) + assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) + + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) + assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) + + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) + } + test("complex type") { val row = new GenericRow(Array[Any]( "^Ba*n", // 0 From 2f6a835e1a039a0b1ba6e184b3350444b70f91df Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 20 Jun 2014 00:34:59 -0700 Subject: [PATCH 54/57] [SPARK-2218] rename Equals to EqualTo in Spark SQL expressions. Due to the existence of scala.Equals, it is very error prone to name the expression Equals, especially because we use a lot of partial functions and pattern matching in the optimizer. Note that this sits on top of #1144. Author: Reynold Xin Closes #1146 from rxin/equals and squashes the following commits: f8583fd [Reynold Xin] Merge branch 'master' of github.com:apache/spark into equals 326b388 [Reynold Xin] Merge branch 'master' of github.com:apache/spark into equals bd19807 [Reynold Xin] Rename EqualsTo to EqualTo. 81148d1 [Reynold Xin] [SPARK-2218] rename Equals to EqualsTo in Spark SQL expressions. c4e543d [Reynold Xin] [SPARK-2210] boolean cast on boolean value should be removed. --- .../apache/spark/sql/catalyst/SqlParser.scala | 6 +++--- .../catalyst/analysis/HiveTypeCoercion.scala | 9 ++++++--- .../spark/sql/catalyst/dsl/package.scala | 6 +++--- .../sql/catalyst/expressions/package.scala | 2 +- .../sql/catalyst/expressions/predicates.scala | 4 ++-- .../sql/catalyst/planning/patterns.scala | 6 +++--- .../optimizer/ConstantFoldingSuite.scala | 4 ++-- .../spark/sql/parquet/ParquetFilters.scala | 4 ++-- .../spark/sql/parquet/ParquetQuerySuite.scala | 19 +++++++------------ .../org/apache/spark/sql/hive/HiveQl.scala | 10 +++++----- .../execution/HiveTypeCoercionSuite.scala | 8 ++++---- 11 files changed, 38 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 2ad2d04af5704..0cc4592047b19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -258,13 +258,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers { comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1,e2) }) protected lazy val comparisonExpression: Parser[Expression] = - termExpression ~ "=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Equals(e1, e2) } | + termExpression ~ "=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => EqualTo(e1, e2) } | termExpression ~ "<" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThan(e1, e2) } | termExpression ~ "<=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThanOrEqual(e1, e2) } | termExpression ~ ">" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThan(e1, e2) } | termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } | - termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } | - termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } | + termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | + termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | termExpression ~ REGEXP ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | termExpression ~ LIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => Like(e1, e2) } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index c0714bcdd0afb..76ddeba9cb312 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -234,8 +234,8 @@ trait HiveTypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - // No need to change Equals operators as that actually makes sense for boolean types. - case e: Equals => e + // No need to change EqualTo operators as that actually makes sense for boolean types. + case e: EqualTo => e // Otherwise turn them to Byte types so that there exists and ordering. case p: BinaryComparison if p.left.dataType == BooleanType && p.right.dataType == BooleanType => @@ -254,7 +254,10 @@ trait HiveTypeCoercion { // Skip if the type is boolean type already. Note that this extra cast should be removed // by optimizer.SimplifyCasts. case Cast(e, BooleanType) if e.dataType == BooleanType => e - case Cast(e, BooleanType) => Not(Equals(e, Literal(0))) + // If the data type is not boolean and is being cast boolean, turn it into a comparison + // with the numeric value, i.e. x != 0. This will coerce the type into numeric type. + case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) + // Turn true into 1, and false into 0 if casting boolean into other types. case Cast(e, dataType) if e.dataType == BooleanType => Cast(If(e, Literal(1), Literal(0)), dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index d177339d40ae5..26ad4837b0b01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.types._ * * // These unresolved attributes can be used to create more complicated expressions. * scala> 'a === 'b - * res2: org.apache.spark.sql.catalyst.expressions.Equals = ('a = 'b) + * res2: org.apache.spark.sql.catalyst.expressions.EqualTo = ('a = 'b) * * // SQL verbs can be used to construct logical query plans. * scala> import org.apache.spark.sql.catalyst.plans.logical._ @@ -76,8 +76,8 @@ package object dsl { def <= (other: Expression) = LessThanOrEqual(expr, other) def > (other: Expression) = GreaterThan(expr, other) def >= (other: Expression) = GreaterThanOrEqual(expr, other) - def === (other: Expression) = Equals(expr, other) - def !== (other: Expression) = Not(Equals(expr, other)) + def === (other: Expression) = EqualTo(expr, other) + def !== (other: Expression) = Not(EqualTo(expr, other)) def like(other: Expression) = Like(expr, other) def rlike(other: Expression) = RLike(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 573ec052f4266..b6f2451b52e1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -24,7 +24,7 @@ package org.apache.spark.sql.catalyst * expression, a [[NamedExpression]] in addition to the standard collection of expressions. * * ==Standard Expressions== - * A library of standard expressions (e.g., [[Add]], [[Equals]]), aggregates (e.g., SUM, COUNT), + * A library of standard expressions (e.g., [[Add]], [[EqualTo]]), aggregates (e.g., SUM, COUNT), * and other computations (e.g. UDFs). Each expression type is capable of determining its output * schema as a function of its children's output schema. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2718d4364601c..b63406b94a4a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -52,7 +52,7 @@ trait PredicateHelper { * * For example consider a join between two relations R(a, b) and S(c, d). * - * `canEvaluate(Equals(a,b), R)` returns `true` where as `canEvaluate(Equals(a,c), R)` returns + * `canEvaluate(EqualTo(a,b), R)` returns `true` where as `canEvaluate(EqualTo(a,c), R)` returns * `false`. */ protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean = @@ -140,7 +140,7 @@ abstract class BinaryComparison extends BinaryPredicate { self: Product => } -case class Equals(left: Expression, right: Expression) extends BinaryComparison { +case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { def symbol = "=" override def eval(input: Row): Any = { val l = left.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 820ecfb78b52e..a43bef389c4bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -136,14 +136,14 @@ object HashFilteredJoin extends Logging with PredicateHelper { val Join(left, right, joinType, _) = join val (joinPredicates, otherPredicates) = allPredicates.flatMap(splitConjunctivePredicates).partition { - case Equals(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) || + case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) || (canEvaluate(l, right) && canEvaluate(r, left)) => true case _ => false } val joinKeys = joinPredicates.map { - case Equals(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r) - case Equals(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l) + case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r) + case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l) } // Do not consider this strategy if there are no join keys. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index cea97c584f7e1..0ff82064012a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -195,8 +195,8 @@ class ConstantFoldingSuite extends PlanTest { Add(Literal(null, IntegerType), 1) as 'c9, Add(1, Literal(null, IntegerType)) as 'c10, - Equals(Literal(null, IntegerType), 1) as 'c11, - Equals(1, Literal(null, IntegerType)) as 'c12, + EqualTo(Literal(null, IntegerType), 1) as 'c11, + EqualTo(1, Literal(null, IntegerType)) as 'c12, Like(Literal(null, StringType), "abc") as 'c13, Like("abc", Literal(null, StringType)) as 'c14, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 052b0a9196717..cc575bedd8fcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -205,9 +205,9 @@ object ParquetFilters { Some(new AndFilter(leftFilter.get, rightFilter.get)) } } - case p @ Equals(left: Literal, right: NamedExpression) if !right.nullable => + case p @ EqualTo(left: Literal, right: NamedExpression) if !right.nullable => Some(createEqualityFilter(right.name, left, p)) - case p @ Equals(left: NamedExpression, right: Literal) if !left.nullable => + case p @ EqualTo(left: NamedExpression, right: Literal) if !left.nullable => Some(createEqualityFilter(left.name, right, p)) case p @ LessThan(left: Literal, right: NamedExpression) if !right.nullable => Some(createLessThanFilter(right.name, left, p)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 0c239d00b199b..7714eb1b5628a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -19,27 +19,23 @@ package org.apache.spark.sql.parquet import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} -import org.apache.hadoop.fs.{Path, FileSystem} -import org.apache.hadoop.mapreduce.Job - import parquet.hadoop.ParquetFileWriter import parquet.hadoop.util.ContextUtil import parquet.schema.MessageTypeParser +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.mapreduce.Job import org.apache.spark.SparkContext import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser} +import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.TestData -import org.apache.spark.sql.SchemaRDD import org.apache.spark.sql.catalyst.util.getTempFilePath -import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star} +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils -// Implicits -import org.apache.spark.sql.test.TestSQLContext._ case class TestRDDEntry(key: Int, value: String) @@ -72,7 +68,6 @@ case class AllDataTypes( booleanField: Boolean) class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { - import TestData._ TestData // Load test data tables. var testRDD: SchemaRDD = null @@ -319,7 +314,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("create RecordFilter for simple predicates") { val attribute1 = new AttributeReference("first", IntegerType, false)() - val predicate1 = new Equals(attribute1, new Literal(1, IntegerType)) + val predicate1 = new EqualTo(attribute1, new Literal(1, IntegerType)) val filter1 = ParquetFilters.createFilter(predicate1) assert(filter1.isDefined) assert(filter1.get.predicate == predicate1, "predicates do not match") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index df761b073a75a..ec653efcc8c58 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -698,7 +698,7 @@ private[hive] object HiveQl { val joinConditions = joinExpressions.sliding(2).map { case Seq(c1, c2) => - val predicates = (c1, c2).zipped.map { case (e1, e2) => Equals(e1, e2): Expression } + val predicates = (c1, c2).zipped.map { case (e1, e2) => EqualTo(e1, e2): Expression } predicates.reduceLeft(And) }.toBuffer @@ -924,9 +924,9 @@ private[hive] object HiveQl { case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) /* Comparisons */ - case Token("=", left :: right:: Nil) => Equals(nodeToExpr(left), nodeToExpr(right)) - case Token("!=", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right))) - case Token("<>", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right))) + case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) + case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) + case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) @@ -966,7 +966,7 @@ private[hive] object HiveQl { // FIXME (SPARK-2155): the key will get evaluated for multiple times in CaseWhen's eval(). // Hence effectful / non-deterministic key expressions are *not* supported at the moment. // We should consider adding new Expressions to get around this. - Seq(Equals(nodeToExpr(branches(0)), nodeToExpr(condVal)), + Seq(EqualTo(nodeToExpr(branches(0)), nodeToExpr(condVal)), nodeToExpr(value)) case Seq(elseVal) => Seq(nodeToExpr(elseVal)) }.toSeq.reduce(_ ++ _) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index cc8744c9668eb..7436de264a1e1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.catalyst.expressions.{Cast, Equals} +import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} import org.apache.spark.sql.execution.Project import org.apache.spark.sql.hive.test.TestHive @@ -39,13 +39,13 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { // No cast expression introduced project.transformAllExpressions { case c: Cast => - assert(false, "unexpected cast " + c) + fail(s"unexpected cast $c") c } - // Only one Equals + // Only one equality check var numEquals = 0 - project.transformAllExpressions { case e: Equals => + project.transformAllExpressions { case e: EqualTo => numEquals += 1 e } From d484ddeff1440d8e14e05c3cd7e7a18746f1a586 Mon Sep 17 00:00:00 2001 From: Gang Bai Date: Fri, 20 Jun 2014 08:52:20 -0700 Subject: [PATCH 55/57] [SPARK-2163] class LBFGS optimize with Double tolerance instead of Int https://issues.apache.org/jira/browse/SPARK-2163 This pull request includes the change for **[SPARK-2163]**: * Changed the convergence tolerance parameter from type `Int` to type `Double`. * Added types for vars in `class LBFGS`, making the style consistent with `class GradientDescent`. * Added associated test to check that optimizing via `class LBFGS` produces the same results as via calling `runLBFGS` from `object LBFGS`. This is a very minor change but it will solve the problem in my implementation of a regression model for count data, where I make use of LBFGS for parameter estimation. Author: Gang Bai Closes #1104 from BaiGang/fix_int_tol and squashes the following commits: cecf02c [Gang Bai] Changed setConvergenceTol'' to specify tolerance with a parameter of type Double. For the reason and the problem caused by an Int parameter, please check https://issues.apache.org/jira/browse/SPARK-2163. Added a test in LBFGSSuite for validating that optimizing via class LBFGS produces the same results as calling runLBFGS from object LBFGS. Keep the indentations and styles correct. --- .../spark/mllib/optimization/LBFGS.scala | 2 +- .../spark/mllib/optimization/LBFGSSuite.scala | 34 +++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 8f187c9df5102..7bbed9c8fdbef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -60,7 +60,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. * Smaller value will lead to higher accuracy with the cost of more iterations. */ - def setConvergenceTol(tolerance: Int): this.type = { + def setConvergenceTol(tolerance: Double): this.type = { this.convergenceTol = tolerance this } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 4b1850659a18e..fe7a9033cd5f4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -195,4 +195,38 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { assert(lossLBFGS3.length == 6) assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol) } + + test("Optimize via class LBFGS.") { + val regParam = 0.2 + + // Prepare another non-zero weights to compare the loss in the first iteration. + val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12) + val convergenceTol = 1e-12 + val maxNumIterations = 10 + + val lbfgsOptimizer = new LBFGS(gradient, squaredL2Updater) + .setNumCorrections(numCorrections) + .setConvergenceTol(convergenceTol) + .setMaxNumIterations(maxNumIterations) + .setRegParam(regParam) + + val weightLBFGS = lbfgsOptimizer.optimize(dataRDD, initialWeightsWithIntercept) + + val numGDIterations = 50 + val stepSize = 1.0 + val (weightGD, _) = GradientDescent.runMiniBatchSGD( + dataRDD, + gradient, + squaredL2Updater, + stepSize, + numGDIterations, + regParam, + miniBatchFrac, + initialWeightsWithIntercept) + + // for class LBFGS and the optimize method, we only look at the weights + assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) && + compareDouble(weightLBFGS(1), weightGD(1), 0.02), + "The weight differences between LBFGS and GD should be within 2%.") + } } From 6a224c31e8563156ad5732a23667e73076984ae1 Mon Sep 17 00:00:00 2001 From: "Allan Douglas R. de Oliveira" Date: Fri, 20 Jun 2014 11:03:03 -0700 Subject: [PATCH 56/57] SPARK-1868: Users should be allowed to cogroup at least 4 RDDs Adds cogroup for 4 RDDs. Author: Allan Douglas R. de Oliveira Closes #813 from douglaz/more_cogroups and squashes the following commits: f8d6273 [Allan Douglas R. de Oliveira] Test python groupWith for one more case 0e9009c [Allan Douglas R. de Oliveira] Added scala tests c3ffcdd [Allan Douglas R. de Oliveira] Added java tests 517a67f [Allan Douglas R. de Oliveira] Added tests for python groupWith 2f402d5 [Allan Douglas R. de Oliveira] Removed TODO 17474f4 [Allan Douglas R. de Oliveira] Use new cogroup function 7877a2a [Allan Douglas R. de Oliveira] Fixed code ba02414 [Allan Douglas R. de Oliveira] Added varargs cogroup to pyspark c4a8a51 [Allan Douglas R. de Oliveira] Added java cogroup 4 e94963c [Allan Douglas R. de Oliveira] Fixed spacing f1ee57b [Allan Douglas R. de Oliveira] Fixed scala style issues d7196f1 [Allan Douglas R. de Oliveira] Allow the cogroup of 4 RDDs --- .../apache/spark/api/java/JavaPairRDD.scala | 51 +++++++++++++++ .../apache/spark/rdd/PairRDDFunctions.scala | 51 +++++++++++++++ .../java/org/apache/spark/JavaAPISuite.java | 63 +++++++++++++++++++ .../spark/rdd/PairRDDFunctionsSuite.scala | 33 ++++++++++ python/pyspark/join.py | 20 +++--- python/pyspark/rdd.py | 22 ++++--- 6 files changed, 223 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 14fa9d8135afe..4f3081433a542 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -543,6 +543,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) partitioner: Partitioner): JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner))) + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3], + partitioner: Partitioner) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, partitioner))) + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -558,6 +570,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2))) + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3]) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3))) + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -574,6 +597,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions))) + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3], + numPartitions: Int) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, numPartitions))) + /** Alias for cogroup. */ def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JIterable[V], JIterable[W])] = fromRDD(cogroupResultToJava(rdd.groupWith(other))) @@ -583,6 +618,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2))) + /** Alias for cogroup. */ + def groupWith[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3]) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.groupWith(other1, other2, other3))) + /** * Return the list of values in the RDD for key `key`. This operation is done efficiently if the * RDD has a known partitioner by only searching the partition that the key maps to. @@ -786,6 +828,15 @@ object JavaPairRDD { .mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3))) } + private[spark] + def cogroupResult3ToJava[K: ClassTag, V, W1, W2, W3]( + rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))]) + : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3]))] = { + rddToPairRDDFunctions(rdd) + .mapValues(x => + (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3), asJavaIterable(x._4))) + } + def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = { new JavaPairRDD[K, V](rdd) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index fe36c80e0be84..443d1c587c3ee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -567,6 +567,28 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) new FlatMappedValuesRDD(self, cleanF) } + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: RDD[(K, W1)], + other2: RDD[(K, W2)], + other3: RDD[(K, W3)], + partitioner: Partitioner) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + val cg = new CoGroupedRDD[K](Seq(self, other1, other2, other3), partitioner) + cg.mapValues { case Seq(vs, w1s, w2s, w3s) => + (vs.asInstanceOf[Seq[V]], + w1s.asInstanceOf[Seq[W1]], + w2s.asInstanceOf[Seq[W2]], + w3s.asInstanceOf[Seq[W3]]) + } + } + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -599,6 +621,16 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } } + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)]) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3)) + } + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -633,6 +665,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) cogroup(other1, other2, new HashPartitioner(numPartitions)) } + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: RDD[(K, W1)], + other2: RDD[(K, W2)], + other3: RDD[(K, W3)], + numPartitions: Int) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + cogroup(other1, other2, other3, new HashPartitioner(numPartitions)) + } + /** Alias for cogroup. */ def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = { cogroup(other, defaultPartitioner(self, other)) @@ -644,6 +689,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) cogroup(other1, other2, defaultPartitioner(self, other1, other2)) } + /** Alias for cogroup. */ + def groupWith[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)]) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3)) + } + /** * Return an RDD with the pairs from `this` whose keys are not in `other`. * diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e46298c6a9e63..761f2d6a77d33 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -21,6 +21,9 @@ import java.util.*; import scala.Tuple2; +import scala.Tuple3; +import scala.Tuple4; + import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; @@ -304,6 +307,66 @@ public void cogroup() { cogrouped.collect(); } + @SuppressWarnings("unchecked") + @Test + public void cogroup3() { + JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( + new Tuple2("Apples", "Fruit"), + new Tuple2("Oranges", "Fruit"), + new Tuple2("Oranges", "Citrus") + )); + JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 2), + new Tuple2("Apples", 3) + )); + JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 21), + new Tuple2("Apples", 42) + )); + + JavaPairRDD, Iterable, Iterable>> cogrouped = + categories.cogroup(prices, quantities); + Assert.assertEquals("[Fruit, Citrus]", + Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); + Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); + + + cogrouped.collect(); + } + + @SuppressWarnings("unchecked") + @Test + public void cogroup4() { + JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( + new Tuple2("Apples", "Fruit"), + new Tuple2("Oranges", "Fruit"), + new Tuple2("Oranges", "Citrus") + )); + JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 2), + new Tuple2("Apples", 3) + )); + JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 21), + new Tuple2("Apples", 42) + )); + JavaPairRDD countries = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", "BR"), + new Tuple2("Apples", "US") + )); + + JavaPairRDD, Iterable, Iterable, Iterable>> cogrouped = + categories.cogroup(prices, quantities, countries); + Assert.assertEquals("[Fruit, Citrus]", + Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); + Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); + Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4())); + + cogrouped.collect(); + } + @SuppressWarnings("unchecked") @Test public void leftOuterJoin() { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 0b9004448a63e..447e38ec9dbd0 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -249,6 +249,39 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { )) } + test("groupWith3") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd'))) + val joined = rdd1.groupWith(rdd2, rdd3).collect() + assert(joined.size === 4) + val joinedSet = joined.map(x => (x._1, + (x._2._1.toList, x._2._2.toList, x._2._3.toList))).toSet + assert(joinedSet === Set( + (1, (List(1, 2), List('x'), List('a'))), + (2, (List(1), List('y', 'z'), List())), + (3, (List(1), List(), List('b'))), + (4, (List(), List('w'), List('c', 'd'))) + )) + } + + test("groupWith4") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd'))) + val rdd4 = sc.parallelize(Array((2, '@'))) + val joined = rdd1.groupWith(rdd2, rdd3, rdd4).collect() + assert(joined.size === 4) + val joinedSet = joined.map(x => (x._1, + (x._2._1.toList, x._2._2.toList, x._2._3.toList, x._2._4.toList))).toSet + assert(joinedSet === Set( + (1, (List(1, 2), List('x'), List('a'), List())), + (2, (List(1), List('y', 'z'), List(), List('@'))), + (3, (List(1), List(), List('b'), List())), + (4, (List(), List('w'), List('c', 'd'), List())) + )) + } + test("zero-partition RDD") { val emptyDir = Files.createTempDir() emptyDir.deleteOnExit() diff --git a/python/pyspark/join.py b/python/pyspark/join.py index 6f94d26ef86a9..5f3a7e71f7866 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -79,15 +79,15 @@ def dispatch(seq): return _do_python_join(rdd, other, numPartitions, dispatch) -def python_cogroup(rdd, other, numPartitions): - vs = rdd.map(lambda (k, v): (k, (1, v))) - ws = other.map(lambda (k, v): (k, (2, v))) +def python_cogroup(rdds, numPartitions): + def make_mapper(i): + return lambda (k, v): (k, (i, v)) + vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)] + union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds) + rdd_len = len(vrdds) def dispatch(seq): - vbuf, wbuf = [], [] + bufs = [[] for i in range(rdd_len)] for (n, v) in seq: - if n == 1: - vbuf.append(v) - elif n == 2: - wbuf.append(v) - return (ResultIterable(vbuf), ResultIterable(wbuf)) - return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch) + bufs[n].append(v) + return tuple(map(ResultIterable, bufs)) + return union_vrdds.groupByKey(numPartitions).mapValues(dispatch) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 62a95c84675dd..1d55c35a8bf48 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1233,7 +1233,7 @@ def _mergeCombiners(iterator): combiners[k] = mergeCombiners(combiners[k], v) return combiners.iteritems() return shuffled.mapPartitions(_mergeCombiners) - + def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ Aggregate the values of each key, using given combine functions and a neutral "zero value". @@ -1245,7 +1245,7 @@ def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ def createZero(): return copy.deepcopy(zeroValue) - + return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) def foldByKey(self, zeroValue, func, numPartitions=None): @@ -1323,12 +1323,20 @@ def mapValues(self, f): map_values_fn = lambda (k, v): (k, f(v)) return self.map(map_values_fn, preservesPartitioning=True) - # TODO: support varargs cogroup of several RDDs. - def groupWith(self, other): + def groupWith(self, other, *others): """ - Alias for cogroup. + Alias for cogroup but with support for multiple RDDs. + + >>> w = sc.parallelize([("a", 5), ("b", 6)]) + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) + >>> z = sc.parallelize([("b", 42)]) + >>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), list(y[3])))), \ + sorted(list(w.groupWith(x, y, z).collect()))) + [('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))] + """ - return self.cogroup(other) + return python_cogroup((self, other) + others, numPartitions=None) # TODO: add variant with custom parittioner def cogroup(self, other, numPartitions=None): @@ -1342,7 +1350,7 @@ def cogroup(self, other, numPartitions=None): >>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect()))) [('a', ([1], [2])), ('b', ([4], []))] """ - return python_cogroup(self, other, numPartitions) + return python_cogroup((self, other), numPartitions) def subtractByKey(self, other, numPartitions=None): """ From 171ebb3a824a577d69443ec68a3543b27914cf6d Mon Sep 17 00:00:00 2001 From: William Benton Date: Fri, 20 Jun 2014 13:41:38 -0700 Subject: [PATCH 57/57] SPARK-2180: support HAVING clauses in Hive queries This PR extends Spark's HiveQL support to handle HAVING clauses in aggregations. The HAVING test from the Hive compatibility suite doesn't appear to be runnable from within Spark, so I added a simple comparable test to `HiveQuerySuite`. Author: William Benton Closes #1136 from willb/SPARK-2180 and squashes the following commits: 3bbaf26 [William Benton] Added casts to HAVING expressions 83f1340 [William Benton] scalastyle fixes 18387f1 [William Benton] Add test for HAVING without GROUP BY b880bef [William Benton] Added semantic error for HAVING without GROUP BY 942428e [William Benton] Added test coverage for SPARK-2180. 56084cc [William Benton] Add support for HAVING clauses in Hive queries. --- .../org/apache/spark/sql/hive/HiveQl.scala | 30 +++++++++++++++---- .../sql/hive/execution/HiveQuerySuite.scala | 29 ++++++++++++++++++ 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ec653efcc8c58..c69e3dba6b467 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -204,6 +204,9 @@ private[hive] object HiveQl { class ParseException(sql: String, cause: Throwable) extends Exception(s"Failed to parse: $sql", cause) + class SemanticException(msg: String) + extends Exception(s"Error in semantic analysis: $msg") + /** * Returns the AST for the given SQL string. */ @@ -480,6 +483,7 @@ private[hive] object HiveQl { whereClause :: groupByClause :: orderByClause :: + havingClause :: sortByClause :: clusterByClause :: distributeByClause :: @@ -494,6 +498,7 @@ private[hive] object HiveQl { "TOK_WHERE", "TOK_GROUPBY", "TOK_ORDERBY", + "TOK_HAVING", "TOK_SORTBY", "TOK_CLUSTERBY", "TOK_DISTRIBUTEBY", @@ -576,21 +581,34 @@ private[hive] object HiveQl { val withDistinct = if (selectDistinctClause.isDefined) Distinct(withProject) else withProject + val withHaving = havingClause.map { h => + + if (groupByClause == None) { + throw new SemanticException("HAVING specified without GROUP BY") + } + + val havingExpr = h.getChildren.toSeq match { + case Seq(hexpr) => nodeToExpr(hexpr) + } + + Filter(Cast(havingExpr, BooleanType), withDistinct) + }.getOrElse(withDistinct) + val withSort = (orderByClause, sortByClause, distributeByClause, clusterByClause) match { case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.getChildren.map(nodeToSortOrder), withDistinct) + Sort(totalOrdering.getChildren.map(nodeToSortOrder), withHaving) case (None, Some(perPartitionOrdering), None, None) => - SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withDistinct) + SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withHaving) case (None, None, Some(partitionExprs), None) => - Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct) + Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving) case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), - Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct)) + Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, Some(clusterExprs)) => SortPartitions(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), - Repartition(clusterExprs.getChildren.map(nodeToExpr), withDistinct)) - case (None, None, None, None) => withDistinct + Repartition(clusterExprs.getChildren.map(nodeToExpr), withHaving)) + case (None, None, None, None) => withHaving case _ => sys.error("Unsupported set of ordering / distribution clauses.") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 9f5cf282f7c48..80185098bf24f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -224,6 +224,32 @@ class HiveQuerySuite extends HiveComparisonTest { TestHive.reset() } + test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { + val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) + .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} + + TestHive.sparkContext.parallelize(fixture).registerAsTable("having_test") + + val results = + hql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") + .collect() + .map(x => Pair(x.getString(0), x.getInt(1))) + + assert(results === Array(Pair("foo", 4))) + + TestHive.reset() + } + + test("SPARK-2180: HAVING without GROUP BY raises exception") { + intercept[Exception] { + hql("SELECT value, attr FROM having_test HAVING attr > 3") + } + } + + test("SPARK-2180: HAVING with non-boolean clause raises no exceptions") { + val results = hql("select key, count(*) c from src group by key having c").collect() + } + test("Query Hive native command execution result") { val tableName = "test_native_commands" @@ -441,3 +467,6 @@ class HiveQuerySuite extends HiveComparisonTest { // since they modify /clear stuff. } + +// for SPARK-2180 test +case class HavingRow(key: Int, value: String, attr: Int)