Skip to content

Commit

Permalink
rename m/n to numRows/numCols for local matrix
Browse files Browse the repository at this point in the history
add tests for matrices
  • Loading branch information
mengxr committed Apr 7, 2014
1 parent b881506 commit be119fe
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 17 deletions.
42 changes: 29 additions & 13 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,44 +25,60 @@ import breeze.linalg.{Matrix => BM, DenseMatrix => BDM}
trait Matrix extends Serializable {

/** Number of rows. */
def m: Int
def numRows: Int

/** Number of columns. */
def n: Int
def numCols: Int

/** Converts to a dense array in column major. */
def toArray: Array[Double]

/** Converts to a breeze matrix. */
private[mllib] def toBreeze: BM[Double]

/** Gets the (i, j)-th element. */
private[mllib] def apply(i: Int, j: Int): Double = toBreeze(i, j)
}

/**
* Column majored dense matrix.
* Column-majored dense matrix.
*
* @param m
* @param n
* @param values
* @param numRows number of rows
* @param numCols number of columns
* @param values matrix entries in column major
*/
class DenseMatrix(val m: Int, val n: Int, val values: Array[Double]) extends Matrix {
class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) extends Matrix {

require(values.length == m * n)
require(values.length == numRows * numCols)

def toArray: Array[Double] = values
override def toArray: Array[Double] = values

private[mllib] def toBreeze: BM[Double] = new BDM[Double](m, n, values)
private[mllib] override def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values)
}

object Matrices {

def dense(m: Int, n: Int, values: Array[Double]): Matrix = {
new DenseMatrix(m, n, values)
/**
* Creates a dense matrix.
*
* @param numRows number of rows
* @param numCols number of columns
* @param values matrix entries in column major
*/
def dense(numRows: Int, numCols: Int, values: Array[Double]): Matrix = {
new DenseMatrix(numRows, numCols, values)
}

/**
* Creates a Matrix instance from a breeze matrix.
* @param breeze a breeze matrix
* @return a Matrix instance
*/
private[mllib] def fromBreeze(breeze: BM[Double]): Matrix = {
breeze match {
case dm: BDM[Double] =>
require(dm.majorStride == dm.rows)
require(dm.majorStride == dm.rows,
"Do not support stride size different from the number of rows.")
new DenseMatrix(dm.rows, dm.cols, dm.data)
case _ =>
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ class RowRDDMatrix(
*/
def multiply(B: Matrix): RowRDDMatrix = {
val n = numCols().toInt
require(n == B.m, s"Dimension mismatch: $n vs ${B.m}")
require(n == B.numRows, s"Dimension mismatch: $n vs ${B.numRows}")

require(B.isInstanceOf[DenseMatrix],
s"Only support dense matrix at this time but found ${B.getClass.getName}.")
Expand All @@ -254,7 +254,7 @@ class RowRDDMatrix(
iter.map(v => Vectors.fromBreeze(Bi.t * v.toBreeze))
}, preservesPartitioning = true)

new RowRDDMatrix(AB, _m, B.n)
new RowRDDMatrix(AB, _m, B.numCols)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.linalg

import org.scalatest.FunSuite

import breeze.linalg.{DenseMatrix => BDM}

class BreezeMatrixConversionSuite extends FunSuite {
test("dense matrix to breeze") {
val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0))
val breeze = mat.toBreeze.asInstanceOf[BDM[Double]]
assert(breeze.rows === mat.numRows)
assert(breeze.cols === mat.numCols)
assert(breeze.data.eq(mat.asInstanceOf[DenseMatrix].values), "should not copy data")
}

test("dense breeze matrix to matrix") {
val breeze = new BDM[Double](3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0))
val mat = Matrices.fromBreeze(breeze).asInstanceOf[DenseMatrix]
assert(mat.numRows === breeze.rows)
assert(mat.numCols === breeze.cols)
assert(mat.values.eq(breeze.data), "should not copy data")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.linalg

import org.scalatest.FunSuite

class MatricesSuite extends FunSuite {
test("dense matrix construction") {
val m = 3
val n = 2
val values = Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)
val mat = Matrices.dense(m, n, values).asInstanceOf[DenseMatrix]
assert(mat.numRows === m)
assert(mat.numCols === n)
assert(mat.values.eq(values), "should not copy data")
assert(mat.toArray.eq(values), "toArray should not copy data")
}

test("dense matrix construction with wrong dimension") {
intercept[RuntimeException] {
Matrices.dense(3, 2, Array(0.0, 1.0, 2.0))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ class RowRDDMatrixSuite extends FunSuite with LocalSparkContext {
test("pca") {
for (mat <- Seq(denseMat, sparseMat); k <- 1 to n) {
val pc = denseMat.computePrincipalComponents(k)
assert(pc.m === n)
assert(pc.n === k)
assert(pc.numRows === n)
assert(pc.numCols === k)
assertPrincipalComponentsEqual(pc, principalComponents, k)
}
}
Expand Down

0 comments on commit be119fe

Please sign in to comment.