Skip to content

Commit

Permalink
Merge pull request apache#501 from JoshRosen/cartesian-rdd-fixes
Browse files Browse the repository at this point in the history
Fix two bugs in PySpark cartesian(): SPARK-978 and SPARK-1034

This pull request fixes two bugs in PySpark's `cartesian()` method:

- [SPARK-978](https://spark-project.atlassian.net/browse/SPARK-978): PySpark's cartesian method throws ClassCastException exception
- [SPARK-1034](https://spark-project.atlassian.net/browse/SPARK-1034): Py4JException on PySpark Cartesian Result

The JIRAs have more details describing the fixes.
  • Loading branch information
pwendell committed Jan 24, 2014
2 parents fad6aac + 6156990 commit cad3002
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K

override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd)

override val classTag: ClassTag[(K, V)] =
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K, V]]]
override val classTag: ClassTag[(K, V)] = rdd.elementClassTag

import JavaPairRDD._

Expand Down
59 changes: 39 additions & 20 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ private[spark] class PythonRDD[T: ClassTag](
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
for (elem <- parent.iterator(split, context)) {
PythonRDD.writeToStream(elem, dataOut)
}
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
dataOut.flush()
worker.shutdownOutput()
} catch {
Expand Down Expand Up @@ -206,20 +204,43 @@ private[spark] object PythonRDD {
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}

def writeToStream(elem: Any, dataOut: DataOutputStream) {
elem match {
case bytes: Array[Byte] =>
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
case pair: (Array[Byte], Array[Byte]) =>
dataOut.writeInt(pair._1.length)
dataOut.write(pair._1)
dataOut.writeInt(pair._2.length)
dataOut.write(pair._2)
case str: String =>
dataOut.writeUTF(str)
case other =>
throw new SparkException("Unexpected element type " + other.getClass)
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
// The right way to implement this would be to use TypeTags to get the full
// type of T. Since I don't want to introduce breaking changes throughout the
// entire Spark API, I have to use this hacky approach:
if (iter.hasNext) {
val first = iter.next()
val newIter = Seq(first).iterator ++ iter
first match {
case arr: Array[Byte] =>
newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes =>
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
case string: String =>
newIter.asInstanceOf[Iterator[String]].foreach { str =>
dataOut.writeUTF(str)
}
case pair: Tuple2[_, _] =>
pair._1 match {
case bytePair: Array[Byte] =>
newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair =>
dataOut.writeInt(pair._1.length)
dataOut.write(pair._1)
dataOut.writeInt(pair._2.length)
dataOut.write(pair._2)
}
case stringPair: String =>
newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair =>
dataOut.writeUTF(pair._1)
dataOut.writeUTF(pair._2)
}
case other =>
throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass)
}
case other =>
throw new SparkException("Unexpected element type " + first.getClass)
}
}
}

Expand All @@ -230,9 +251,7 @@ private[spark] object PythonRDD {

def writeToFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) {
writeToStream(item, file)
}
writeIteratorToStream(items, file)
file.close()
}

Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,22 @@ def test_save_as_textfile_with_unicode(self):
raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*")))
self.assertEqual(x, unicode(raw_contents.strip(), "utf-8"))

def test_transforming_cartesian_result(self):
# Regression test for SPARK-1034
rdd1 = self.sc.parallelize([1, 2])
rdd2 = self.sc.parallelize([3, 4])
cart = rdd1.cartesian(rdd2)
result = cart.map(lambda (x, y): x + y).collect()

def test_cartesian_on_textfile(self):
# Regression test for
path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
a = self.sc.textFile(path)
result = a.cartesian(a).collect()
(x, y) = result[0]
self.assertEqual("Hello World!", x.strip())
self.assertEqual("Hello World!", y.strip())


class TestIO(PySparkTestCase):

Expand Down

0 comments on commit cad3002

Please sign in to comment.