diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 71d8a5c4aa0ec..e70281038b683 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -64,7 +64,7 @@ case class Rating(val user: Int, val product: Int, val rating: Double) * Alternating Least Squares matrix factorization. * * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices, - * `X` and `Y`, i.e. `Xt * Y = R`. Typically these approximations are called 'factor' matrices. + * `X` and `Y`, i.e. `X * Yt = R`. Typically these approximations are called 'factor' matrices. * The general approach is iterative. During each iteration, one of the factor matrices is held * constant, while the other is solved for using least squares. The newly-solved factor matrix is * then held constant while solving for the other factor matrix. @@ -381,8 +381,16 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l userXtX(us(i)).addi(tempXtX) SimpleBlas.axpy(rs(i), x, userXy(us(i))) case true => - userXtX(us(i)).addi(tempXtX.mul(alpha * rs(i))) - SimpleBlas.axpy(1 + alpha * rs(i), x, userXy(us(i))) + // Extension to the original paper to handle rs(i) < 0. confidence is a function + // of |rs(i)| instead so that it is never negative: + val confidence = 1 + alpha * abs(rs(i)) + userXtX(us(i)).addi(tempXtX.mul(confidence - 1)) + // For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i) + // means we try to reconstruct 0. We add terms only where P = 1, so, term below + // is now only added for rs(i) > 0: + if (rs(i) > 0) { + SimpleBlas.axpy(confidence, x, userXy(us(i))) + } } } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index b40f552e0d0aa..b150334deb06c 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -19,7 +19,6 @@ import java.io.Serializable; import java.util.List; -import java.lang.Math; import org.junit.After; import org.junit.Assert; @@ -46,7 +45,7 @@ public void tearDown() { System.clearProperty("spark.driver.port"); } - void validatePrediction(MatrixFactorizationModel model, int users, int products, int features, + static void validatePrediction(MatrixFactorizationModel model, int users, int products, int features, DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { DoubleMatrix predictedU = new DoubleMatrix(users, features); List> userFeatures = model.userFeatures().toJavaRDD().collect(); @@ -84,15 +83,15 @@ void validatePrediction(MatrixFactorizationModel model, int users, int products, for (int p = 0; p < products; ++p) { double prediction = predictedRatings.get(u, p); double truePref = truePrefs.get(u, p); - double confidence = 1.0 + /* alpha = */ 1.0 * trueRatings.get(u, p); + double confidence = 1.0 + /* alpha = */ 1.0 * Math.abs(trueRatings.get(u, p)); double err = confidence * (truePref - prediction) * (truePref - prediction); sqErr += err; - denom += 1.0; + denom += confidence; } } double rmse = Math.sqrt(sqErr / denom); Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f", - rmse, matchThreshold), Math.abs(rmse) < matchThreshold); + rmse, matchThreshold), rmse < matchThreshold); } } @@ -103,7 +102,7 @@ public void runALSUsingStaticMethods() { int users = 50; int products = 100; scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7, false); + users, products, features, 0.7, false, false); JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); @@ -117,7 +116,7 @@ public void runALSUsingConstructor() { int users = 100; int products = 200; scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7, false); + users, products, features, 0.7, false, false); JavaRDD data = sc.parallelize(testData._1()); @@ -134,7 +133,7 @@ public void runImplicitALSUsingStaticMethods() { int users = 80; int products = 160; scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7, true); + users, products, features, 0.7, true, false); JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); @@ -148,7 +147,7 @@ public void runImplicitALSUsingConstructor() { int users = 100; int products = 200; scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7, true); + users, products, features, 0.7, true, false); JavaRDD data = sc.parallelize(testData._1()); @@ -158,4 +157,19 @@ public void runImplicitALSUsingConstructor() { .run(data.rdd()); validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); } + + @Test + public void runImplicitALSWithNegativeWeight() { + int features = 2; + int iterations = 15; + int users = 80; + int products = 160; + scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7, true, true); + + JavaRDD data = sc.parallelize(testData._1()); + MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); + validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 5dcec7dc3eb9b..45e7d2db00c42 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.mllib.recommendation import scala.collection.JavaConversions._ +import scala.math.abs import scala.util.Random -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.jblas._ @@ -34,7 +34,8 @@ object ALSSuite { products: Int, features: Int, samplingRate: Double, - implicitPrefs: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = { + implicitPrefs: Boolean, + negativeWeights: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = { val (sampledRatings, trueRatings, truePrefs) = generateRatings(users, products, features, samplingRate, implicitPrefs) (seqAsJavaList(sampledRatings), trueRatings, truePrefs) @@ -45,7 +46,8 @@ object ALSSuite { products: Int, features: Int, samplingRate: Double, - implicitPrefs: Boolean = false): (Seq[Rating], DoubleMatrix, DoubleMatrix) = { + implicitPrefs: Boolean = false, + negativeWeights: Boolean = false): (Seq[Rating], DoubleMatrix, DoubleMatrix) = { val rand = new Random(42) // Create a random matrix with uniform values from -1 to 1 @@ -56,7 +58,9 @@ object ALSSuite { val productMatrix = randomMatrix(features, products) val (trueRatings, truePrefs) = implicitPrefs match { case true => - val raw = new DoubleMatrix(users, products, Array.fill(users * products)(rand.nextInt(10).toDouble): _*) + // Generate raw values from [0,9], or if negativeWeights, from [-2,7] + val raw = new DoubleMatrix(users, products, + Array.fill(users * products)((if (negativeWeights) -2 else 0) + rand.nextInt(10).toDouble): _*) val prefs = new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*) (raw, prefs) case false => (userMatrix.mmul(productMatrix), null) @@ -107,6 +111,10 @@ class ALSSuite extends FunSuite with LocalSparkContext { testALS(100, 200, 2, 15, 0.7, 0.4, true, true) } + test("rank-2 matrices implicit negative") { + testALS(100, 200, 2, 15, 0.7, 0.4, true, false, true) + } + /** * Test if we can correctly factorize R = U * P where U and P are of known rank. * @@ -118,13 +126,14 @@ class ALSSuite extends FunSuite with LocalSparkContext { * @param matchThreshold max difference allowed to consider a predicted rating correct * @param implicitPrefs flag to test implicit feedback * @param bulkPredict flag to test bulk prediciton + * @param negativeWeights whether the generated data can contain negative values */ def testALS(users: Int, products: Int, features: Int, iterations: Int, samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false, - bulkPredict: Boolean = false) + bulkPredict: Boolean = false, negativeWeights: Boolean = false) { val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products, - features, samplingRate, implicitPrefs) + features, samplingRate, implicitPrefs, negativeWeights) val model = implicitPrefs match { case false => ALS.train(sc.parallelize(sampledRatings), features, iterations) case true => ALS.trainImplicit(sc.parallelize(sampledRatings), features, iterations) @@ -166,13 +175,13 @@ class ALSSuite extends FunSuite with LocalSparkContext { for (u <- 0 until users; p <- 0 until products) { val prediction = predictedRatings.get(u, p) val truePref = truePrefs.get(u, p) - val confidence = 1 + 1.0 * trueRatings.get(u, p) + val confidence = 1 + 1.0 * abs(trueRatings.get(u, p)) val err = confidence * (truePref - prediction) * (truePref - prediction) sqErr += err - denom += 1 + denom += confidence } val rmse = math.sqrt(sqErr / denom) - if (math.abs(rmse) > matchThreshold) { + if (rmse > matchThreshold) { fail("Model failed to predict RMSE: %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format( rmse, truePrefs, predictedRatings, predictedU, predictedP)) }