diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8b4db783979ec..40444c237b738 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -21,9 +21,8 @@ import scala.language.implicitConversions import java.io._ import java.net.URI -import java.util.Arrays +import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.AtomicInteger -import java.util.{Properties, UUID} import java.util.UUID.randomUUID import scala.collection.{Map, Set} import scala.collection.generic.Growable @@ -41,6 +40,7 @@ import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} +import org.apache.spark.executor.TriggerThreadDump import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ @@ -51,7 +51,7 @@ import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ import org.apache.spark.ui.SparkUI import org.apache.spark.ui.jobs.JobProgressListener -import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} +import org.apache.spark.util._ /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -361,6 +361,29 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { override protected def childValue(parent: Properties): Properties = new Properties(parent) } + /** + * Called by the web UI to obtain executor thread dumps. This method may be expensive. + * Logs an error and returns None if we failed to obtain a thread dump, which could occur due + * to an executor being dead or unresponsive or due to network issues while sending the thread + * dump message back to the driver. + */ + private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = { + try { + if (executorId == SparkContext.DRIVER_IDENTIFIER) { + Some(Utils.getThreadDump()) + } else { + val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get + val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) + Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, + AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + } + } catch { + case e: Exception => + logError(s"Exception getting thread dump from executor $executorId", e) + None + } + } + private[spark] def getLocalProperties: Properties = localProperties.get() private[spark] def setLocalProperties(props: Properties) { diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 697154d762d41..3711824a40cfc 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -131,7 +131,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Create a new ActorSystem using driver's Spark properties to run the backend. val driverConf = new SparkConf().setAll(props) val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf)) + SparkEnv.executorActorSystemName, + hostname, port, driverConf, new SecurityManager(driverConf)) // set it val sparkHostPort = hostname + ":" + boundPort actorSystem.actorOf( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index e24a15f015e1c..8b095e23f32ff 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import akka.actor.ActorSystem +import akka.actor.{Props, ActorSystem} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -92,6 +92,10 @@ private[spark] class Executor( } } + // Create an actor for receiving RPCs from the driver + private val executorActor = env.actorSystem.actorOf( + Props(new ExecutorActor(executorId)), "ExecutorActor") + // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager private val urlClassLoader = createClassLoader() @@ -131,6 +135,7 @@ private[spark] class Executor( def stop() { env.metricsSystem.report() + env.actorSystem.stop(executorActor) isStopped = true threadPool.shutdown() if (!isLocal) { diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala new file mode 100644 index 0000000000000..41925f7e97e84 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import akka.actor.Actor +import org.apache.spark.Logging + +import org.apache.spark.util.{Utils, ActorLogReceive} + +/** + * Driver -> Executor message to trigger a thread dump. + */ +private[spark] case object TriggerThreadDump + +/** + * Actor that runs inside of executors to enable driver -> executor RPC. + */ +private[spark] +class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { + + override def receiveWithLogging = { + case TriggerThreadDump => + sender ! Utils.getThreadDump() + } + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index d08e1419e3e41..b63c7f191155c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -88,6 +88,10 @@ class BlockManagerMaster( askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) } + def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId)) + } + /** * Remove a block from the slaves that have it. This can only be used to remove * blocks that the driver knows about. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 5e375a2553979..685b2e11440fb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -86,6 +86,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetPeers(blockManagerId) => sender ! getPeers(blockManagerId) + case GetActorSystemHostPortForExecutor(executorId) => + sender ! getActorSystemHostPortForExecutor(executorId) + case GetMemoryStatus => sender ! memoryStatus @@ -412,6 +415,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus Seq.empty } } + + /** + * Returns the hostname and port of an executor's actor system, based on the Akka address of its + * BlockManagerSlaveActor. + */ + private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + for ( + blockManagerId <- blockManagerIdByExecutor.get(executorId); + info <- blockManagerInfo.get(blockManagerId); + host <- info.slaveActor.path.address.host; + port <- info.slaveActor.path.address.port + ) yield { + (host, port) + } + } } @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 291ddfcc113ac..3f32099d08cc9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -92,6 +92,8 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class RemoveExecutor(execId: String) extends ToBlockManagerMaster case object StopBlockManagerMaster extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala new file mode 100644 index 0000000000000..e9c755e36f716 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.exec + +import javax.servlet.http.HttpServletRequest + +import scala.util.Try +import scala.xml.{Text, Node} + +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") { + + private val sc = parent.sc + + def render(request: HttpServletRequest): Seq[Node] = { + val executorId = Option(request.getParameter("executorId")).getOrElse { + return Text(s"Missing executorId parameter") + } + val time = System.currentTimeMillis() + val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) + + val content = maybeThreadDump.map { threadDump => + val dumpRows = threadDump.map { thread => +
+
+ + Thread {thread.threadId}: {thread.threadName} ({thread.threadState}) + +
+ +
+ } + +
+

