Skip to content

Commit

Permalink
Add converter interface
Browse files Browse the repository at this point in the history
  • Loading branch information
MLnick committed Jun 4, 2014
1 parent 5757f6e commit b65606f
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 35 deletions.
114 changes: 86 additions & 28 deletions core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,84 @@
package org.apache.spark.api.python

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.{Logging, SparkContext}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io._
import scala.util.{Failure, Success, Try}


trait Converter {
def convert(obj: Any): Any
}

object DefaultConverter extends Converter {

/**
* Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
* object representation
*/
private def convertWritable(writable: Writable): Any = {
import collection.JavaConversions._
writable match {
case iw: IntWritable => SparkContext.intWritableConverter().convert(iw)
case dw: DoubleWritable => SparkContext.doubleWritableConverter().convert(dw)
case lw: LongWritable => SparkContext.longWritableConverter().convert(lw)
case fw: FloatWritable => SparkContext.floatWritableConverter().convert(fw)
case t: Text => SparkContext.stringWritableConverter().convert(t)
case bw: BooleanWritable => SparkContext.booleanWritableConverter().convert(bw)
case byw: BytesWritable => SparkContext.bytesWritableConverter().convert(byw)
case n: NullWritable => null
case aw: ArrayWritable => aw.get().map(convertWritable(_))
case mw: MapWritable => mapAsJavaMap(mw.map{ case (k, v) =>
(convertWritable(k), convertWritable(v))
}.toMap)
case other => other
}
}

def convert(obj: Any): Any = {
obj match {
case writable: Writable =>
convertWritable(writable)
case _ =>
obj
}
}
}

class ConverterRegistry extends Logging {

var keyConverter: Converter = DefaultConverter
var valueConverter: Converter = DefaultConverter

def convertKey(obj: Any): Any = keyConverter.convert(obj)

def convertValue(obj: Any): Any = valueConverter.convert(obj)

def registerKeyConverter(converterClass: String) = {
keyConverter = register(converterClass)
logInfo(s"Loaded and registered key converter ($converterClass)")
}

def registerValueConverter(converterClass: String) = {
valueConverter = register(converterClass)
logInfo(s"Loaded and registered value converter ($converterClass)")
}

private def register(converterClass: String): Converter = {
Try {
val converter = Class.forName(converterClass).newInstance().asInstanceOf[Converter]
converter
} match {
case Success(s) => s
case Failure(err) =>
logError(s"Failed to register converter: $converterClass")
throw err
}

}
}

/** Utilities for working with Python objects -> Hadoop-related objects */
private[python] object PythonHadoopUtil {

Expand Down Expand Up @@ -51,33 +124,18 @@ private[python] object PythonHadoopUtil {
* Converts an RDD of key-value pairs, where key and/or value could be instances of
* [[org.apache.hadoop.io.Writable]], into an RDD[(K, V)]
*/
def convertRDD[K, V](rdd: RDD[(K, V)]) = {
rdd.map{
case (k: Writable, v: Writable) => (convert(k).asInstanceOf[K], convert(v).asInstanceOf[V])
case (k: Writable, v) => (convert(k).asInstanceOf[K], v.asInstanceOf[V])
case (k, v: Writable) => (k.asInstanceOf[K], convert(v).asInstanceOf[V])
case (k, v) => (k.asInstanceOf[K], v.asInstanceOf[V])
}
}

/**
* Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
* object representation
*/
private def convert(writable: Writable): Any = {
import collection.JavaConversions._
writable match {
case iw: IntWritable => SparkContext.intWritableConverter().convert(iw)
case dw: DoubleWritable => SparkContext.doubleWritableConverter().convert(dw)
case lw: LongWritable => SparkContext.longWritableConverter().convert(lw)
case fw: FloatWritable => SparkContext.floatWritableConverter().convert(fw)
case t: Text => SparkContext.stringWritableConverter().convert(t)
case bw: BooleanWritable => SparkContext.booleanWritableConverter().convert(bw)
case byw: BytesWritable => SparkContext.bytesWritableConverter().convert(byw)
case n: NullWritable => null
case aw: ArrayWritable => aw.get().map(convert(_))
case mw: MapWritable => mapAsJavaMap(mw.map{ case (k, v) => (convert(k), convert(v)) }.toMap)
case other => other
def convertRDD[K, V](rdd: RDD[(K, V)],
keyClass: String,
keyConverter: Option[String],
valueClass: String,
valueConverter: Option[String]) = {
rdd.mapPartitions { case iter =>
val registry = new ConverterRegistry
keyConverter.foreach(registry.registerKeyConverter(_))
valueConverter.foreach(registry.registerValueConverter(_))
iter.map { case (k, v) =>
(registry.convertKey(k).asInstanceOf[K], registry.convertValue(v).asInstanceOf[V])
}
}
}

Expand Down
21 changes: 14 additions & 7 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -353,17 +353,20 @@ private[spark] object PythonRDD extends Logging {
def sequenceFile[K, V](
sc: JavaSparkContext,
path: String,
keyClass: String,
valueClass: String,
keyClassMaybeNull: String,
valueClassMaybeNull: String,
keyConverter: String,
valueConverter: String,
minSplits: Int) = {
val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text")
val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text")
implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]]
implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]]
val kc = kcm.runtimeClass.asInstanceOf[Class[K]]
val vc = vcm.runtimeClass.asInstanceOf[Class[V]]
val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits)
val converted = PythonHadoopUtil.convertRDD[K, V](rdd)
val converted = PythonHadoopUtil.convertRDD[K, V](
rdd, keyClass, Option(keyConverter), valueClass, Option(valueConverter))
JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
}

