diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 8890662d99b52..33b5c5fa3f465 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} +import breeze.linalg.{DenseVector => BDV} import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg._ @@ -244,11 +245,30 @@ class IndexedRowMatrix @Since("1.0.0") ( */ @Since("1.0.0") def multiply(B: Matrix): IndexedRowMatrix = { - val mat = toRowMatrix().multiply(B) - val indexedRows = rows.map(_.index).zip(mat.rows).map { case (i, v) => - IndexedRow(i, v) + val n = numCols().toInt + val k = B.numCols + require(n == B.numRows, s"Dimension mismatch: $n vs ${B.numRows}") + + require(B.isInstanceOf[DenseMatrix], + s"Only support dense matrix at this time but found ${B.getClass.getName}.") + + val Bb = rows.context.broadcast(B.asBreeze.asInstanceOf[BDM[Double]].toDenseVector.toArray) + val AB = rows.mapPartitions { iter => + val Bi = Bb.value + iter.map { row => + val index = row.index + val vector = row.vector + val v = BDV.zeros[Double](k) + var i = 0 + while (i < k) { + v(i) = vector.asBreeze.dot(new BDV(Bi, i * n, 1, n)) + i += 1 + } + IndexedRow(index, Vectors.fromBreeze(v)) + } } - new IndexedRowMatrix(indexedRows, nRows, B.numCols) + + new IndexedRowMatrix(AB, 0L, B.numCols) } /** @@ -274,3 +294,4 @@ class IndexedRowMatrix @Since("1.0.0") ( mat } } +