Updated at {UIUtils.formatDate(time)}

+ { + // scalastyle:off +

+ Expand All +

+

+ // scalastyle:on + } +
{dumpRows}
+
+ }.getOrElse(Text("Error fetching thread dump")) + UIUtils.headerSparkPage(s"Thread dump for executor $executorId", content, parent) + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b0e3bb3b552fd..048fee3ce1ff4 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -41,7 +41,10 @@ private case class ExecutorSummaryInfo( totalShuffleWrite: Long, maxMemory: Long) -private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { +private[ui] class ExecutorsPage( + parent: ExecutorsTab, + threadDumpEnabled: Boolean) + extends WebUIPage("") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -75,6 +78,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { Shuffle Write + {if (threadDumpEnabled) Thread Dump else Seq.empty} {execInfoSorted.map(execRow)} @@ -133,6 +137,15 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { {Utils.bytesToString(info.totalShuffleWrite)} + { + if (threadDumpEnabled) { + + Thread Dump + + } else { + Seq.empty + } + } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 9e0e71a51a408..ba97630f025c1 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -27,8 +27,14 @@ import org.apache.spark.ui.{SparkUI, SparkUITab} private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { val listener = parent.executorsListener + val sc = parent.sc + val threadDumpEnabled = + sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true) - attachPage(new ExecutorsPage(this)) + attachPage(new ExecutorsPage(this, threadDumpEnabled)) + if (threadDumpEnabled) { + attachPage(new ExecutorThreadDumpPage(this)) + } } /** diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index b5207360510dd..e3223403c17f4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -59,6 +59,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val failedStages = ListBuffer[StageInfo]() val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData] val stageIdToInfo = new HashMap[StageId, StageInfo] + + // Number of completed and failed stages, may not actually equal to completedStages.size and + // failedStages.size respectively due to completedStage and failedStages only maintain the latest + // part of the stages, the earlier ones will be removed when there are too many stages for + // memory sake. + var numCompletedStages = 0 + var numFailedStages = 0 // Map from pool name to a hash map (map from stage id to StageInfo). val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]() @@ -110,9 +117,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { activeStages.remove(stage.stageId) if (stage.failureReason.isEmpty) { completedStages += stage + numCompletedStages += 1 trimIfNecessary(completedStages) } else { failedStages += stage + numFailedStages += 1 trimIfNecessary(failedStages) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala index 6e718eecdd52a..83a7898071c9b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala @@ -34,7 +34,9 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") listener.synchronized { val activeStages = listener.activeStages.values.toSeq val completedStages = listener.completedStages.reverse.toSeq + val numCompletedStages = listener.numCompletedStages val failedStages = listener.failedStages.reverse.toSeq + val numFailedStages = listener.numFailedStages val now = System.currentTimeMillis val activeStagesTable = @@ -69,11 +71,11 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")
  • Completed Stages: - {completedStages.size} + {numCompletedStages}
  • Failed Stages: - {failedStages.size} + {numFailedStages}
  • @@ -86,9 +88,9 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") }} ++

    Active Stages ({activeStages.size})

    ++ activeStagesTable.toNodeSeq ++ -

    Completed Stages ({completedStages.size})

    ++ +

    Completed Stages ({numCompletedStages})

    ++ completedStagesTable.toNodeSeq ++ -

    Failed Stages ({failedStages.size})

    ++ +

    Failed Stages ({numFailedStages})

    ++ failedStagesTable.toNodeSeq UIUtils.headerSparkPage("Spark Stages", content, parent) diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 79e398eb8c104..10010bdfa1a51 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -212,4 +212,18 @@ private[spark] object AkkaUtils extends Logging { logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) } + + def makeExecutorRef( + name: String, + conf: SparkConf, + host: String, + port: Int, + actorSystem: ActorSystem): ActorRef = { + val executorActorSystemName = SparkEnv.executorActorSystemName + Utils.checkHost(host, "Expected hostname") + val url = s"akka.tcp://$executorActorSystemName@$host:$port/user/$name" + val timeout = AkkaUtils.lookupTimeout(conf) + logInfo(s"Connecting to $name: $url") + Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + } } diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala new file mode 100644 index 0000000000000..d4e0ad93b966a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Used for shipping per-thread stacktraces from the executors to driver. + */ +private[spark] case class ThreadStackTrace( + threadId: Long, + threadName: String, + threadState: Thread.State, + stackTrace: String) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a33046d2040d8..6ab94af9f3739 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io._ +import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer import java.util.jar.Attributes.Name @@ -1611,6 +1612,18 @@ private[spark] object Utils extends Logging { s"$className: $desc\n$st" } + /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */ + def getThreadDump(): Array[ThreadStackTrace] = { + // We need to filter out null values here because dumpAllThreads() may return null array + // elements for threads that are dead / don't exist. + val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) + threadInfos.sortBy(_.getThreadId).map { case threadInfo => + val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n") + ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, + threadInfo.getThreadState, stackTrace) + } + } + /** * Convert all spark properties set in the given SparkConf to a sequence of java options. */ diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a91c9ddeaef36..aec1e409db95c 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -177,6 +177,17 @@ class JsonProtocolSuite extends FunSuite { deserializedBmRemoved) } + test("FetchFailed backwards compatibility") { + // FetchFailed in Spark 1.1.0 does not have an "Message" property. + val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, + "ignored") + val oldEvent = JsonProtocol.taskEndReasonToJson(fetchFailed) + .removeField({ _._1 == "Message" }) + val expectedFetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, + "Unknown reason") + assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) + } + test("SparkListenerApplicationStart backwards compatibility") { // SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property. val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user") @@ -185,6 +196,15 @@ class JsonProtocolSuite extends FunSuite { assert(applicationStart === JsonProtocol.applicationStartFromJson(oldEvent)) } + test("ExecutorLostFailure backward compatibility") { + // ExecutorLostFailure in Spark 1.1.0 does not have an "Executor ID" property. + val executorLostFailure = ExecutorLostFailure("100") + val oldEvent = JsonProtocol.taskEndReasonToJson(executorLostFailure) + .removeField({ _._1 == "Executor ID" }) + val expectedExecutorLostFailure = ExecutorLostFailure("Unknown") + assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ diff --git a/dev/run-tests b/dev/run-tests index 0e9eefa76a18b..de607e4344453 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -180,7 +180,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS if [ -n "$_SQL_TESTS_ONLY" ]; then # This must be an array of individual arguments. Otherwise, having one long string #+ will be interpreted as a single test, which doesn't work. - SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test") + SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test") else SBT_MAVEN_TEST_ARGS=("test") fi diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py new file mode 100644 index 0000000000000..540dae785f6ea --- /dev/null +++ b/examples/src/main/python/mllib/dataset_example.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example of how to use SchemaRDD as a dataset for ML. Run with:: + bin/spark-submit examples/src/main/python/mllib/dataset_example.py +""" + +import os +import sys +import tempfile +import shutil + +from pyspark import SparkContext +from pyspark.sql import SQLContext +from pyspark.mllib.util import MLUtils +from pyspark.mllib.stat import Statistics + + +def summarize(dataset): + print "schema: %s" % dataset.schema().json() + labels = dataset.map(lambda r: r.label) + print "label average: %f" % labels.mean() + features = dataset.map(lambda r: r.features) + summary = Statistics.colStats(features) + print "features average: %r" % summary.mean() + +if __name__ == "__main__": + if len(sys.argv) > 2: + print >> sys.stderr, "Usage: dataset_example.py " + exit(-1) + sc = SparkContext(appName="DatasetExample") + sqlCtx = SQLContext(sc) + if len(sys.argv) == 2: + input = sys.argv[1] + else: + input = "data/mllib/sample_libsvm_data.txt" + points = MLUtils.loadLibSVMFile(sc, input) + dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache() + summarize(dataset0) + tempdir = tempfile.NamedTemporaryFile(delete=False).name + os.unlink(tempdir) + print "Save dataset as a Parquet file to %s." % tempdir + dataset0.saveAsParquetFile(tempdir) + print "Load it back and summarize it again." + dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache() + summarize(dataset1) + shutil.rmtree(tempdir) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala new file mode 100644 index 0000000000000..f8d83f4ec7327 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import java.io.File + +import com.google.common.io.Files +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} + +/** + * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DatasetExample { + + case class Params( + input: String = "data/mllib/sample_libsvm_data.txt", + dataFormat: String = "libsvm") extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DatasetExample") { + head("Dataset: an example app using SchemaRDD as a Dataset for ML.") + opt[String]("input") + .text(s"input path to dataset") + .action((x, c) => c.copy(input = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(input = x)) + checkConfig { params => + success + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DatasetExample with $params") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ // for implicit conversions + + // Load input data + val origData: RDD[LabeledPoint] = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.input) + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) + } + println(s"Loaded ${origData.count()} instances from file: ${params.input}") + + // Convert input data to SchemaRDD explicitly. + val schemaRDD: SchemaRDD = origData + println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") + println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") + + // Select columns, using implicit conversion to SchemaRDD. + val labelsSchemaRDD: SchemaRDD = origData.select('label) + val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } + val numLabels = labels.count() + val meanLabel = labels.fold(0.0)(_ + _) / numLabels + println(s"Selected label column with average value $meanLabel") + + val featuresSchemaRDD: SchemaRDD = origData.select('features) + val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } + val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + + val tmpDir = Files.createTempDir() + tmpDir.deleteOnExit() + val outputDir = new File(tmpDir, "dataset").toString + println(s"Saving to $outputDir as Parquet file.") + schemaRDD.saveAsParquetFile(outputDir) + + println(s"Loading Parquet file with UDT from $outputDir.") + val newDataset = sqlContext.parquetFile(outputDir) + + println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") + val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } + val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") + + sc.stop() + } + +} diff --git a/mllib/pom.xml b/mllib/pom.xml index fb7239e779aae..87a7ddaba97f2 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -45,6 +45,11 @@ spark-streaming_${scala.binary.version} ${project.version} + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + org.eclipse.jetty jetty-server diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 6af225b7f49f7..ac217edc619ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -17,22 +17,26 @@ package org.apache.spark.mllib.linalg -import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} import java.util +import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} import scala.annotation.varargs import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} -import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException +import org.apache.spark.mllib.util.NumericParser +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} +import org.apache.spark.sql.catalyst.types._ /** * Represents a numeric vector, whose index type is Int and value type is Double. * * Note: Users should not implement this interface. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) sealed trait Vector extends Serializable { /** @@ -74,6 +78,65 @@ sealed trait Vector extends Serializable { } } +/** + * User-defined type for [[Vector]] which allows easy interaction with SQL + * via [[org.apache.spark.sql.SchemaRDD]]. + */ +private[spark] class VectorUDT extends UserDefinedType[Vector] { + + override def sqlType: StructType = { + // type: 0 = sparse, 1 = dense + // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse + // vectors. The "values" field is nullable because we might want to add binary vectors later, + // which uses "size" and "indices", but not "values". + StructType(Seq( + StructField("type", ByteType, nullable = false), + StructField("size", IntegerType, nullable = true), + StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) + } + + override def serialize(obj: Any): Row = { + val row = new GenericMutableRow(4) + obj match { + case sv: SparseVector => + row.setByte(0, 0) + row.setInt(1, sv.size) + row.update(2, sv.indices.toSeq) + row.update(3, sv.values.toSeq) + case dv: DenseVector => + row.setByte(0, 1) + row.setNullAt(1) + row.setNullAt(2) + row.update(3, dv.values.toSeq) + } + row + } + + override def deserialize(datum: Any): Vector = { + datum match { + case row: Row => + require(row.length == 4, + s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") + val tpe = row.getByte(0) + tpe match { + case 0 => + val size = row.getInt(1) + val indices = row.getAs[Iterable[Int]](2).toArray + val values = row.getAs[Iterable[Double]](3).toArray + new SparseVector(size, indices, values) + case 1 => + val values = row.getAs[Iterable[Double]](3).toArray + new DenseVector(values) + } + } + } + + override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT" + + override def userClass: Class[Vector] = classOf[Vector] +} + /** * Factory methods for [[org.apache.spark.mllib.linalg.Vector]]. * We don't use the name `Vector` because Scala imports @@ -191,6 +254,7 @@ object Vectors { /** * A dense vector represented by a value array. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) class DenseVector(val values: Array[Double]) extends Vector { override def size: Int = values.length @@ -215,6 +279,7 @@ class DenseVector(val values: Array[Double]) extends Vector { * @param indices index array, assume to be strictly increasing. * @param values value array, must have the same length as the index array. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) class SparseVector( override val size: Int, val indices: Array[Int], diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index cd651fe2d2ddf..93a84fe07b32a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -155,4 +155,15 @@ class VectorsSuite extends FunSuite { throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.") } } + + test("VectorUDT") { + val dv0 = Vectors.dense(Array.empty[Double]) + val dv1 = Vectors.dense(1.0, 2.0) + val sv0 = Vectors.sparse(2, Array.empty, Array.empty) + val sv1 = Vectors.sparse(2, Array(1), Array(2.0)) + val udt = new VectorUDT() + for (v <- Seq(dv0, dv1, sv0, sv1)) { + assert(v === udt.deserialize(udt.serialize(v))) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala index c0a62e00432a3..5cb433232e714 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala @@ -30,7 +30,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { test("BaggedPoint RDD: without subsampling") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42) baggedRDD.collect().foreach { baggedPoint => assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) } @@ -44,7 +44,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) @@ -60,7 +60,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) @@ -75,7 +75,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) @@ -91,7 +91,7 @@ class BaggedPointSuite extends FunSuite with LocalSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed) val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index d0a0e102a1a07..c0c3dff31e7f8 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -29,6 +29,9 @@ import numpy as np +from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ + IntegerType, ByteType, Row + __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors'] @@ -106,7 +109,54 @@ def _format_float(f, digits=4): return s +class VectorUDT(UserDefinedType): + """ + SQL user-defined type (UDT) for Vector. + """ + + @classmethod + def sqlType(cls): + return StructType([ + StructField("type", ByteType(), False), + StructField("size", IntegerType(), True), + StructField("indices", ArrayType(IntegerType(), False), True), + StructField("values", ArrayType(DoubleType(), False), True)]) + + @classmethod + def module(cls): + return "pyspark.mllib.linalg" + + @classmethod + def scalaUDT(cls): + return "org.apache.spark.mllib.linalg.VectorUDT" + + def serialize(self, obj): + if isinstance(obj, SparseVector): + indices = [int(i) for i in obj.indices] + values = [float(v) for v in obj.values] + return (0, obj.size, indices, values) + elif isinstance(obj, DenseVector): + values = [float(v) for v in obj] + return (1, None, None, values) + else: + raise ValueError("cannot serialize %r of type %r" % (obj, type(obj))) + + def deserialize(self, datum): + assert len(datum) == 4, \ + "VectorUDT.deserialize given row with length %d but requires 4" % len(datum) + tpe = datum[0] + if tpe == 0: + return SparseVector(datum[1], datum[2], datum[3]) + elif tpe == 1: + return DenseVector(datum[3]) + else: + raise ValueError("do not recognize type %r" % tpe) + + class Vector(object): + + __UDT__ = VectorUDT() + """ Abstract class for DenseVector and SparseVector """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index d6fb87b378b4a..9fa4d6f6a2f5f 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -33,14 +33,14 @@ else: import unittest -from pyspark.serializers import PickleSerializer -from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.serializers import PickleSerializer +from pyspark.sql import SQLContext from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase - _have_scipy = False try: import scipy.sparse @@ -221,6 +221,39 @@ def test_col_with_different_rdds(self): self.assertEqual(10, summary.count()) +class VectorUDTTests(PySparkTestCase): + + dv0 = DenseVector([]) + dv1 = DenseVector([1.0, 2.0]) + sv0 = SparseVector(2, [], []) + sv1 = SparseVector(2, [1], [2.0]) + udt = VectorUDT() + + def test_json_schema(self): + self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for v in [self.dv0, self.dv1, self.sv0, self.sv1]: + self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) + + def test_infer_schema(self): + sqlCtx = SQLContext(self.sc) + rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) + srdd = sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "features"][0] + self.assertEqual(field.dataType, self.udt) + vectors = srdd.map(lambda p: p.features).collect() + self.assertEqual(len(vectors), 2) + for v in vectors: + if isinstance(v, SparseVector): + self.assertEqual(v, self.sv1) + elif isinstance(v, DenseVector): + self.assertEqual(v, self.dv1) + else: + raise ValueError("expecting a vector but got %r of type %r" % (v, type(v))) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 675df084bf303..d16c18bc79fe4 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -417,6 +417,75 @@ def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) +class UserDefinedType(DataType): + """ + :: WARN: Spark Internal Use Only :: + SQL User-Defined Type (UDT). + """ + + @classmethod + def typeName(cls): + return cls.__name__.lower() + + @classmethod + def sqlType(cls): + """ + Underlying SQL storage type for this UDT. + """ + raise NotImplementedError("UDT must implement sqlType().") + + @classmethod + def module(cls): + """ + The Python module of the UDT. + """ + raise NotImplementedError("UDT must implement module().") + + @classmethod + def scalaUDT(cls): + """ + The class name of the paired Scala UDT. + """ + raise NotImplementedError("UDT must have a paired Scala UDT.") + + def serialize(self, obj): + """ + Converts the a user-type object into a SQL datum. + """ + raise NotImplementedError("UDT must implement serialize().") + + def deserialize(self, datum): + """ + Converts a SQL datum into a user-type object. + """ + raise NotImplementedError("UDT must implement deserialize().") + + def json(self): + return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) + + def jsonValue(self): + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + return schema + + @classmethod + def fromJson(cls, json): + pyUDT = json["pyClass"] + split = pyUDT.rfind(".") + pyModule = pyUDT[:split] + pyClass = pyUDT[split+1:] + m = __import__(pyModule, globals(), locals(), [pyClass], -1) + UDT = getattr(m, pyClass) + return UDT() + + def __eq__(self, other): + return type(self) == type(other) + + _all_primitive_types = dict((v.typeName(), v) for v in globals().itervalues() if type(v) is PrimitiveTypeSingleton and @@ -469,6 +538,12 @@ def _parse_datatype_json_string(json_string): ... complex_arraytype, False) >>> check_datatype(complex_maptype) True + >>> check_datatype(ExamplePointUDT()) + True + >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> check_datatype(structtype_with_udt) + True """ return _parse_datatype_json_value(json.loads(json_string)) @@ -488,7 +563,13 @@ def _parse_datatype_json_value(json_value): else: raise ValueError("Could not parse datatype: %s" % json_value) else: - return _all_complex_types[json_value["type"]].fromJson(json_value) + tpe = json_value["type"] + if tpe in _all_complex_types: + return _all_complex_types[tpe].fromJson(json_value) + elif tpe == 'udt': + return UserDefinedType.fromJson(json_value) + else: + raise ValueError("not supported type: %s" % tpe) # Mapping Python types to Spark SQL DataType @@ -509,7 +590,18 @@ def _parse_datatype_json_value(json_value): def _infer_type(obj): - """Infer the DataType from obj""" + """Infer the DataType from obj + + >>> p = ExamplePoint(1.0, 2.0) + >>> _infer_type(p) + ExamplePointUDT + """ + if obj is None: + raise ValueError("Can not infer type for None") + + if hasattr(obj, '__UDT__'): + return obj.__UDT__ + dataType = _type_mappings.get(type(obj)) if dataType is not None: return dataType() @@ -558,6 +650,93 @@ def _infer_schema(row): return StructType(fields) +def _need_python_to_sql_conversion(dataType): + """ + Checks whether we need python to sql conversion for the given type. + For now, only UDTs need this conversion. + + >>> _need_python_to_sql_conversion(DoubleType()) + False + >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), + ... StructField("values", ArrayType(DoubleType(), False), False)]) + >>> _need_python_to_sql_conversion(schema0) + False + >>> _need_python_to_sql_conversion(ExamplePointUDT()) + True + >>> schema1 = ArrayType(ExamplePointUDT(), False) + >>> _need_python_to_sql_conversion(schema1) + True + >>> schema2 = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> _need_python_to_sql_conversion(schema2) + True + """ + if isinstance(dataType, StructType): + return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + elif isinstance(dataType, ArrayType): + return _need_python_to_sql_conversion(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_python_to_sql_conversion(dataType.keyType) or \ + _need_python_to_sql_conversion(dataType.valueType) + elif isinstance(dataType, UserDefinedType): + return True + else: + return False + + +def _python_to_sql_converter(dataType): + """ + Returns a converter that converts a Python object into a SQL datum for the given type. + + >>> conv = _python_to_sql_converter(DoubleType()) + >>> conv(1.0) + 1.0 + >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) + >>> conv([1.0, 2.0]) + [1.0, 2.0] + >>> conv = _python_to_sql_converter(ExamplePointUDT()) + >>> conv(ExamplePoint(1.0, 2.0)) + [1.0, 2.0] + >>> schema = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> conv = _python_to_sql_converter(schema) + >>> conv((1.0, ExamplePoint(1.0, 2.0))) + (1.0, [1.0, 2.0]) + """ + if not _need_python_to_sql_conversion(dataType): + return lambda x: x + + if isinstance(dataType, StructType): + names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) + converters = map(_python_to_sql_converter, types) + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): + return tuple(c(v) for c, v in zip(converters, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs + d = dict(obj) + return tuple(c(d.get(n)) for n, c in zip(names, converters)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + else: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return converter + elif isinstance(dataType, ArrayType): + element_converter = _python_to_sql_converter(dataType.elementType) + return lambda a: [element_converter(v) for v in a] + elif isinstance(dataType, MapType): + key_converter = _python_to_sql_converter(dataType.keyType) + value_converter = _python_to_sql_converter(dataType.valueType) + return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + elif isinstance(dataType, UserDefinedType): + return lambda obj: dataType.serialize(obj) + else: + raise ValueError("Unexpected type %r" % dataType) + + def _has_nulltype(dt): """ Return whether there is NullType in `dt` or not """ if isinstance(dt, StructType): @@ -818,11 +997,22 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... + >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... """ # all objects are nullable if obj is None: return + if isinstance(dataType, UserDefinedType): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError("%r is not an instance of type %r" % (obj, dataType)) + _verify_type(dataType.serialize(obj), dataType.sqlType()) + return + _type = type(dataType) assert _type in _acceptable_types, "unkown datatype: %s" % dataType @@ -897,6 +1087,8 @@ def _has_struct_or_date(dt): return _has_struct_or_date(dt.valueType) elif isinstance(dt, DateType): return True + elif isinstance(dt, UserDefinedType): + return True return False @@ -967,6 +1159,9 @@ def Dict(d): elif isinstance(dataType, DateType): return datetime.date + elif isinstance(dataType, UserDefinedType): + return lambda datum: dataType.deserialize(datum) + elif not isinstance(dataType, StructType): raise Exception("unexpected data type: %s" % dataType) @@ -1244,6 +1439,10 @@ def applySchema(self, rdd, schema): for row in rows: _verify_type(row, schema) + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) jrdd = self._pythonToJava(rdd._jrdd, batched) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) @@ -1877,6 +2076,7 @@ def _test(): # let doctest run in pyspark.sql, so DataTypes can be picklable import pyspark.sql from pyspark.sql import Row, SQLContext + from pyspark.tests import ExamplePoint, ExamplePointUDT globs = pyspark.sql.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: @@ -1888,6 +2088,8 @@ def _test(): Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) + globs['ExamplePoint'] = ExamplePoint + globs['ExamplePointUDT'] = ExamplePointUDT jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index dc5d1312fe994..6e7f24c6066a5 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -49,7 +49,8 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType +from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ + UserDefinedType, DoubleType from pyspark import shuffle _have_scipy = False @@ -694,8 +695,65 @@ def heavy_foo(x): self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + class SQLTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) + def setUp(self): self.sqlCtx = SQLContext(self.sc) @@ -824,6 +882,39 @@ def test_convert_row_to_dict(self): row = self.sqlCtx.sql("select l[0].a AS la from test").first() self.assertEqual(1, row.asDict()["la"]) + def test_infer_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd = self.sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + srdd.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + srdd = self.sqlCtx.applySchema(rdd, schema) + point = srdd.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd0 = self.sqlCtx.inferSchema(rdd) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + srdd0.saveAsParquetFile(output_dir) + srdd1 = self.sqlCtx.parquetFile(output_dir) + point = srdd1.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + class InputFormatTests(ReusedPySparkTestCase): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index fa1786e74bb3e..18c96da2f87fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -34,320 +34,366 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi override def toString = s"scalaUDF(${children.mkString(",")})" + // scalastyle:off + /** This method has been generated by this script (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) - val evals = (0 to x - 1).map(x => s"children($x).eval(input)").reduce(_ + ",\n " + _) + val evals = (0 to x - 1).map(x => s" ScalaReflection.convertToScala(children($x).eval(input), children($x).dataType)").reduce(_ + ",\n " + _) s""" case $x => function.asInstanceOf[($anys) => Any]( - $evals) + $evals) """ - } + }.foreach(println) */ - // scalastyle:off override def eval(input: Row): Any = { val result = children.size match { case 0 => function.asInstanceOf[() => Any]() - case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) + case 1 => + function.asInstanceOf[(Any) => Any]( + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType)) + + case 2 => function.asInstanceOf[(Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType)) + + case 3 => function.asInstanceOf[(Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType)) + + case 4 => function.asInstanceOf[(Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType)) + + case 5 => function.asInstanceOf[(Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType)) + + case 6 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType)) + + case 7 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType)) + + case 8 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType)) + + case 9 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType)) + + case 10 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType)) + + case 11 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType)) + + case 12 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType)) + + case 13 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType)) + + case 14 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType)) + + case 15 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType)) + + case 16 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType)) + + case 17 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType)) + + case 18 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType)) + + case 19 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType)) + + case 20 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType)) + + case 21 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input), - children(20).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), + ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType)) + + case 22 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input), - children(20).eval(input), - children(21).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), + ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType), + ScalaReflection.convertToScala(children(21).eval(input), children(21).dataType)) + } // scalastyle:on diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index e1b5992a36e5f..5dd19dd12d8dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -71,6 +71,8 @@ object DataType { case JSortedObject( ("class", JString(udtClass)), + ("pyClass", _), + ("sqlType", _), ("type", JString("udt"))) => Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] } @@ -593,6 +595,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** Underlying storage type for this UDT */ def sqlType: DataType + /** Paired Python UDT class, if exists. */ + def pyUDT: String = null + /** * Convert the user type to a SQL datum * @@ -606,7 +611,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def jsonValue: JValue = { ("type" -> "udt") ~ - ("class" -> this.getClass.getName) + ("class" -> this.getClass.getName) ~ + ("pyClass" -> pyUDT) ~ + ("sqlType" -> sqlType.jsonValue) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9e61d18f7e926..84eaf401f240c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{Optimizer, DefaultOptimizer} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation @@ -483,6 +484,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case ArrayType(_, _) => true case MapType(_, _, _) => true case StructType(_) => true + case udt: UserDefinedType[_] => needsConversion(udt.sqlType) case other => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 997669051ed07..a83cf5d441d1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -135,6 +135,8 @@ object EvaluatePython { case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type }.asJava + case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) + case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal // Pyrolite can handle Timestamp @@ -177,6 +179,9 @@ object EvaluatePython { case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) + case (_, udt: UserDefinedType[_]) => + fromJava(obj, udt.sqlType) + case (c: Int, ByteType) => c.toByte case (c: Long, ByteType) => c.toByte case (c: Int, ShortType) => c.toShort diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala new file mode 100644 index 0000000000000..b9569e96c0312 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.types._ + +/** + * An example class to demonstrate UDT in Scala, Java, and Python. + * @param x x coordinate + * @param y y coordinate + */ +@SQLUserDefinedType(udt = classOf[ExamplePointUDT]) +private[sql] class ExamplePoint(val x: Double, val y: Double) + +/** + * User-defined type for [[ExamplePoint]]. + */ +private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { + + override def sqlType: DataType = ArrayType(DoubleType, false) + + override def pyUDT: String = "pyspark.tests.ExamplePointUDT" + + override def serialize(obj: Any): Seq[Double] = { + obj match { + case p: ExamplePoint => + Seq(p.x, p.y) + } + } + + override def deserialize(datum: Any): ExamplePoint = { + datum match { + case values: Seq[_] => + val xy = values.asInstanceOf[Seq[Double]] + assert(xy.length == 2) + new ExamplePoint(xy(0), xy(1)) + case values: util.ArrayList[_] => + val xy = values.asInstanceOf[util.ArrayList[Double]].asScala + new ExamplePoint(xy(0), xy(1)) + } + } + + override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 1bc15146f0fe8..3fa4a7c6481d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.types.UserDefinedType - protected[sql] object DataTypeConversions { /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 666235e57f812..1806a1dd82023 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -60,13 +60,13 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } class UserDefinedTypeSuite extends QueryTest { + val points = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) + val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) - test("register user type: MyDenseVector for MyLabeledPoint") { - val points = Seq( - MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) + test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() assert(labelsArrays.size === 2) @@ -80,4 +80,12 @@ class UserDefinedTypeSuite extends QueryTest { assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) } + + test("UDTs and UDFs") { + registerFunction("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + pointsRDD.registerTempTable("points") + checkAnswer( + sql("SELECT testType(features) from points"), + Seq(Row(true), Row(true))) + } }