diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala index 829c058ec03c6..c6ac527660709 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala @@ -93,11 +93,18 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable { def minOption(cmp: (Vector, Vector) => Boolean) = maxMinOption(!cmp(_, _)) - def rowShrink(): RDD[Vector] = { + def rowShrink(): RDD[Vector] = self.filter(x => x.toArray.sum != 0) + + def colShrink(): RDD[Vector] = { + val means = self.colMeans() + self.map( v => Vectors.dense(v.toArray.zip(means.toArray).filter{ case (x, m) => m != 0.0 }.map(_._1))) + } + + def colShrinkWithFilter(): (RDD[Vector], RDD[Boolean]) = { ??? } - def colShrink(): RDD[Vector] = { + def rowShrinkWithFilter(): (RDD[Vector], RDD[Boolean]) = { ??? } }