diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 8b4db783979ec..40444c237b738 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -21,9 +21,8 @@ import scala.language.implicitConversions
import java.io._
import java.net.URI
-import java.util.Arrays
+import java.util.{Arrays, Properties, UUID}
import java.util.concurrent.atomic.AtomicInteger
-import java.util.{Properties, UUID}
import java.util.UUID.randomUUID
import scala.collection.{Map, Set}
import scala.collection.generic.Growable
@@ -41,6 +40,7 @@ import akka.actor.Props
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
+import org.apache.spark.executor.TriggerThreadDump
import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
@@ -51,7 +51,7 @@ import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage._
import org.apache.spark.ui.SparkUI
import org.apache.spark.ui.jobs.JobProgressListener
-import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils}
+import org.apache.spark.util._
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -361,6 +361,29 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
override protected def childValue(parent: Properties): Properties = new Properties(parent)
}
+ /**
+ * Called by the web UI to obtain executor thread dumps. This method may be expensive.
+ * Logs an error and returns None if we failed to obtain a thread dump, which could occur due
+ * to an executor being dead or unresponsive or due to network issues while sending the thread
+ * dump message back to the driver.
+ */
+ private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = {
+ try {
+ if (executorId == SparkContext.DRIVER_IDENTIFIER) {
+ Some(Utils.getThreadDump())
+ } else {
+ val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get
+ val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem)
+ Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef,
+ AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf)))
+ }
+ } catch {
+ case e: Exception =>
+ logError(s"Exception getting thread dump from executor $executorId", e)
+ None
+ }
+ }
+
private[spark] def getLocalProperties: Properties = localProperties.get()
private[spark] def setLocalProperties(props: Properties) {
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 697154d762d41..3711824a40cfc 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -131,7 +131,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
// Create a new ActorSystem using driver's Spark properties to run the backend.
val driverConf = new SparkConf().setAll(props)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf))
+ SparkEnv.executorActorSystemName,
+ hostname, port, driverConf, new SecurityManager(driverConf))
// set it
val sparkHostPort = hostname + ":" + boundPort
actorSystem.actorOf(
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index e24a15f015e1c..8b095e23f32ff 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -26,7 +26,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
-import akka.actor.ActorSystem
+import akka.actor.{Props, ActorSystem}
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
@@ -92,6 +92,10 @@ private[spark] class Executor(
}
}
+ // Create an actor for receiving RPCs from the driver
+ private val executorActor = env.actorSystem.actorOf(
+ Props(new ExecutorActor(executorId)), "ExecutorActor")
+
// Create our ClassLoader
// do this after SparkEnv creation so can access the SecurityManager
private val urlClassLoader = createClassLoader()
@@ -131,6 +135,7 @@ private[spark] class Executor(
def stop() {
env.metricsSystem.report()
+ env.actorSystem.stop(executorActor)
isStopped = true
threadPool.shutdown()
if (!isLocal) {
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
new file mode 100644
index 0000000000000..41925f7e97e84
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import akka.actor.Actor
+import org.apache.spark.Logging
+
+import org.apache.spark.util.{Utils, ActorLogReceive}
+
+/**
+ * Driver -> Executor message to trigger a thread dump.
+ */
+private[spark] case object TriggerThreadDump
+
+/**
+ * Actor that runs inside of executors to enable driver -> executor RPC.
+ */
+private[spark]
+class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging {
+
+ override def receiveWithLogging = {
+ case TriggerThreadDump =>
+ sender ! Utils.getThreadDump()
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index d08e1419e3e41..b63c7f191155c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -88,6 +88,10 @@ class BlockManagerMaster(
askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId))
}
+ def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId))
+ }
+
/**
* Remove a block from the slaves that have it. This can only be used to remove
* blocks that the driver knows about.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 5e375a2553979..685b2e11440fb 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -86,6 +86,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case GetPeers(blockManagerId) =>
sender ! getPeers(blockManagerId)
+ case GetActorSystemHostPortForExecutor(executorId) =>
+ sender ! getActorSystemHostPortForExecutor(executorId)
+
case GetMemoryStatus =>
sender ! memoryStatus
@@ -412,6 +415,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
Seq.empty
}
}
+
+ /**
+ * Returns the hostname and port of an executor's actor system, based on the Akka address of its
+ * BlockManagerSlaveActor.
+ */
+ private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ for (
+ blockManagerId <- blockManagerIdByExecutor.get(executorId);
+ info <- blockManagerInfo.get(blockManagerId);
+ host <- info.slaveActor.path.address.host;
+ port <- info.slaveActor.path.address.port
+ ) yield {
+ (host, port)
+ }
+ }
}
@DeveloperApi
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 291ddfcc113ac..3f32099d08cc9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -92,6 +92,8 @@ private[spark] object BlockManagerMessages {
case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
+ case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster
+
case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
case object StopBlockManagerMaster extends ToBlockManagerMaster
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
new file mode 100644
index 0000000000000..e9c755e36f716
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.exec
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.util.Try
+import scala.xml.{Text, Node}
+
+import org.apache.spark.ui.{UIUtils, WebUIPage}
+
+private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") {
+
+ private val sc = parent.sc
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val executorId = Option(request.getParameter("executorId")).getOrElse {
+ return Text(s"Missing executorId parameter")
+ }
+ val time = System.currentTimeMillis()
+ val maybeThreadDump = sc.get.getExecutorThreadDump(executorId)
+
+ val content = maybeThreadDump.map { threadDump =>
+ val dumpRows = threadDump.map { thread =>
+
+ }
+
+
+ }.getOrElse(Text("Error fetching thread dump"))
+ UIUtils.headerSparkPage(s"Thread dump for executor $executorId", content, parent)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
index b0e3bb3b552fd..048fee3ce1ff4 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
@@ -41,7 +41,10 @@ private case class ExecutorSummaryInfo(
totalShuffleWrite: Long,
maxMemory: Long)
-private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") {
+private[ui] class ExecutorsPage(
+ parent: ExecutorsTab,
+ threadDumpEnabled: Boolean)
+ extends WebUIPage("") {
private val listener = parent.listener
def render(request: HttpServletRequest): Seq[Node] = {
@@ -75,6 +78,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") {
Shuffle Write
+ {if (threadDumpEnabled) Thread Dump | else Seq.empty}
{execInfoSorted.map(execRow)}
@@ -133,6 +137,15 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") {
{Utils.bytesToString(info.totalShuffleWrite)}
|
+ {
+ if (threadDumpEnabled) {
+
+ Thread Dump
+ |
+ } else {
+ Seq.empty
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
index 9e0e71a51a408..ba97630f025c1 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
@@ -27,8 +27,14 @@ import org.apache.spark.ui.{SparkUI, SparkUITab}
private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") {
val listener = parent.executorsListener
+ val sc = parent.sc
+ val threadDumpEnabled =
+ sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true)
- attachPage(new ExecutorsPage(this))
+ attachPage(new ExecutorsPage(this, threadDumpEnabled))
+ if (threadDumpEnabled) {
+ attachPage(new ExecutorThreadDumpPage(this))
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index b5207360510dd..e3223403c17f4 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -59,6 +59,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
val failedStages = ListBuffer[StageInfo]()
val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData]
val stageIdToInfo = new HashMap[StageId, StageInfo]
+
+ // Number of completed and failed stages, may not actually equal to completedStages.size and
+ // failedStages.size respectively due to completedStage and failedStages only maintain the latest
+ // part of the stages, the earlier ones will be removed when there are too many stages for
+ // memory sake.
+ var numCompletedStages = 0
+ var numFailedStages = 0
// Map from pool name to a hash map (map from stage id to StageInfo).
val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]()
@@ -110,9 +117,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
activeStages.remove(stage.stageId)
if (stage.failureReason.isEmpty) {
completedStages += stage
+ numCompletedStages += 1
trimIfNecessary(completedStages)
} else {
failedStages += stage
+ numFailedStages += 1
trimIfNecessary(failedStages)
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala
index 6e718eecdd52a..83a7898071c9b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala
@@ -34,7 +34,9 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")
listener.synchronized {
val activeStages = listener.activeStages.values.toSeq
val completedStages = listener.completedStages.reverse.toSeq
+ val numCompletedStages = listener.numCompletedStages
val failedStages = listener.failedStages.reverse.toSeq
+ val numFailedStages = listener.numFailedStages
val now = System.currentTimeMillis
val activeStagesTable =
@@ -69,11 +71,11 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")
Completed Stages:
- {completedStages.size}
+ {numCompletedStages}
Failed Stages:
- {failedStages.size}
+ {numFailedStages}
@@ -86,9 +88,9 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")
}} ++
Active Stages ({activeStages.size})
++
activeStagesTable.toNodeSeq ++
- Completed Stages ({completedStages.size})
++
+ Completed Stages ({numCompletedStages})
++
completedStagesTable.toNodeSeq ++
- Failed Stages ({failedStages.size})
++
+ Failed Stages ({numFailedStages})
++
failedStagesTable.toNodeSeq
UIUtils.headerSparkPage("Spark Stages", content, parent)
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index 79e398eb8c104..10010bdfa1a51 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -212,4 +212,18 @@ private[spark] object AkkaUtils extends Logging {
logInfo(s"Connecting to $name: $url")
Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
}
+
+ def makeExecutorRef(
+ name: String,
+ conf: SparkConf,
+ host: String,
+ port: Int,
+ actorSystem: ActorSystem): ActorRef = {
+ val executorActorSystemName = SparkEnv.executorActorSystemName
+ Utils.checkHost(host, "Expected hostname")
+ val url = s"akka.tcp://$executorActorSystemName@$host:$port/user/$name"
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ logInfo(s"Connecting to $name: $url")
+ Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala
new file mode 100644
index 0000000000000..d4e0ad93b966a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/**
+ * Used for shipping per-thread stacktraces from the executors to driver.
+ */
+private[spark] case class ThreadStackTrace(
+ threadId: Long,
+ threadName: String,
+ threadState: Thread.State,
+ stackTrace: String)
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index a33046d2040d8..6ab94af9f3739 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util
import java.io._
+import java.lang.management.ManagementFactory
import java.net._
import java.nio.ByteBuffer
import java.util.jar.Attributes.Name
@@ -1611,6 +1612,18 @@ private[spark] object Utils extends Logging {
s"$className: $desc\n$st"
}
+ /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */
+ def getThreadDump(): Array[ThreadStackTrace] = {
+ // We need to filter out null values here because dumpAllThreads() may return null array
+ // elements for threads that are dead / don't exist.
+ val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null)
+ threadInfos.sortBy(_.getThreadId).map { case threadInfo =>
+ val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n")
+ ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName,
+ threadInfo.getThreadState, stackTrace)
+ }
+ }
+
/**
* Convert all spark properties set in the given SparkConf to a sequence of java options.
*/
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index a91c9ddeaef36..aec1e409db95c 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -177,6 +177,17 @@ class JsonProtocolSuite extends FunSuite {
deserializedBmRemoved)
}
+ test("FetchFailed backwards compatibility") {
+ // FetchFailed in Spark 1.1.0 does not have an "Message" property.
+ val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19,
+ "ignored")
+ val oldEvent = JsonProtocol.taskEndReasonToJson(fetchFailed)
+ .removeField({ _._1 == "Message" })
+ val expectedFetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19,
+ "Unknown reason")
+ assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent))
+ }
+
test("SparkListenerApplicationStart backwards compatibility") {
// SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property.
val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user")
@@ -185,6 +196,15 @@ class JsonProtocolSuite extends FunSuite {
assert(applicationStart === JsonProtocol.applicationStartFromJson(oldEvent))
}
+ test("ExecutorLostFailure backward compatibility") {
+ // ExecutorLostFailure in Spark 1.1.0 does not have an "Executor ID" property.
+ val executorLostFailure = ExecutorLostFailure("100")
+ val oldEvent = JsonProtocol.taskEndReasonToJson(executorLostFailure)
+ .removeField({ _._1 == "Executor ID" })
+ val expectedExecutorLostFailure = ExecutorLostFailure("Unknown")
+ assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent))
+ }
+
/** -------------------------- *
| Helper test running methods |
* --------------------------- */
diff --git a/dev/run-tests b/dev/run-tests
index 0e9eefa76a18b..de607e4344453 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -180,7 +180,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS
if [ -n "$_SQL_TESTS_ONLY" ]; then
# This must be an array of individual arguments. Otherwise, having one long string
#+ will be interpreted as a single test, which doesn't work.
- SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test")
+ SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test")
else
SBT_MAVEN_TEST_ARGS=("test")
fi
diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py
new file mode 100644
index 0000000000000..540dae785f6ea
--- /dev/null
+++ b/examples/src/main/python/mllib/dataset_example.py
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+An example of how to use SchemaRDD as a dataset for ML. Run with::
+ bin/spark-submit examples/src/main/python/mllib/dataset_example.py
+"""
+
+import os
+import sys
+import tempfile
+import shutil
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext
+from pyspark.mllib.util import MLUtils
+from pyspark.mllib.stat import Statistics
+
+
+def summarize(dataset):
+ print "schema: %s" % dataset.schema().json()
+ labels = dataset.map(lambda r: r.label)
+ print "label average: %f" % labels.mean()
+ features = dataset.map(lambda r: r.features)
+ summary = Statistics.colStats(features)
+ print "features average: %r" % summary.mean()
+
+if __name__ == "__main__":
+ if len(sys.argv) > 2:
+ print >> sys.stderr, "Usage: dataset_example.py "
+ exit(-1)
+ sc = SparkContext(appName="DatasetExample")
+ sqlCtx = SQLContext(sc)
+ if len(sys.argv) == 2:
+ input = sys.argv[1]
+ else:
+ input = "data/mllib/sample_libsvm_data.txt"
+ points = MLUtils.loadLibSVMFile(sc, input)
+ dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache()
+ summarize(dataset0)
+ tempdir = tempfile.NamedTemporaryFile(delete=False).name
+ os.unlink(tempdir)
+ print "Save dataset as a Parquet file to %s." % tempdir
+ dataset0.saveAsParquetFile(tempdir)
+ print "Load it back and summarize it again."
+ dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache()
+ summarize(dataset1)
+ shutil.rmtree(tempdir)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
new file mode 100644
index 0000000000000..f8d83f4ec7327
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import java.io.File
+
+import com.google.common.io.Files
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext, SchemaRDD}
+
+/**
+ * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with
+ * {{{
+ * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object DatasetExample {
+
+ case class Params(
+ input: String = "data/mllib/sample_libsvm_data.txt",
+ dataFormat: String = "libsvm") extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("DatasetExample") {
+ head("Dataset: an example app using SchemaRDD as a Dataset for ML.")
+ opt[String]("input")
+ .text(s"input path to dataset")
+ .action((x, c) => c.copy(input = x))
+ opt[String]("dataFormat")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ success
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"DatasetExample with $params")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._ // for implicit conversions
+
+ // Load input data
+ val origData: RDD[LabeledPoint] = params.dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, params.input)
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input)
+ }
+ println(s"Loaded ${origData.count()} instances from file: ${params.input}")
+
+ // Convert input data to SchemaRDD explicitly.
+ val schemaRDD: SchemaRDD = origData
+ println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}")
+ println(s"Converted to SchemaRDD with ${schemaRDD.count()} records")
+
+ // Select columns, using implicit conversion to SchemaRDD.
+ val labelsSchemaRDD: SchemaRDD = origData.select('label)
+ val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v }
+ val numLabels = labels.count()
+ val meanLabel = labels.fold(0.0)(_ + _) / numLabels
+ println(s"Selected label column with average value $meanLabel")
+
+ val featuresSchemaRDD: SchemaRDD = origData.select('features)
+ val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v }
+ val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
+ (summary, feat) => summary.add(feat),
+ (sum1, sum2) => sum1.merge(sum2))
+ println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")
+
+ val tmpDir = Files.createTempDir()
+ tmpDir.deleteOnExit()
+ val outputDir = new File(tmpDir, "dataset").toString
+ println(s"Saving to $outputDir as Parquet file.")
+ schemaRDD.saveAsParquetFile(outputDir)
+
+ println(s"Loading Parquet file with UDT from $outputDir.")
+ val newDataset = sqlContext.parquetFile(outputDir)
+
+ println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
+ val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v }
+ val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
+ (summary, feat) => summary.add(feat),
+ (sum1, sum2) => sum1.merge(sum2))
+ println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}")
+
+ sc.stop()
+ }
+
+}
diff --git a/mllib/pom.xml b/mllib/pom.xml
index fb7239e779aae..87a7ddaba97f2 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -45,6 +45,11 @@
spark-streaming_${scala.binary.version}
${project.version}
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+
org.eclipse.jetty
jetty-server
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 6af225b7f49f7..ac217edc619ab 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -17,22 +17,26 @@
package org.apache.spark.mllib.linalg
-import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable}
import java.util
+import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable}
import scala.annotation.varargs
import scala.collection.JavaConverters._
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
-import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.SparkException
+import org.apache.spark.mllib.util.NumericParser
+import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row}
+import org.apache.spark.sql.catalyst.types._
/**
* Represents a numeric vector, whose index type is Int and value type is Double.
*
* Note: Users should not implement this interface.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
sealed trait Vector extends Serializable {
/**
@@ -74,6 +78,65 @@ sealed trait Vector extends Serializable {
}
}
+/**
+ * User-defined type for [[Vector]] which allows easy interaction with SQL
+ * via [[org.apache.spark.sql.SchemaRDD]].
+ */
+private[spark] class VectorUDT extends UserDefinedType[Vector] {
+
+ override def sqlType: StructType = {
+ // type: 0 = sparse, 1 = dense
+ // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
+ // vectors. The "values" field is nullable because we might want to add binary vectors later,
+ // which uses "size" and "indices", but not "values".
+ StructType(Seq(
+ StructField("type", ByteType, nullable = false),
+ StructField("size", IntegerType, nullable = true),
+ StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true),
+ StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
+ }
+
+ override def serialize(obj: Any): Row = {
+ val row = new GenericMutableRow(4)
+ obj match {
+ case sv: SparseVector =>
+ row.setByte(0, 0)
+ row.setInt(1, sv.size)
+ row.update(2, sv.indices.toSeq)
+ row.update(3, sv.values.toSeq)
+ case dv: DenseVector =>
+ row.setByte(0, 1)
+ row.setNullAt(1)
+ row.setNullAt(2)
+ row.update(3, dv.values.toSeq)
+ }
+ row
+ }
+
+ override def deserialize(datum: Any): Vector = {
+ datum match {
+ case row: Row =>
+ require(row.length == 4,
+ s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
+ val tpe = row.getByte(0)
+ tpe match {
+ case 0 =>
+ val size = row.getInt(1)
+ val indices = row.getAs[Iterable[Int]](2).toArray
+ val values = row.getAs[Iterable[Double]](3).toArray
+ new SparseVector(size, indices, values)
+ case 1 =>
+ val values = row.getAs[Iterable[Double]](3).toArray
+ new DenseVector(values)
+ }
+ }
+ }
+
+ override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT"
+
+ override def userClass: Class[Vector] = classOf[Vector]
+}
+
/**
* Factory methods for [[org.apache.spark.mllib.linalg.Vector]].
* We don't use the name `Vector` because Scala imports
@@ -191,6 +254,7 @@ object Vectors {
/**
* A dense vector represented by a value array.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
class DenseVector(val values: Array[Double]) extends Vector {
override def size: Int = values.length
@@ -215,6 +279,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
* @param indices index array, assume to be strictly increasing.
* @param values value array, must have the same length as the index array.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
class SparseVector(
override val size: Int,
val indices: Array[Int],
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index cd651fe2d2ddf..93a84fe07b32a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -155,4 +155,15 @@ class VectorsSuite extends FunSuite {
throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.")
}
}
+
+ test("VectorUDT") {
+ val dv0 = Vectors.dense(Array.empty[Double])
+ val dv1 = Vectors.dense(1.0, 2.0)
+ val sv0 = Vectors.sparse(2, Array.empty, Array.empty)
+ val sv1 = Vectors.sparse(2, Array(1), Array(2.0))
+ val udt = new VectorUDT()
+ for (v <- Seq(dv0, dv1, sv0, sv1)) {
+ assert(v === udt.deserialize(udt.serialize(v)))
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
index c0a62e00432a3..5cb433232e714 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
@@ -30,7 +30,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext {
test("BaggedPoint RDD: without subsampling") {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42)
baggedRDD.collect().foreach { baggedPoint =>
assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
}
@@ -44,7 +44,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed)
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
@@ -60,7 +60,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed)
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
@@ -75,7 +75,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed)
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
@@ -91,7 +91,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed)
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index d0a0e102a1a07..c0c3dff31e7f8 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -29,6 +29,9 @@
import numpy as np
+from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
+ IntegerType, ByteType, Row
+
__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors']
@@ -106,7 +109,54 @@ def _format_float(f, digits=4):
return s
+class VectorUDT(UserDefinedType):
+ """
+ SQL user-defined type (UDT) for Vector.
+ """
+
+ @classmethod
+ def sqlType(cls):
+ return StructType([
+ StructField("type", ByteType(), False),
+ StructField("size", IntegerType(), True),
+ StructField("indices", ArrayType(IntegerType(), False), True),
+ StructField("values", ArrayType(DoubleType(), False), True)])
+
+ @classmethod
+ def module(cls):
+ return "pyspark.mllib.linalg"
+
+ @classmethod
+ def scalaUDT(cls):
+ return "org.apache.spark.mllib.linalg.VectorUDT"
+
+ def serialize(self, obj):
+ if isinstance(obj, SparseVector):
+ indices = [int(i) for i in obj.indices]
+ values = [float(v) for v in obj.values]
+ return (0, obj.size, indices, values)
+ elif isinstance(obj, DenseVector):
+ values = [float(v) for v in obj]
+ return (1, None, None, values)
+ else:
+ raise ValueError("cannot serialize %r of type %r" % (obj, type(obj)))
+
+ def deserialize(self, datum):
+ assert len(datum) == 4, \
+ "VectorUDT.deserialize given row with length %d but requires 4" % len(datum)
+ tpe = datum[0]
+ if tpe == 0:
+ return SparseVector(datum[1], datum[2], datum[3])
+ elif tpe == 1:
+ return DenseVector(datum[3])
+ else:
+ raise ValueError("do not recognize type %r" % tpe)
+
+
class Vector(object):
+
+ __UDT__ = VectorUDT()
+
"""
Abstract class for DenseVector and SparseVector
"""
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index d6fb87b378b4a..9fa4d6f6a2f5f 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -33,14 +33,14 @@
else:
import unittest
-from pyspark.serializers import PickleSerializer
-from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector
+from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
+from pyspark.serializers import PickleSerializer
+from pyspark.sql import SQLContext
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
-
_have_scipy = False
try:
import scipy.sparse
@@ -221,6 +221,39 @@ def test_col_with_different_rdds(self):
self.assertEqual(10, summary.count())
+class VectorUDTTests(PySparkTestCase):
+
+ dv0 = DenseVector([])
+ dv1 = DenseVector([1.0, 2.0])
+ sv0 = SparseVector(2, [], [])
+ sv1 = SparseVector(2, [1], [2.0])
+ udt = VectorUDT()
+
+ def test_json_schema(self):
+ self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt)
+
+ def test_serialization(self):
+ for v in [self.dv0, self.dv1, self.sv0, self.sv1]:
+ self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v)))
+
+ def test_infer_schema(self):
+ sqlCtx = SQLContext(self.sc)
+ rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)])
+ srdd = sqlCtx.inferSchema(rdd)
+ schema = srdd.schema()
+ field = [f for f in schema.fields if f.name == "features"][0]
+ self.assertEqual(field.dataType, self.udt)
+ vectors = srdd.map(lambda p: p.features).collect()
+ self.assertEqual(len(vectors), 2)
+ for v in vectors:
+ if isinstance(v, SparseVector):
+ self.assertEqual(v, self.sv1)
+ elif isinstance(v, DenseVector):
+ self.assertEqual(v, self.dv1)
+ else:
+ raise ValueError("expecting a vector but got %r of type %r" % (v, type(v)))
+
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 675df084bf303..d16c18bc79fe4 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -417,6 +417,75 @@ def fromJson(cls, json):
return StructType([StructField.fromJson(f) for f in json["fields"]])
+class UserDefinedType(DataType):
+ """
+ :: WARN: Spark Internal Use Only ::
+ SQL User-Defined Type (UDT).
+ """
+
+ @classmethod
+ def typeName(cls):
+ return cls.__name__.lower()
+
+ @classmethod
+ def sqlType(cls):
+ """
+ Underlying SQL storage type for this UDT.
+ """
+ raise NotImplementedError("UDT must implement sqlType().")
+
+ @classmethod
+ def module(cls):
+ """
+ The Python module of the UDT.
+ """
+ raise NotImplementedError("UDT must implement module().")
+
+ @classmethod
+ def scalaUDT(cls):
+ """
+ The class name of the paired Scala UDT.
+ """
+ raise NotImplementedError("UDT must have a paired Scala UDT.")
+
+ def serialize(self, obj):
+ """
+ Converts the a user-type object into a SQL datum.
+ """
+ raise NotImplementedError("UDT must implement serialize().")
+
+ def deserialize(self, datum):
+ """
+ Converts a SQL datum into a user-type object.
+ """
+ raise NotImplementedError("UDT must implement deserialize().")
+
+ def json(self):
+ return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
+
+ def jsonValue(self):
+ schema = {
+ "type": "udt",
+ "class": self.scalaUDT(),
+ "pyClass": "%s.%s" % (self.module(), type(self).__name__),
+ "sqlType": self.sqlType().jsonValue()
+ }
+ return schema
+
+ @classmethod
+ def fromJson(cls, json):
+ pyUDT = json["pyClass"]
+ split = pyUDT.rfind(".")
+ pyModule = pyUDT[:split]
+ pyClass = pyUDT[split+1:]
+ m = __import__(pyModule, globals(), locals(), [pyClass], -1)
+ UDT = getattr(m, pyClass)
+ return UDT()
+
+ def __eq__(self, other):
+ return type(self) == type(other)
+
+
_all_primitive_types = dict((v.typeName(), v)
for v in globals().itervalues()
if type(v) is PrimitiveTypeSingleton and
@@ -469,6 +538,12 @@ def _parse_datatype_json_string(json_string):
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
True
+ >>> check_datatype(ExamplePointUDT())
+ True
+ >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
+ ... StructField("point", ExamplePointUDT(), False)])
+ >>> check_datatype(structtype_with_udt)
+ True
"""
return _parse_datatype_json_value(json.loads(json_string))
@@ -488,7 +563,13 @@ def _parse_datatype_json_value(json_value):
else:
raise ValueError("Could not parse datatype: %s" % json_value)
else:
- return _all_complex_types[json_value["type"]].fromJson(json_value)
+ tpe = json_value["type"]
+ if tpe in _all_complex_types:
+ return _all_complex_types[tpe].fromJson(json_value)
+ elif tpe == 'udt':
+ return UserDefinedType.fromJson(json_value)
+ else:
+ raise ValueError("not supported type: %s" % tpe)
# Mapping Python types to Spark SQL DataType
@@ -509,7 +590,18 @@ def _parse_datatype_json_value(json_value):
def _infer_type(obj):
- """Infer the DataType from obj"""
+ """Infer the DataType from obj
+
+ >>> p = ExamplePoint(1.0, 2.0)
+ >>> _infer_type(p)
+ ExamplePointUDT
+ """
+ if obj is None:
+ raise ValueError("Can not infer type for None")
+
+ if hasattr(obj, '__UDT__'):
+ return obj.__UDT__
+
dataType = _type_mappings.get(type(obj))
if dataType is not None:
return dataType()
@@ -558,6 +650,93 @@ def _infer_schema(row):
return StructType(fields)
+def _need_python_to_sql_conversion(dataType):
+ """
+ Checks whether we need python to sql conversion for the given type.
+ For now, only UDTs need this conversion.
+
+ >>> _need_python_to_sql_conversion(DoubleType())
+ False
+ >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
+ ... StructField("values", ArrayType(DoubleType(), False), False)])
+ >>> _need_python_to_sql_conversion(schema0)
+ False
+ >>> _need_python_to_sql_conversion(ExamplePointUDT())
+ True
+ >>> schema1 = ArrayType(ExamplePointUDT(), False)
+ >>> _need_python_to_sql_conversion(schema1)
+ True
+ >>> schema2 = StructType([StructField("label", DoubleType(), False),
+ ... StructField("point", ExamplePointUDT(), False)])
+ >>> _need_python_to_sql_conversion(schema2)
+ True
+ """
+ if isinstance(dataType, StructType):
+ return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
+ elif isinstance(dataType, ArrayType):
+ return _need_python_to_sql_conversion(dataType.elementType)
+ elif isinstance(dataType, MapType):
+ return _need_python_to_sql_conversion(dataType.keyType) or \
+ _need_python_to_sql_conversion(dataType.valueType)
+ elif isinstance(dataType, UserDefinedType):
+ return True
+ else:
+ return False
+
+
+def _python_to_sql_converter(dataType):
+ """
+ Returns a converter that converts a Python object into a SQL datum for the given type.
+
+ >>> conv = _python_to_sql_converter(DoubleType())
+ >>> conv(1.0)
+ 1.0
+ >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
+ >>> conv([1.0, 2.0])
+ [1.0, 2.0]
+ >>> conv = _python_to_sql_converter(ExamplePointUDT())
+ >>> conv(ExamplePoint(1.0, 2.0))
+ [1.0, 2.0]
+ >>> schema = StructType([StructField("label", DoubleType(), False),
+ ... StructField("point", ExamplePointUDT(), False)])
+ >>> conv = _python_to_sql_converter(schema)
+ >>> conv((1.0, ExamplePoint(1.0, 2.0)))
+ (1.0, [1.0, 2.0])
+ """
+ if not _need_python_to_sql_conversion(dataType):
+ return lambda x: x
+
+ if isinstance(dataType, StructType):
+ names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
+ converters = map(_python_to_sql_converter, types)
+
+ def converter(obj):
+ if isinstance(obj, dict):
+ return tuple(c(obj.get(n)) for n, c in zip(names, converters))
+ elif isinstance(obj, tuple):
+ if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"):
+ return tuple(c(v) for c, v in zip(converters, obj))
+ elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
+ d = dict(obj)
+ return tuple(c(d.get(n)) for n, c in zip(names, converters))
+ else:
+ return tuple(c(v) for c, v in zip(converters, obj))
+ else:
+ raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
+ return converter
+ elif isinstance(dataType, ArrayType):
+ element_converter = _python_to_sql_converter(dataType.elementType)
+ return lambda a: [element_converter(v) for v in a]
+ elif isinstance(dataType, MapType):
+ key_converter = _python_to_sql_converter(dataType.keyType)
+ value_converter = _python_to_sql_converter(dataType.valueType)
+ return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
+ elif isinstance(dataType, UserDefinedType):
+ return lambda obj: dataType.serialize(obj)
+ else:
+ raise ValueError("Unexpected type %r" % dataType)
+
+
def _has_nulltype(dt):
""" Return whether there is NullType in `dt` or not """
if isinstance(dt, StructType):
@@ -818,11 +997,22 @@ def _verify_type(obj, dataType):
Traceback (most recent call last):
...
ValueError:...
+ >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
+ >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
"""
# all objects are nullable
if obj is None:
return
+ if isinstance(dataType, UserDefinedType):
+ if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
+ raise ValueError("%r is not an instance of type %r" % (obj, dataType))
+ _verify_type(dataType.serialize(obj), dataType.sqlType())
+ return
+
_type = type(dataType)
assert _type in _acceptable_types, "unkown datatype: %s" % dataType
@@ -897,6 +1087,8 @@ def _has_struct_or_date(dt):
return _has_struct_or_date(dt.valueType)
elif isinstance(dt, DateType):
return True
+ elif isinstance(dt, UserDefinedType):
+ return True
return False
@@ -967,6 +1159,9 @@ def Dict(d):
elif isinstance(dataType, DateType):
return datetime.date
+ elif isinstance(dataType, UserDefinedType):
+ return lambda datum: dataType.deserialize(datum)
+
elif not isinstance(dataType, StructType):
raise Exception("unexpected data type: %s" % dataType)
@@ -1244,6 +1439,10 @@ def applySchema(self, rdd, schema):
for row in rows:
_verify_type(row, schema)
+ # convert python objects to sql data
+ converter = _python_to_sql_converter(schema)
+ rdd = rdd.map(converter)
+
batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
jrdd = self._pythonToJava(rdd._jrdd, batched)
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
@@ -1877,6 +2076,7 @@ def _test():
# let doctest run in pyspark.sql, so DataTypes can be picklable
import pyspark.sql
from pyspark.sql import Row, SQLContext
+ from pyspark.tests import ExamplePoint, ExamplePointUDT
globs = pyspark.sql.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
@@ -1888,6 +2088,8 @@ def _test():
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
+ globs['ExamplePoint'] = ExamplePoint
+ globs['ExamplePointUDT'] = ExamplePointUDT
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index dc5d1312fe994..6e7f24c6066a5 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -49,7 +49,8 @@
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
CloudPickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
-from pyspark.sql import SQLContext, IntegerType, Row, ArrayType
+from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
+ UserDefinedType, DoubleType
from pyspark import shuffle
_have_scipy = False
@@ -694,8 +695,65 @@ def heavy_foo(x):
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
+class ExamplePointUDT(UserDefinedType):
+ """
+ User-defined type (UDT) for ExamplePoint.
+ """
+
+ @classmethod
+ def sqlType(self):
+ return ArrayType(DoubleType(), False)
+
+ @classmethod
+ def module(cls):
+ return 'pyspark.tests'
+
+ @classmethod
+ def scalaUDT(cls):
+ return 'org.apache.spark.sql.test.ExamplePointUDT'
+
+ def serialize(self, obj):
+ return [obj.x, obj.y]
+
+ def deserialize(self, datum):
+ return ExamplePoint(datum[0], datum[1])
+
+
+class ExamplePoint:
+ """
+ An example class to demonstrate UDT in Scala, Java, and Python.
+ """
+
+ __UDT__ = ExamplePointUDT()
+
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __repr__(self):
+ return "ExamplePoint(%s,%s)" % (self.x, self.y)
+
+ def __str__(self):
+ return "(%s,%s)" % (self.x, self.y)
+
+ def __eq__(self, other):
+ return isinstance(other, ExamplePoint) and \
+ other.x == self.x and other.y == self.y
+
+
class SQLTests(ReusedPySparkTestCase):
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name)
+
def setUp(self):
self.sqlCtx = SQLContext(self.sc)
@@ -824,6 +882,39 @@ def test_convert_row_to_dict(self):
row = self.sqlCtx.sql("select l[0].a AS la from test").first()
self.assertEqual(1, row.asDict()["la"])
+ def test_infer_schema_with_udt(self):
+ from pyspark.tests import ExamplePoint, ExamplePointUDT
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ srdd = self.sqlCtx.inferSchema(rdd)
+ schema = srdd.schema()
+ field = [f for f in schema.fields if f.name == "point"][0]
+ self.assertEqual(type(field.dataType), ExamplePointUDT)
+ srdd.registerTempTable("labeled_point")
+ point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point
+ self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+ def test_apply_schema_with_udt(self):
+ from pyspark.tests import ExamplePoint, ExamplePointUDT
+ row = (1.0, ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ schema = StructType([StructField("label", DoubleType(), False),
+ StructField("point", ExamplePointUDT(), False)])
+ srdd = self.sqlCtx.applySchema(rdd, schema)
+ point = srdd.first().point
+ self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
+ def test_parquet_with_udt(self):
+ from pyspark.tests import ExamplePoint
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ srdd0 = self.sqlCtx.inferSchema(rdd)
+ output_dir = os.path.join(self.tempdir.name, "labeled_point")
+ srdd0.saveAsParquetFile(output_dir)
+ srdd1 = self.sqlCtx.parquetFile(output_dir)
+ point = srdd1.first().point
+ self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
class InputFormatTests(ReusedPySparkTestCase):
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index fa1786e74bb3e..18c96da2f87fb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -34,320 +34,366 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
override def toString = s"scalaUDF(${children.mkString(",")})"
+ // scalastyle:off
+
/** This method has been generated by this script
(1 to 22).map { x =>
val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _)
- val evals = (0 to x - 1).map(x => s"children($x).eval(input)").reduce(_ + ",\n " + _)
+ val evals = (0 to x - 1).map(x => s" ScalaReflection.convertToScala(children($x).eval(input), children($x).dataType)").reduce(_ + ",\n " + _)
s"""
case $x =>
function.asInstanceOf[($anys) => Any](
- $evals)
+ $evals)
"""
- }
+ }.foreach(println)
*/
- // scalastyle:off
override def eval(input: Row): Any = {
val result = children.size match {
case 0 => function.asInstanceOf[() => Any]()
- case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input))
+ case 1 =>
+ function.asInstanceOf[(Any) => Any](
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType))
+
+
case 2 =>
function.asInstanceOf[(Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType))
+
+
case 3 =>
function.asInstanceOf[(Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType))
+
+
case 4 =>
function.asInstanceOf[(Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType))
+
+
case 5 =>
function.asInstanceOf[(Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType))
+
+
case 6 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType))
+
+
case 7 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType))
+
+
case 8 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input),
- children(7).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
+ ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType))
+
+
case 9 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input),
- children(7).eval(input),
- children(8).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
+ ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
+ ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType))
+
+
case 10 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input),
- children(7).eval(input),
- children(8).eval(input),
- children(9).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
+ ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
+ ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
+ ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType))
+
+
case 11 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input),
- children(7).eval(input),
- children(8).eval(input),
- children(9).eval(input),
- children(10).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
+ ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
+ ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
+ ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
+ ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType))
+
+
case 12 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input),
- children(7).eval(input),
- children(8).eval(input),
- children(9).eval(input),
- children(10).eval(input),
- children(11).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
+ ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
+ ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
+ ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
+ ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
+ ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType))
+
+
case 13 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input),
- children(7).eval(input),
- children(8).eval(input),
- children(9).eval(input),
- children(10).eval(input),
- children(11).eval(input),
- children(12).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
+ ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
+ ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
+ ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
+ ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
+ ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
+ ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType))
+
+
case 14 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input),
- children(7).eval(input),
- children(8).eval(input),
- children(9).eval(input),
- children(10).eval(input),
- children(11).eval(input),
- children(12).eval(input),
- children(13).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
+ ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
+ ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
+ ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
+ ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
+ ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
+ ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
+ ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType))
+
+
case 15 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input),
- children(7).eval(input),
- children(8).eval(input),
- children(9).eval(input),
- children(10).eval(input),
- children(11).eval(input),
- children(12).eval(input),
- children(13).eval(input),
- children(14).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
+ ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
+ ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
+ ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
+ ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
+ ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
+ ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
+ ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
+ ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType))
+
+
case 16 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input),
- children(7).eval(input),
- children(8).eval(input),
- children(9).eval(input),
- children(10).eval(input),
- children(11).eval(input),
- children(12).eval(input),
- children(13).eval(input),
- children(14).eval(input),
- children(15).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
+ ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
+ ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
+ ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
+ ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
+ ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
+ ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
+ ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
+ ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType),
+ ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType))
+
+
case 17 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input),
- children(7).eval(input),
- children(8).eval(input),
- children(9).eval(input),
- children(10).eval(input),
- children(11).eval(input),
- children(12).eval(input),
- children(13).eval(input),
- children(14).eval(input),
- children(15).eval(input),
- children(16).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
+ ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
+ ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
+ ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
+ ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
+ ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
+ ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
+ ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
+ ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType),
+ ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType),
+ ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType))
+
+
case 18 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input),
- children(7).eval(input),
- children(8).eval(input),
- children(9).eval(input),
- children(10).eval(input),
- children(11).eval(input),
- children(12).eval(input),
- children(13).eval(input),
- children(14).eval(input),
- children(15).eval(input),
- children(16).eval(input),
- children(17).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
+ ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
+ ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
+ ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
+ ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
+ ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
+ ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
+ ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
+ ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType),
+ ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType),
+ ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType),
+ ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType))
+
+
case 19 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input),
- children(7).eval(input),
- children(8).eval(input),
- children(9).eval(input),
- children(10).eval(input),
- children(11).eval(input),
- children(12).eval(input),
- children(13).eval(input),
- children(14).eval(input),
- children(15).eval(input),
- children(16).eval(input),
- children(17).eval(input),
- children(18).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
+ ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
+ ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
+ ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
+ ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
+ ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
+ ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
+ ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
+ ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType),
+ ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType),
+ ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType),
+ ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType),
+ ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType))
+
+
case 20 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input),
- children(7).eval(input),
- children(8).eval(input),
- children(9).eval(input),
- children(10).eval(input),
- children(11).eval(input),
- children(12).eval(input),
- children(13).eval(input),
- children(14).eval(input),
- children(15).eval(input),
- children(16).eval(input),
- children(17).eval(input),
- children(18).eval(input),
- children(19).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
+ ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
+ ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
+ ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
+ ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
+ ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
+ ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
+ ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
+ ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType),
+ ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType),
+ ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType),
+ ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType),
+ ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType),
+ ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType))
+
+
case 21 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input),
- children(7).eval(input),
- children(8).eval(input),
- children(9).eval(input),
- children(10).eval(input),
- children(11).eval(input),
- children(12).eval(input),
- children(13).eval(input),
- children(14).eval(input),
- children(15).eval(input),
- children(16).eval(input),
- children(17).eval(input),
- children(18).eval(input),
- children(19).eval(input),
- children(20).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
+ ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
+ ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
+ ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
+ ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
+ ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
+ ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
+ ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
+ ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType),
+ ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType),
+ ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType),
+ ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType),
+ ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType),
+ ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType),
+ ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType))
+
+
case 22 =>
function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any](
- children(0).eval(input),
- children(1).eval(input),
- children(2).eval(input),
- children(3).eval(input),
- children(4).eval(input),
- children(5).eval(input),
- children(6).eval(input),
- children(7).eval(input),
- children(8).eval(input),
- children(9).eval(input),
- children(10).eval(input),
- children(11).eval(input),
- children(12).eval(input),
- children(13).eval(input),
- children(14).eval(input),
- children(15).eval(input),
- children(16).eval(input),
- children(17).eval(input),
- children(18).eval(input),
- children(19).eval(input),
- children(20).eval(input),
- children(21).eval(input))
+ ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType),
+ ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType),
+ ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType),
+ ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType),
+ ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType),
+ ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType),
+ ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType),
+ ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType),
+ ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType),
+ ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType),
+ ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType),
+ ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType),
+ ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType),
+ ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType),
+ ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType),
+ ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType),
+ ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType),
+ ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType),
+ ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType),
+ ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType),
+ ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType),
+ ScalaReflection.convertToScala(children(21).eval(input), children(21).dataType))
+
}
// scalastyle:on
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index e1b5992a36e5f..5dd19dd12d8dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -71,6 +71,8 @@ object DataType {
case JSortedObject(
("class", JString(udtClass)),
+ ("pyClass", _),
+ ("sqlType", _),
("type", JString("udt"))) =>
Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]]
}
@@ -593,6 +595,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
/** Underlying storage type for this UDT */
def sqlType: DataType
+ /** Paired Python UDT class, if exists. */
+ def pyUDT: String = null
+
/**
* Convert the user type to a SQL datum
*
@@ -606,7 +611,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
override private[sql] def jsonValue: JValue = {
("type" -> "udt") ~
- ("class" -> this.getClass.getName)
+ ("class" -> this.getClass.getName) ~
+ ("pyClass" -> pyUDT) ~
+ ("sqlType" -> sqlType.jsonValue)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 9e61d18f7e926..84eaf401f240c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{Optimizer, DefaultOptimizer}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.types.UserDefinedType
import org.apache.spark.sql.execution.{SparkStrategies, _}
import org.apache.spark.sql.json._
import org.apache.spark.sql.parquet.ParquetRelation
@@ -483,6 +484,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
case ArrayType(_, _) => true
case MapType(_, _, _) => true
case StructType(_) => true
+ case udt: UserDefinedType[_] => needsConversion(udt.sqlType)
case other => false
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 997669051ed07..a83cf5d441d1e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -135,6 +135,8 @@ object EvaluatePython {
case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
}.asJava
+ case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType)
+
case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal
// Pyrolite can handle Timestamp
@@ -177,6 +179,9 @@ object EvaluatePython {
case (c: java.util.Calendar, TimestampType) =>
new java.sql.Timestamp(c.getTime().getTime())
+ case (_, udt: UserDefinedType[_]) =>
+ fromJava(obj, udt.sqlType)
+
case (c: Int, ByteType) => c.toByte
case (c: Long, ByteType) => c.toByte
case (c: Int, ShortType) => c.toShort
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
new file mode 100644
index 0000000000000..b9569e96c0312
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import java.util
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
+import org.apache.spark.sql.catalyst.types._
+
+/**
+ * An example class to demonstrate UDT in Scala, Java, and Python.
+ * @param x x coordinate
+ * @param y y coordinate
+ */
+@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
+private[sql] class ExamplePoint(val x: Double, val y: Double)
+
+/**
+ * User-defined type for [[ExamplePoint]].
+ */
+private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
+
+ override def sqlType: DataType = ArrayType(DoubleType, false)
+
+ override def pyUDT: String = "pyspark.tests.ExamplePointUDT"
+
+ override def serialize(obj: Any): Seq[Double] = {
+ obj match {
+ case p: ExamplePoint =>
+ Seq(p.x, p.y)
+ }
+ }
+
+ override def deserialize(datum: Any): ExamplePoint = {
+ datum match {
+ case values: Seq[_] =>
+ val xy = values.asInstanceOf[Seq[Double]]
+ assert(xy.length == 2)
+ new ExamplePoint(xy(0), xy(1))
+ case values: util.ArrayList[_] =>
+ val xy = values.asInstanceOf[util.ArrayList[Double]].asScala
+ new ExamplePoint(xy(0), xy(1))
+ }
+ }
+
+ override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
index 1bc15146f0fe8..3fa4a7c6481d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.types.UserDefinedType
-
protected[sql] object DataTypeConversions {
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 666235e57f812..1806a1dd82023 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -60,13 +60,13 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
}
class UserDefinedTypeSuite extends QueryTest {
+ val points = Seq(
+ MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
+ MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0))))
+ val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points)
- test("register user type: MyDenseVector for MyLabeledPoint") {
- val points = Seq(
- MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
- MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0))))
- val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points)
+ test("register user type: MyDenseVector for MyLabeledPoint") {
val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v }
val labelsArrays: Array[Double] = labels.collect()
assert(labelsArrays.size === 2)
@@ -80,4 +80,12 @@ class UserDefinedTypeSuite extends QueryTest {
assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0))))
assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0))))
}
+
+ test("UDTs and UDFs") {
+ registerFunction("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector])
+ pointsRDD.registerTempTable("points")
+ checkAnswer(
+ sql("SELECT testType(features) from points"),
+ Seq(Row(true), Row(true)))
+ }
}