Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/apache/spark into SPARK-1413
Browse files Browse the repository at this point in the history
  • Loading branch information
witgo committed Apr 10, 2014
2 parents 0d5f819 + 8ca3b2b commit 5e35d87
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 14 deletions.
16 changes: 11 additions & 5 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,9 @@ class SparkContext(config: SparkConf) extends Logging {
require(p >= 0 && p < rdd.partitions.size, s"Invalid partition requested: $p")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
// There's no need to check this function for serializability,
// since it will be run right away.
val cleanedFunc = clean(func, false)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
Expand Down Expand Up @@ -1135,14 +1137,18 @@ class SparkContext(config: SparkConf) extends Logging {
def cancelAllJobs() {
dagScheduler.cancelAllJobs()
}

/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
*
* @param f closure to be cleaned and optionally serialized
* @param captureNow whether or not to serialize this closure and capture any free
* variables immediately; defaults to true. If this is set and f is not serializable,
* it will raise an exception.
*/
private[spark] def clean[F <: AnyRef](f: F): F = {
ClosureCleaner.clean(f)
f
private[spark] def clean[F <: AnyRef : ClassTag](f: F, captureNow: Boolean = true): F = {
ClosureCleaner.clean(f, captureNow)
}

/**
Expand Down
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -660,14 +660,16 @@ abstract class RDD[T: ClassTag](
* Applies a function f to all elements of this RDD.
*/
def foreach(f: T => Unit) {
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
}

/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
sc.runJob(this, (iter: Iterator[T]) => f(iter))
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
}

/**
Expand Down
21 changes: 20 additions & 1 deletion core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import scala.collection.mutable.Map
import scala.collection.mutable.Set

import scala.reflect.ClassTag

import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._

import org.apache.spark.Logging
import org.apache.spark.SparkEnv
import org.apache.spark.SparkException

private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
Expand Down Expand Up @@ -101,7 +105,7 @@ private[spark] object ClosureCleaner extends Logging {
}
}

def clean(func: AnyRef) {
def clean[F <: AnyRef : ClassTag](func: F, captureNow: Boolean = true): F = {
// TODO: cache outerClasses / innerClasses / accessedFields
val outerClasses = getOuterClasses(func)
val innerClasses = getInnerClasses(func)
Expand Down Expand Up @@ -150,6 +154,21 @@ private[spark] object ClosureCleaner extends Logging {
field.setAccessible(true)
field.set(func, outer)
}

if (captureNow) {
cloneViaSerializing(func)
} else {
func
}
}

private def cloneViaSerializing[T: ClassTag](func: T): T = {
try {
val serializer = SparkEnv.get.closureSerializer.newInstance()
serializer.deserialize[T](serializer.serialize[T](func))
} catch {
case ex: Exception => throw new SparkException("Task not serializable: " + ex.toString)
}
}

private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = {
Expand Down
17 changes: 16 additions & 1 deletion core/src/test/scala/org/apache/spark/FailureSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}

test("failure because task closure is not serializable") {
test("failure because closure in final-stage task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable

Expand All @@ -118,13 +118,27 @@ class FailureSuite extends FunSuite with LocalSparkContext {
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))

FailureSuiteState.clear()
}

test("failure because closure in early-stage task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable

// Non-serializable closure in an earlier stage
val thrown1 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
}
assert(thrown1.getClass === classOf[SparkException])
assert(thrown1.getMessage.contains("NotSerializableException"))

FailureSuiteState.clear()
}

test("failure because closure in foreach task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable

// Non-serializable closure in foreach function
val thrown2 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).foreach(x => println(a))
Expand All @@ -135,5 +149,6 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}


// TODO: Need to add tests with shuffle fetch failures.
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.serializer;

import java.io.NotSerializableException

import org.scalatest.FunSuite

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkException
import org.apache.spark.SharedSparkContext

/* A trivial (but unserializable) container for trivial functions */
class UnserializableClass {
def op[T](x: T) = x.toString

def pred[T](x: T) = x.toString.length % 2 == 0
}

class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext {

def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass)

test("throws expected serialization exceptions on actions") {
val (data, uc) = fixture

val ex = intercept[SparkException] {
data.map(uc.op(_)).count
}

assert(ex.getMessage.matches(".*Task not serializable.*"))
}

// There is probably a cleaner way to eliminate boilerplate here, but we're
// iterating over a map from transformation names to functions that perform that
// transformation on a given RDD, creating one test case for each

for (transformation <-
Map("map" -> map _, "flatMap" -> flatMap _, "filter" -> filter _, "mapWith" -> mapWith _,
"mapPartitions" -> mapPartitions _, "mapPartitionsWithIndex" -> mapPartitionsWithIndex _,
"mapPartitionsWithContext" -> mapPartitionsWithContext _, "filterWith" -> filterWith _)) {
val (name, xf) = transformation

test(s"$name transformations throw proactive serialization exceptions") {
val (data, uc) = fixture

val ex = intercept[SparkException] {
xf(data, uc)
}

assert(ex.getMessage.matches(".*Task not serializable.*"), s"RDD.$name doesn't proactively throw NotSerializableException")
}
}

def map(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.map(y => uc.op(y))

def mapWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapWith(x => x.toString)((x,y) => x + uc.op(y))

def flatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.flatMap(y=>Seq(uc.op(y)))

def filter(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.filter(y=>uc.pred(y))

def filterWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.filterWith(x => x.toString)((x,y) => uc.pred(y))

def mapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitions(_.map(y => uc.op(y)))

def mapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y)))

def mapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithContext((_, it) => it.map(y => uc.op(y)))

}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,27 @@ class ClosureCleanerSuite extends FunSuite {
val obj = new TestClassWithNesting(1)
assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1
}

