Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into spark-1403
Browse files Browse the repository at this point in the history
  • Loading branch information
Bharath Bhushan committed Apr 5, 2014
2 parents 04b9662 + 60e18ce commit b3a053f
Show file tree
Hide file tree
Showing 20 changed files with 241 additions and 50 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ class SparkContext(
* (a-hdfs-path/part-nnnnn, its content)
* }}}
*
* @note Small files are perferred, large file is also allowable, but may cause bad performance.
* @note Small files are preferred, as each file will be loaded fully in memory.
*/
def wholeTextFiles(path: String): RDD[(String, String)] = {
newAPIHadoopFile(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* (a-hdfs-path/part-nnnnn, its content)
* }}}
*
* @note Small files are perferred, large file is also allowable, but may cause bad performance.
* @note Small files are preferred, as each file will be loaded fully in memory.
*/
def wholeTextFiles(path: String): JavaPairRDD[String, String] =
new JavaPairRDD(sc.wholeTextFiles(path))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.api.python

import java.io._
import java.net._
import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}

import scala.collection.JavaConversions._
Expand Down Expand Up @@ -206,6 +207,7 @@ private object SpecialLengths {
}

private[spark] object PythonRDD {
val UTF8 = Charset.forName("UTF-8")

def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
Expand Down Expand Up @@ -266,7 +268,7 @@ private[spark] object PythonRDD {
}

def writeUTF(str: String, dataOut: DataOutputStream) {
val bytes = str.getBytes("UTF-8")
val bytes = str.getBytes(UTF8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
Expand All @@ -286,7 +288,7 @@ private[spark] object PythonRDD {

private
class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8)
}

/**
Expand Down
64 changes: 61 additions & 3 deletions core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.rdd

import java.io.File
import java.io.FilenameFilter
import java.io.IOException
import java.io.PrintWriter
import java.util.StringTokenizer

Expand All @@ -27,6 +30,7 @@ import scala.io.Source
import scala.reflect.ClassTag

import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.util.Utils


/**
Expand All @@ -38,7 +42,8 @@ class PipedRDD[T: ClassTag](
command: Seq[String],
envVars: Map[String, String],
printPipeContext: (String => Unit) => Unit,
printRDDElement: (T, String => Unit) => Unit)
printRDDElement: (T, String => Unit) => Unit,
separateWorkingDir: Boolean)
extends RDD[String](prev) {

// Similar to Runtime.exec(), if we are given a single string, split it into words
Expand All @@ -48,12 +53,24 @@ class PipedRDD[T: ClassTag](
command: String,
envVars: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
printRDDElement: (T, String => Unit) => Unit = null) =
this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement)
printRDDElement: (T, String => Unit) => Unit = null,
separateWorkingDir: Boolean = false) =
this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement,
separateWorkingDir)


override def getPartitions: Array[Partition] = firstParent[T].partitions

/**
* A FilenameFilter that accepts anything that isn't equal to the name passed in.
* @param name of file or directory to leave out
*/
class NotEqualsFileNameFilter(filterName: String) extends FilenameFilter {
def accept(dir: File, name: String): Boolean = {
!name.equals(filterName)
}
}

override def compute(split: Partition, context: TaskContext): Iterator[String] = {
val pb = new ProcessBuilder(command)
// Add the environmental variables to the process.
Expand All @@ -67,6 +84,38 @@ class PipedRDD[T: ClassTag](
currentEnvVars.putAll(hadoopSplit.getPipeEnvVars())
}

// When spark.worker.separated.working.directory option is turned on, each
// task will be run in separate directory. This should be resolve file
// access conflict issue
val taskDirectory = "./tasks/" + java.util.UUID.randomUUID.toString
var workInTaskDirectory = false
logDebug("taskDirectory = " + taskDirectory)
if (separateWorkingDir == true) {
val currentDir = new File(".")
logDebug("currentDir = " + currentDir.getAbsolutePath())
val taskDirFile = new File(taskDirectory)
taskDirFile.mkdirs()

try {
val tasksDirFilter = new NotEqualsFileNameFilter("tasks")

// Need to add symlinks to jars, files, and directories. On Yarn we could have
// directories and other files not known to the SparkContext that were added via the
// Hadoop distributed cache. We also don't want to symlink to the /tasks directories we
// are creating here.
for (file <- currentDir.list(tasksDirFilter)) {
val fileWithDir = new File(currentDir, file)
Utils.symlink(new File(fileWithDir.getAbsolutePath()),
new File(taskDirectory + "/" + fileWithDir.getName()))
}
pb.directory(taskDirFile)
workInTaskDirectory = true
} catch {
case e: Exception => logError("Unable to setup task working directory: " + e.getMessage +
" (" + taskDirectory + ")")
}
}

val proc = pb.start()
val env = SparkEnv.get

Expand Down Expand Up @@ -112,6 +161,15 @@ class PipedRDD[T: ClassTag](
if (exitStatus != 0) {
throw new Exception("Subprocess exited with status " + exitStatus)
}

// cleanup task working directory if used
if (workInTaskDirectory == true) {
scala.util.control.Exception.ignoring(classOf[IOException]) {
Utils.deleteRecursively(new File(taskDirectory))
}
logDebug("Removed task working directory " + taskDirectory)
}

false
}
}
Expand Down
7 changes: 5 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -481,16 +481,19 @@ abstract class RDD[T: ClassTag](
* instead of constructing a huge String to concat all the elements:
* def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
* for (e <- record._2){f(e)}
* @param separateWorkingDir Use separate working directories for each task.
* @return the result RDD
*/
def pipe(
command: Seq[String],
env: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = {
printRDDElement: (T, String => Unit) => Unit = null,
separateWorkingDir: Boolean = false): RDD[String] = {
new PipedRDD(this, command, env,
if (printPipeContext ne null) sc.clean(printPipeContext) else null,
if (printRDDElement ne null) sc.clean(printRDDElement) else null)
if (printRDDElement ne null) sc.clean(printRDDElement) else null,
separateWorkingDir)
}

/**
Expand Down
45 changes: 44 additions & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor}
import scala.collection.JavaConversions._
import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.SortedSet
import scala.io.Source
import scala.reflect.ClassTag

Expand All @@ -43,6 +44,8 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream,
*/
private[spark] object Utils extends Logging {

val osName = System.getProperty("os.name")

/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
Expand Down Expand Up @@ -521,9 +524,10 @@ private[spark] object Utils extends Logging {

/**
* Delete a file or directory and its contents recursively.
* Don't follow directories if they are symlinks.
*/
def deleteRecursively(file: File) {
if (file.isDirectory) {
if ((file.isDirectory) && !isSymlink(file)) {
for (child <- listFilesSafely(file)) {
deleteRecursively(child)
}
Expand All @@ -536,6 +540,25 @@ private[spark] object Utils extends Logging {
}
}

/**
* Check to see if file is a symbolic link.
*/
def isSymlink(file: File): Boolean = {
if (file == null) throw new NullPointerException("File must not be null")
if (osName.startsWith("Windows")) return false
val fileInCanonicalDir = if (file.getParent() == null) {
file
} else {
new File(file.getParentFile().getCanonicalFile(), file.getName())
}

if (fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile())) {
return false;
} else {
return true;
}
}

/**
* Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
*/
Expand Down Expand Up @@ -898,6 +921,26 @@ private[spark] object Utils extends Logging {
count
}

/**
* Creates a symlink. Note jdk1.7 has Files.createSymbolicLink but not used here
* for jdk1.6 support. Supports windows by doing copy, everything else uses "ln -sf".
* @param src absolute path to the source
* @param dst relative path for the destination
*/
def symlink(src: File, dst: File) {
if (!src.isAbsolute()) {
throw new IOException("Source must be absolute")
}
if (dst.isAbsolute()) {
throw new IOException("Destination must be relative")
}
val linkCmd = if (osName.startsWith("Windows")) "copy" else "ln -sf"
import scala.sys.process._
(linkCmd + " " + src.getAbsolutePath() + " " + dst.getPath()) lines_! ProcessLogger(line =>
(logInfo(line)))
}


/** Return the class name of the given object, removing all dollar signs */
def getFormattedClassName(obj: AnyRef) = {
obj.getClass.getSimpleName.replace("$", "")
Expand Down
28 changes: 27 additions & 1 deletion core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark

import org.scalatest.FunSuite
import java.io.File

import com.google.common.io.Files

import org.scalatest.FunSuite

import org.apache.spark.rdd.{HadoopRDD, PipedRDD, HadoopPartition}
import org.apache.hadoop.mapred.{JobConf, TextInputFormat, FileSplit}
Expand Down Expand Up @@ -126,6 +129,29 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
}
}

test("basic pipe with separate working directory") {
if (testCommandAvailable("cat")) {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("cat"), separateWorkingDir = true)
val c = piped.collect()
assert(c.size === 4)
assert(c(0) === "1")
assert(c(1) === "2")
assert(c(2) === "3")
assert(c(3) === "4")
val pipedPwd = nums.pipe(Seq("pwd"), separateWorkingDir = true)
val collectPwd = pipedPwd.collect()
assert(collectPwd(0).contains("tasks/"))
val pipedLs = nums.pipe(Seq("ls"), separateWorkingDir = true).collect()
// make sure symlinks were created
assert(pipedLs.length > 0)
// clean up top level tasks directory
new File("tasks").delete()
} else {
assert(true)
}
}

test("test pipe exports map_input_file") {
testExportInputFile("map_input_file")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Shou
}

test("test executor id to summary") {
val sc = new SparkContext("local", "test")
val listener = new JobProgressListener(sc.conf)
val conf = new SparkConf()
val listener = new JobProgressListener(conf)
val taskMetrics = new TaskMetrics()
val shuffleReadMetrics = new ShuffleReadMetrics()

Expand Down
44 changes: 42 additions & 2 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
from pyspark.rdd import RDD
Expand Down Expand Up @@ -257,6 +258,45 @@ def textFile(self, name, minSplits=None):
return RDD(self._jsc.textFile(name, minSplits), self,
UTF8Deserializer())

def wholeTextFiles(self, path):
"""
Read a directory of text files from HDFS, a local file system
(available on all nodes), or any Hadoop-supported file system
URI. Each file is read as a single record and returned in a
key-value pair, where the key is the path of each file, the
value is the content of each file.
For example, if you have the following files::
hdfs://a-hdfs-path/part-00000
hdfs://a-hdfs-path/part-00001
...
hdfs://a-hdfs-path/part-nnnnn
Do C{rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")},
then C{rdd} contains::
(a-hdfs-path/part-00000, its content)
(a-hdfs-path/part-00001, its content)
...
(a-hdfs-path/part-nnnnn, its content)
NOTE: Small files are preferred, as each file will be loaded
fully in memory.
>>> dirPath = os.path.join(tempdir, "files")
>>> os.mkdir(dirPath)
>>> with open(os.path.join(dirPath, "1.txt"), "w") as file1:
... file1.write("1")
>>> with open(os.path.join(dirPath, "2.txt"), "w") as file2:
... file2.write("2")
>>> textFiles = sc.wholeTextFiles(dirPath)
>>> sorted(textFiles.collect())
[(u'.../1.txt', u'1'), (u'.../2.txt', u'2')]
"""
return RDD(self._jsc.wholeTextFiles(path), self,
PairDeserializer(UTF8Deserializer(), UTF8Deserializer()))

def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name)
return RDD(jrdd, self, input_deserializer)
Expand Down Expand Up @@ -425,7 +465,7 @@ def _test():
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
globs['tempdir'] = tempfile.mkdtemp()
atexit.register(lambda: shutil.rmtree(globs['tempdir']))
(failure_count, test_count) = doctest.testmod(globs=globs)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class MarshalSerializer(FramedSerializer):

class UTF8Deserializer(Serializer):
"""
Deserializes streams written by getBytes.
Deserializes streams written by String.getBytes.
"""

def loads(self, stream):
Expand Down
Loading

0 comments on commit b3a053f

Please sign in to comment.