Skip to content

Commit

Permalink
[SPARK-9175] [MLLIB] BLAS.gemm fails to update matrix C when alpha==0…
Browse files Browse the repository at this point in the history
… and beta!=1

Fix BLAS.gemm to update matrix C when alpha==0 and beta!=1
Also include unit tests to verify the fix.

mengxr brkyvz

Author: Meihua Wu <[email protected]>

Closes #7503 from rotationsymmetry/fix_BLAS_gemm and squashes the following commits:

fce199c [Meihua Wu] Fix BLAS.gemm to update C when alpha==0 and beta!=1

(cherry picked from commit ff3c72d)
Signed-off-by: Xiangrui Meng <[email protected]>
  • Loading branch information
rotationsymmetry authored and mengxr committed Jul 21, 2015
1 parent 429eedd commit 1c38d42
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ private[spark] object BLAS extends Serializable with Logging {
B: DenseMatrix,
beta: Double,
C: DenseMatrix): Unit = {
if (alpha == 0.0) {
logDebug("gemm: alpha is equal to 0. Returning C.")
if (alpha == 0.0 && beta == 1.0) {
logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.")
} else {
A match {
case sparse: SparseMatrix =>
Expand Down
15 changes: 15 additions & 0 deletions mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,14 @@ class BLASSuite extends FunSuite {
val C6 = C1.copy
val C7 = C1.copy
val C8 = C1.copy
val C13 = C1.copy
val C14 = C1.copy
val C15 = C1.copy
val C16 = C1.copy
val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0))
val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0))
val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0))
val expected5 = C1.copy

gemm(1.0, dA, B, 2.0, C1)
gemm(1.0, sA, B, 2.0, C2)
Expand Down Expand Up @@ -181,6 +187,15 @@ class BLASSuite extends FunSuite {
assert(C6 ~== expected2 absTol 1e-15)
assert(C7 ~== expected3 absTol 1e-15)
assert(C8 ~== expected3 absTol 1e-15)

gemm(0, dA, B, 5, C13)
gemm(0, sA, B, 5, C14)
gemm(0, dA, B, 1, C15)
gemm(0, sA, B, 1, C16)
assert(C13 ~== expected4 absTol 1e-15)
assert(C14 ~== expected4 absTol 1e-15)
assert(C15 ~== expected5 absTol 1e-15)
assert(C16 ~== expected5 absTol 1e-15)
}

test("gemv") {
Expand Down

0 comments on commit 1c38d42

Please sign in to comment.