diff --git a/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala b/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala new file mode 100644 index 0000000000000..d6c6910c0ed84 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib + +import org.apache.spark.mllib.input.WholeTextFileInputFormat +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext + +/** + * Extra functions available on SparkContext of mllib through an implicit conversion. Import + * `org.apache.spark.mllib.MLContext._` at the top of your program to use these functions. + */ +class MLContext(self: SparkContext) { + + /** + * Read a directory of text files from HDFS, a local file system (available on all nodes), or any + * Hadoop-supported file system URI. Each file is read as a single record and returned in a + * key-value pair, where the key is the path of each file, the value is the content of each file. + * + *
For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do `val rdd = mlContext.wholeTextFile("hdfs://a-hdfs-path")`, + * + *
then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + */ + def wholeTextFile(path: String): RDD[(String, String)] = { + self.newAPIHadoopFile( + path, + classOf[WholeTextFileInputFormat], + classOf[String], + classOf[String]) + } +} + +/** + * The MLContext object contains a number of implicit conversions and parameters for use with + * various mllib features. + */ +object MLContext { + implicit def sparkContextToMLContext(sc: SparkContext) = new MLContext(sc) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileInputFormat.scala b/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileInputFormat.scala new file mode 100644 index 0000000000000..28133618e3c10 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileInputFormat.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.input + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.InputSplit +import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat +import org.apache.hadoop.mapreduce.RecordReader +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.mapreduce.lib.input.CombineFileRecordReader +import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit + +/** + * A [[org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat CombineFileInputFormat]] for + * reading whole text files. Each file is read as key-value pair, where the key is the file path and + * the value is the entire content of file. + */ + +private[mllib] class WholeTextFileInputFormat extends CombineFileInputFormat[String, String] { + override protected def isSplitable(context: JobContext, file: Path): Boolean = false + + override def createRecordReader( + split: InputSplit, + context: TaskAttemptContext): RecordReader[String, String] = { + + new CombineFileRecordReader[String, String]( + split.asInstanceOf[CombineFileSplit], + context, + classOf[WholeTextFileRecordReader]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileRecordReader.scala b/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileRecordReader.scala new file mode 100644 index 0000000000000..1fc668810332b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/input/WholeTextFileRecordReader.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.input + +import com.google.common.io.{ByteStreams, Closeables} + +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapreduce.InputSplit +import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit +import org.apache.hadoop.mapreduce.RecordReader +import org.apache.hadoop.mapreduce.TaskAttemptContext + +/** + * A [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] for reading a single whole text file + * out in a key-value pair, where the key is the file path and the value is the entire content of + * the file. + */ +private[mllib] class WholeTextFileRecordReader( + split: CombineFileSplit, + context: TaskAttemptContext, + index: Integer) + extends RecordReader[String, String] { + + private val path = split.getPath(index) + private val fs = path.getFileSystem(context.getConfiguration) + + // True means the current file has been processed, then skip it. + private var processed = false + + private val key = path.toString + private var value: String = null + + override def initialize(split: InputSplit, context: TaskAttemptContext) = {} + + override def close() = {} + + override def getProgress = if (processed) 1.0f else 0.0f + + override def getCurrentKey = key + + override def getCurrentValue = value + + override def nextKeyValue = { + if (!processed) { + val fileIn = fs.open(path) + val innerBuffer = ByteStreams.toByteArray(fileIn) + + value = new Text(innerBuffer).toString + Closeables.close(fileIn, false) + + processed = true + true + } else { + false + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala new file mode 100644 index 0000000000000..c79355fd26c6f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/input/WholeTextFileRecordReaderSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.input + +import java.io.DataOutputStream +import java.io.File +import java.io.FileOutputStream + +import scala.collection.immutable.IndexedSeq + +import com.google.common.io.Files + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.hadoop.io.Text + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.MLContext._ + +/** + * Tests the correctness of + * [[org.apache.spark.mllib.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary + * directory is created as fake input. Temporal storage would be deleted in the end. + */ +class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { + private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + } + + private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte]) = { + val out = new DataOutputStream(new FileOutputStream(s"${inputDir.toString}/$fileName")) + out.write(contents, 0, contents.length) + out.close() + } + + /** + * This code will test the behaviors of WholeTextFileRecordReader based on local disk. There are + * three aspects to check: + * 1) Whether all files are read; + * 2) Whether paths are read correctly; + * 3) Does the contents be the same. + */ + test("Correctness of WholeTextFileRecordReader.") { + + val dir = Files.createTempDir() + println(s"Local disk address is ${dir.toString}.") + + WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => + createNativeFile(dir, filename, contents) + } + + val res = sc.wholeTextFile(dir.toString).collect() + + assert(res.size === WholeTextFileRecordReaderSuite.fileNames.size, + "Number of files read out does not fit with the actual value.") + + for ((filename, contents) <- res) { + val shortName = filename.split('/').last + assert(WholeTextFileRecordReaderSuite.fileNames.contains(shortName), + s"Missing file name $filename.") + assert(contents === new Text(WholeTextFileRecordReaderSuite.files(shortName)).toString, + s"file $filename contents can not match.") + } + + dir.delete() + } +} + +/** + * Files to be tested are defined here. + */ +object WholeTextFileRecordReaderSuite { + private val testWords: IndexedSeq[Byte] = "Spark is easy to use.\n".map(_.toByte) + + private val fileNames = Array("part-00000", "part-00001", "part-00002") + private val fileLengths = Array(10, 100, 1000) + + private val files = fileLengths.zip(fileNames).map { case (upperBound, filename) => + filename -> Stream.continually(testWords.toList.toStream).flatten.take(upperBound).toArray + }.toMap +}