From 036b7a5cbced3d6a582ce7d3b7cdec4f4ab3a577 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 1 Apr 2014 20:53:48 +0800 Subject: [PATCH] fix the bug of Nan occur --- .../spark/mllib/rdd/VectorRDDFunctions.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala index a39b6f81cf6ed..029ef263d5d80 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.mllib.rdd -import breeze.linalg.{Vector => BV} +import breeze.linalg.{Vector => BV, axpy} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD @@ -92,8 +92,14 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable { VectorRDDStatisticalRing(mean2, m2n2, cnt2, nnz2, max2, min2)) => val totalCnt = cnt1 + cnt2 val deltaMean = mean2 - mean1 - val totalMean = ((mean1 :* nnz1) + (mean2 :* nnz2)) :/ (nnz1 + nnz2) - val totalM2n = m2n1 + m2n2 + ((deltaMean :* deltaMean) :* (nnz1 :* nnz2) :/ (nnz1 + nnz2)) + mean2.activeIterator.foreach { + case (id, 0.0) => + case (id, value) => mean1(id) = (mean1(id) * nnz1(id) + mean2(id) * nnz2(id)) / (nnz1(id) + nnz2(id)) + } + m2n2.activeIterator.foreach { + case (id, 0.0) => + case (id, value) => m2n1(id) += value + deltaMean(id) * deltaMean(id) * nnz1(id) * nnz2(id) / (nnz1(id)+nnz2(id)) + } max2.activeIterator.foreach { case (id, value) => if (max1(id) < value) max1(id) = value @@ -102,7 +108,8 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable { case (id, value) => if (min1(id) > value) min1(id) = value } - VectorRDDStatisticalRing(totalMean, totalM2n, totalCnt, nnz1 + nnz2, max1, min1) + axpy(1.0, nnz2, nnz1) + VectorRDDStatisticalRing(mean1, m2n1, totalCnt, nnz1, max1, min1) } }