From b65606fa6e4e43ad5c4e5af9b6c637c805c1da94 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Wed, 4 Jun 2014 15:56:15 +0200 Subject: [PATCH] Add converter interface --- .../spark/api/python/PythonHadoopUtil.scala | 114 +++++++++++++----- .../apache/spark/api/python/PythonRDD.scala | 21 ++-- .../WriteInputFormatTestDataGenerator.scala | 8 ++ python/pyspark/tests.py | 10 ++ 4 files changed, 118 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 4fe5800df917e..4b532b283a022 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -18,11 +18,84 @@ package org.apache.spark.api.python import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ +import scala.util.{Failure, Success, Try} +trait Converter { + def convert(obj: Any): Any +} + +object DefaultConverter extends Converter { + + /** + * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or + * object representation + */ + private def convertWritable(writable: Writable): Any = { + import collection.JavaConversions._ + writable match { + case iw: IntWritable => SparkContext.intWritableConverter().convert(iw) + case dw: DoubleWritable => SparkContext.doubleWritableConverter().convert(dw) + case lw: LongWritable => SparkContext.longWritableConverter().convert(lw) + case fw: FloatWritable => SparkContext.floatWritableConverter().convert(fw) + case t: Text => SparkContext.stringWritableConverter().convert(t) + case bw: BooleanWritable => SparkContext.booleanWritableConverter().convert(bw) + case byw: BytesWritable => SparkContext.bytesWritableConverter().convert(byw) + case n: NullWritable => null + case aw: ArrayWritable => aw.get().map(convertWritable(_)) + case mw: MapWritable => mapAsJavaMap(mw.map{ case (k, v) => + (convertWritable(k), convertWritable(v)) + }.toMap) + case other => other + } + } + + def convert(obj: Any): Any = { + obj match { + case writable: Writable => + convertWritable(writable) + case _ => + obj + } + } +} + +class ConverterRegistry extends Logging { + + var keyConverter: Converter = DefaultConverter + var valueConverter: Converter = DefaultConverter + + def convertKey(obj: Any): Any = keyConverter.convert(obj) + + def convertValue(obj: Any): Any = valueConverter.convert(obj) + + def registerKeyConverter(converterClass: String) = { + keyConverter = register(converterClass) + logInfo(s"Loaded and registered key converter ($converterClass)") + } + + def registerValueConverter(converterClass: String) = { + valueConverter = register(converterClass) + logInfo(s"Loaded and registered value converter ($converterClass)") + } + + private def register(converterClass: String): Converter = { + Try { + val converter = Class.forName(converterClass).newInstance().asInstanceOf[Converter] + converter + } match { + case Success(s) => s + case Failure(err) => + logError(s"Failed to register converter: $converterClass") + throw err + } + + } +} + /** Utilities for working with Python objects -> Hadoop-related objects */ private[python] object PythonHadoopUtil { @@ -51,33 +124,18 @@ private[python] object PythonHadoopUtil { * Converts an RDD of key-value pairs, where key and/or value could be instances of * [[org.apache.hadoop.io.Writable]], into an RDD[(K, V)] */ - def convertRDD[K, V](rdd: RDD[(K, V)]) = { - rdd.map{ - case (k: Writable, v: Writable) => (convert(k).asInstanceOf[K], convert(v).asInstanceOf[V]) - case (k: Writable, v) => (convert(k).asInstanceOf[K], v.asInstanceOf[V]) - case (k, v: Writable) => (k.asInstanceOf[K], convert(v).asInstanceOf[V]) - case (k, v) => (k.asInstanceOf[K], v.asInstanceOf[V]) - } - } - - /** - * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or - * object representation - */ - private def convert(writable: Writable): Any = { - import collection.JavaConversions._ - writable match { - case iw: IntWritable => SparkContext.intWritableConverter().convert(iw) - case dw: DoubleWritable => SparkContext.doubleWritableConverter().convert(dw) - case lw: LongWritable => SparkContext.longWritableConverter().convert(lw) - case fw: FloatWritable => SparkContext.floatWritableConverter().convert(fw) - case t: Text => SparkContext.stringWritableConverter().convert(t) - case bw: BooleanWritable => SparkContext.booleanWritableConverter().convert(bw) - case byw: BytesWritable => SparkContext.bytesWritableConverter().convert(byw) - case n: NullWritable => null - case aw: ArrayWritable => aw.get().map(convert(_)) - case mw: MapWritable => mapAsJavaMap(mw.map{ case (k, v) => (convert(k), convert(v)) }.toMap) - case other => other + def convertRDD[K, V](rdd: RDD[(K, V)], + keyClass: String, + keyConverter: Option[String], + valueClass: String, + valueConverter: Option[String]) = { + rdd.mapPartitions { case iter => + val registry = new ConverterRegistry + keyConverter.foreach(registry.registerKeyConverter(_)) + valueConverter.foreach(registry.registerValueConverter(_)) + iter.map { case (k, v) => + (registry.convertKey(k).asInstanceOf[K], registry.convertValue(v).asInstanceOf[V]) + } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 12140fad24516..eabae80f80689 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -353,17 +353,20 @@ private[spark] object PythonRDD extends Logging { def sequenceFile[K, V]( sc: JavaSparkContext, path: String, - keyClass: String, - valueClass: String, + keyClassMaybeNull: String, + valueClassMaybeNull: String, keyConverter: String, valueConverter: String, minSplits: Int) = { + val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") + val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]] implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]] val kc = kcm.runtimeClass.asInstanceOf[Class[K]] val vc = vcm.runtimeClass.asInstanceOf[Class[V]] val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd) + val converted = PythonHadoopUtil.convertRDD[K, V]( + rdd, keyClass, Option(keyConverter), valueClass, Option(valueConverter)) JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) } @@ -386,7 +389,8 @@ private[spark] object PythonRDD extends Logging { val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd) + val converted = PythonHadoopUtil.convertRDD[K, V]( + rdd, keyClass, Option(keyConverter), valueClass, Option(valueConverter)) JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) } @@ -407,7 +411,8 @@ private[spark] object PythonRDD extends Logging { val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd) + val converted = PythonHadoopUtil.convertRDD[K, V]( + rdd, keyClass, Option(keyConverter), valueClass, Option(valueConverter)) JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) } @@ -451,7 +456,8 @@ private[spark] object PythonRDD extends Logging { val rdd = hadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd) + val converted = PythonHadoopUtil.convertRDD[K, V]( + rdd, keyClass, Option(keyConverter), valueClass, Option(valueConverter)) JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) } @@ -472,7 +478,8 @@ private[spark] object PythonRDD extends Logging { val rdd = hadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val converted = PythonHadoopUtil.convertRDD[K, V](rdd) + val converted = PythonHadoopUtil.convertRDD[K, V]( + rdd, keyClass, Option(keyConverter), valueClass, Option(valueConverter)) JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) } diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index 707c01d32ed0f..25acf80e50772 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -54,6 +54,14 @@ case class TestWritable(var str: String, var int: Int, var double: Double) exten } } +class TestConverter extends Converter { + import collection.JavaConversions._ + override def convert(obj: Any) = { + val m = obj.asInstanceOf[MapWritable] + seqAsJavaList(m.keySet.map(w => w.asInstanceOf[DoubleWritable].get()).toSeq) + } +} + /** * This object contains method to generate SequenceFile test data and write it to a * given directory (probably a temp directory) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 6b5cd96df597c..01929bbe7372f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -335,6 +335,16 @@ def test_bad_inputs(self): "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.Text")) + def test_converter(self): + basepath = self.tempdir.name + maps = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfmap/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable", + valueConverter="org.apache.spark.api.python.TestConverter").collect()) + em = [(1, [2.0]), (1, [3.0]), (2, [1.0]), (2, [1.0]), (2, [3.0]), (3, [2.0])] + self.assertEqual(maps, em) + class TestDaemon(unittest.TestCase): def connect(self, port):