test("capturing free variables in closures at RDD definition") {
val obj = new TestCaptureVarClass()
val (ones, onesPlusZeroes) = obj.run()

assert(ones === onesPlusZeroes)
}

test("capturing free variable fields in closures at RDD definition") {
val obj = new TestCaptureFieldClass()
val (ones, onesPlusZeroes) = obj.run()

assert(ones === onesPlusZeroes)
}

test("capturing arrays in closures at RDD definition") {
val obj = new TestCaptureArrayEltClass()
val (observed, expected) = obj.run()

assert(observed === expected)
}
}

// A non-serializable class we create in closures to make sure that we aren't
Expand Down Expand Up @@ -143,3 +164,50 @@ class TestClassWithNesting(val y: Int) extends Serializable {
}
}
}

class TestCaptureFieldClass extends Serializable {
class ZeroBox extends Serializable {
var zero = 0
}

def run(): (Int, Int) = {
val zb = new ZeroBox

withSpark(new SparkContext("local", "test")) {sc =>
val ones = sc.parallelize(Array(1, 1, 1, 1, 1))
val onesPlusZeroes = ones.map(_ + zb.zero)

zb.zero = 5

(ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _))
}
}
}

class TestCaptureArrayEltClass extends Serializable {
def run(): (Int, Int) = {
withSpark(new SparkContext("local", "test")) {sc =>
val rdd = sc.parallelize(1 to 10)
val data = Array(1, 2, 3)
val expected = data(0)
val mapped = rdd.map(x => data(0))
data(0) = 4
(mapped.first, expected)
}
}
}

class TestCaptureVarClass extends Serializable {
def run(): (Int, Int) = {
var zero = 0

withSpark(new SparkContext("local", "test")) {sc =>
val ones = sc.parallelize(Array(1, 1, 1, 1, 1))
val onesPlusZeroes = ones.map(_ + zero)

zero = 5

(ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
assert( graph.edges.count() === rawEdges.size )
// Vertices not explicitly provided but referenced by edges should be created automatically
assert( graph.vertices.count() === 100)
graph.triplets.map { et =>
graph.triplets.collect.map { et =>
assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr))
assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,15 +539,15 @@ abstract class DStream[T: ClassTag] (
* on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r)))
transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r), false))
}

/**
* Return a new DStream in which each RDD is generated by applying a function
* on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
val cleanedF = context.sparkContext.clean(transformFunc)
val cleanedF = context.sparkContext.clean(transformFunc, false)
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
assert(rdds.length == 1)
cleanedF(rdds.head.asInstanceOf[RDD[T]], time)
Expand All @@ -562,7 +562,7 @@ abstract class DStream[T: ClassTag] (
def transformWith[U: ClassTag, V: ClassTag](
other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V]
): DStream[V] = {
val cleanedF = ssc.sparkContext.clean(transformFunc)
val cleanedF = ssc.sparkContext.clean(transformFunc, false)
transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2))
}

Expand All @@ -573,7 +573,7 @@ abstract class DStream[T: ClassTag] (
def transformWith[U: ClassTag, V: ClassTag](
other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V]
): DStream[V] = {
val cleanedF = ssc.sparkContext.clean(transformFunc)
val cleanedF = ssc.sparkContext.clean(transformFunc, false)
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
assert(rdds.length == 2)
val rdd1 = rdds(0).asInstanceOf[RDD[T]]
Expand Down

0 comments on commit 5e35d87

Please sign in to comment.