Expand All @@ -386,7 +389,8 @@ private[spark] object PythonRDD extends Logging {
val rdd =
newAPIHadoopRDDFromClassNames[K, V, F](sc,
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val converted = PythonHadoopUtil.convertRDD[K, V](rdd)
val converted = PythonHadoopUtil.convertRDD[K, V](
rdd, keyClass, Option(keyConverter), valueClass, Option(valueConverter))
JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
}

Expand All @@ -407,7 +411,8 @@ private[spark] object PythonRDD extends Logging {
val rdd =
newAPIHadoopRDDFromClassNames[K, V, F](sc,
None, inputFormatClass, keyClass, valueClass, conf)
val converted = PythonHadoopUtil.convertRDD[K, V](rdd)
val converted = PythonHadoopUtil.convertRDD[K, V](
rdd, keyClass, Option(keyConverter), valueClass, Option(valueConverter))
JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
}

Expand Down Expand Up @@ -451,7 +456,8 @@ private[spark] object PythonRDD extends Logging {
val rdd =
hadoopRDDFromClassNames[K, V, F](sc,
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val converted = PythonHadoopUtil.convertRDD[K, V](rdd)
val converted = PythonHadoopUtil.convertRDD[K, V](
rdd, keyClass, Option(keyConverter), valueClass, Option(valueConverter))
JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
}

Expand All @@ -472,7 +478,8 @@ private[spark] object PythonRDD extends Logging {
val rdd =
hadoopRDDFromClassNames[K, V, F](sc,
None, inputFormatClass, keyClass, valueClass, conf)
val converted = PythonHadoopUtil.convertRDD[K, V](rdd)
val converted = PythonHadoopUtil.convertRDD[K, V](
rdd, keyClass, Option(keyConverter), valueClass, Option(valueConverter))
JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ case class TestWritable(var str: String, var int: Int, var double: Double) exten
}
}

class TestConverter extends Converter {
import collection.JavaConversions._
override def convert(obj: Any) = {
val m = obj.asInstanceOf[MapWritable]
seqAsJavaList(m.keySet.map(w => w.asInstanceOf[DoubleWritable].get()).toSeq)
}
}

/**
* This object contains method to generate SequenceFile test data and write it to a
* given directory (probably a temp directory)
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,16 @@ def test_bad_inputs(self):
"org.apache.hadoop.io.IntWritable",
"org.apache.hadoop.io.Text"))

def test_converter(self):
basepath = self.tempdir.name
maps = sorted(self.sc.sequenceFile(
basepath + "/sftestdata/sfmap/",
"org.apache.hadoop.io.IntWritable",
"org.apache.hadoop.io.MapWritable",
valueConverter="org.apache.spark.api.python.TestConverter").collect())
em = [(1, [2.0]), (1, [3.0]), (2, [1.0]), (2, [1.0]), (2, [3.0]), (3, [2.0])]
self.assertEqual(maps, em)


class TestDaemon(unittest.TestCase):
def connect(self, port):
Expand Down

0 comments on commit b65606f

Please sign in to comment.