Skip to content

Commit

Permalink
fixed error, updated test
Browse files Browse the repository at this point in the history
  • Loading branch information
dwmclary committed Mar 17, 2014
1 parent 82cde0e commit fd3fd4b
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* @return the maximum of the RDD
* */
def max(comp: Comparator[T]): T = {
import scala.collection.JavaConversions._
rdd.max()(Ordering.comparatorToOrdering(comp))
}

Expand All @@ -495,7 +494,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* @return the minimum of the RDD
* */
def min(comp: Comparator[T]): T = {
import scala.collection.JavaConversions._
rdd.min()(Ordering.comparatorToOrdering(comp))
}

Expand Down
26 changes: 13 additions & 13 deletions core/src/main/scala/org/apache/spark/util/StatCounter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
private var n: Long = 0 // Running count of our values
private var mu: Double = 0 // Running mean of our values
private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2)
private var max_v: Double = Double.NegativeInfinity // Running max of our values
private var min_v: Double = Double.PositiveInfinity // Running min of our values
private var maxValue: Double = Double.NegativeInfinity // Running max of our values
private var minValue: Double = Double.PositiveInfinity // Running min of our values

merge(values)

Expand All @@ -43,8 +43,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
n += 1
mu += delta / n
m2 += delta * (value - mu)
max_v = math.max(max_v, value)
min_v = math.min(min_v, value)
maxValue = math.max(maxValue, value)
minValue = math.min(minValue, value)
this
}

Expand All @@ -63,8 +63,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
mu = other.mu
m2 = other.m2
n = other.n
max_v = other.max_v
min_v = other.min_v
maxValue = other.maxValue
minValue = other.minValue
} else if (other.n != 0) {
val delta = other.mu - mu
if (other.n * 10 < n) {
Expand All @@ -76,8 +76,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
}
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
n += other.n
max_v = math.max(max_v, other.max_v)
min_v = math.min(min_v, other.min_v)
maxValue = math.max(maxValue, other.maxValue)
minValue = math.min(minValue, other.minValue)
}
this
}
Expand All @@ -89,8 +89,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
other.n = n
other.mu = mu
other.m2 = m2
other.max_v = max_v
other.min_v = min_v
other.maxValue = maxValue
other.minValue = minValue
other
}

Expand All @@ -100,9 +100,9 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {

def sum: Double = n * mu

def max: Double = max_v
def max: Double = maxValue

def min: Double = min_v
def min: Double = minValue

/** Return the variance of the values. */
def variance: Double = {
Expand Down Expand Up @@ -135,7 +135,7 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
def sampleStdev: Double = math.sqrt(sampleVariance)

override def toString: String = {
"(count: %d, mean: %f, stdev: %f, max: %f, min: $f)".format(count, mean, stdev, max, min)
"(count: %d, mean: %f, stdev: %f, max: %f, min: %f)".format(count, mean, stdev, max, min)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
assert(abs(1.0 - rdd.variance) < 0.01)
assert(abs(1.0 - rdd.stdev) < 0.01)
assert(stats.max === 4.0)
assert(stats.min === -1.0)
assert(stats.min === 2.0)

// Add other tests here for classes that should be able to handle empty partitions correctly
}
Expand Down
24 changes: 12 additions & 12 deletions python/pyspark/statcounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def __init__(self, values=[]):
self.n = 0L # Running count of our values
self.mu = 0.0 # Running mean of our values
self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2)
self.max_v = float("-inf")
self.min_v = float("inf")
self.maxValue = float("-inf")
self.minValue = float("inf")

for v in values:
self.merge(v)
Expand All @@ -38,10 +38,10 @@ def merge(self, value):
self.n += 1
self.mu += delta / self.n
self.m2 += delta * (value - self.mu)
if self.max_v < value:
self.max_v = value
if self.min_v > value:
self.min_v = value
if self.maxValue < value:
self.maxValue = value
if self.minValue > value:
self.minValue = value

return self

Expand All @@ -57,8 +57,8 @@ def mergeStats(self, other):
self.mu = other.mu
self.m2 = other.m2
self.n = other.n
self.max_v = other.max_v
self.min_v = other.min_v
self.maxValue = other.maxValue
self.minValue = other.minValue

elif other.n != 0:
delta = other.mu - self.mu
Expand All @@ -69,8 +69,8 @@ def mergeStats(self, other):
else:
self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n)

self.max_v = max(self.max_v, other.max_v)
self.min_v = min(self.min_v, other.min_v)
self.maxValue = max(self.maxValue, other.maxValue)
self.minValue = min(self.minValue, other.minValue)

self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n)
self.n += other.n
Expand All @@ -90,10 +90,10 @@ def sum(self):
return self.n * self.mu

def min(self):
return self.min_v
return self.minValue

def max(self):
return self.max_v
return self.maxValue

# Return the variance of the values.
def variance(self):
Expand Down

0 comments on commit fd3fd4b

Please sign in to comment.