diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 1c4ae0b0ccafd..66f1a8066525b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -509,15 +509,24 @@ class RowMatrix( */ def multiply(B: Matrix): RowMatrix = { val n = numCols().toInt + val k = B.numCols require(n == B.numRows, s"Dimension mismatch: $n vs ${B.numRows}") require(B.isInstanceOf[DenseMatrix], s"Only support dense matrix at this time but found ${B.getClass.getName}.") - val Bb = rows.context.broadcast(B) + val Bb = rows.context.broadcast(B.toBreeze.asInstanceOf[BDM[Double]].toDenseVector.toArray) val AB = rows.mapPartitions({ iter => - val Bi = Bb.value.toBreeze.asInstanceOf[BDM[Double]] - iter.map(v => Vectors.fromBreeze(Bi.t * v.toBreeze)) + val Bi = Bb.value + iter.map(row => { + val v = BDV.zeros[Double](k) + var i = 0 + while (i < k) { + v(i) = row.toBreeze.dot(new BDV(Bi, i * n, 1, n)) + i += 1 + } + Vectors.fromBreeze(v) + }) }, preservesPartitioning = true) new RowMatrix(AB, nRows, B.numCols)