diff --git a/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala b/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala index eefca193ec53e..fb4d458cd8a09 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala @@ -34,24 +34,37 @@ class MLContext(self: SparkContext) { * where the feature indices are converted to zero-based. * * @param path file or directory path in any Hadoop-supported file system URI - * @param numFeatures number of features - * @param labelParser parser for labels, default: _.toDouble + * @param numFeatures number of features, it will be determined from input + * if a non-positive value is given + *@param labelParser parser for labels, default: _.toDouble * @return labeled data stored as an RDD[LabeledPoint] */ def libSVMFile( path: String, numFeatures: Int, labelParser: String => Double = _.toDouble): RDD[LabeledPoint] = { - self.textFile(path).map(_.trim).filter(!_.isEmpty).map { line => - val items = line.split(' ') + val parsed = self.textFile(path).map(_.trim).filter(!_.isEmpty).map(_.split(' ')) + // Determine number of features. + val d = if (numFeatures > 0) { + numFeatures + } else { + parsed.map { items => + if (items.length > 1) { + items.last.split(':')(0).toInt + } else { + 0 + } + }.reduce(math.max) + } + parsed.map { items => val label = labelParser(items.head) - val features = Vectors.sparse(numFeatures, items.tail.map { item => + val (indices, values) = items.tail.map { item => val indexAndValue = item.split(':') val index = indexAndValue(0).toInt - 1 val value = indexAndValue(1).toDouble (index, value) - }) - LabeledPoint(label, features) + }.unzip + LabeledPoint(label, Vectors.sparse(d, indices.toArray, values.toArray)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/MLContextSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/MLContextSuite.scala index 6762f8c479e98..743102b54fa9e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/MLContextSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/MLContextSuite.scala @@ -33,17 +33,26 @@ class MLContextSuite extends FunSuite with LocalSparkContext { val lines = """ |1 1:1.0 3:2.0 5:3.0 + |0 |0 2:4.0 4:5.0 6:6.0 """.stripMargin val tempDir = Files.createTempDir() val file = new File(tempDir.getPath, "part-00000") Files.write(lines, file, Charsets.US_ASCII) - val points = sc.libSVMFile(tempDir.toURI.toString, 6).collect() - assert(points.length === 2) - assert(points(0).label === 1.0) - assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) - assert(points(1).label === 0.0) - assert(points(1).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0)))) + + val pointsWithNumFeatures = sc.libSVMFile(tempDir.toURI.toString, 6).collect() + val pointsWithoutNumFeatures = sc.libSVMFile(tempDir.toURI.toString, 0).collect() + + for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) { + assert(points.length === 3) + assert(points(0).label === 1.0) + assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + assert(points(1).label == 0.0) + assert(points(1).features == Vectors.sparse(6, Seq())) + assert(points(2).label === 0.0) + assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0)))) + } + try { file.delete() tempDir.delete()