Skip to content

Commit

Permalink
[SPARK-15079] Support average/count/sum in Long/DoubleAccumulator
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
This patch removes AverageAccumulator and adds the ability to compute average to LongAccumulator and DoubleAccumulator. The patch also improves documentation for the two accumulators.

## How was this patch tested?
Added unit tests for this.

Author: Reynold Xin <[email protected]>

Closes apache#12858 from rxin/SPARK-15079.
  • Loading branch information
rxin committed May 3, 2016
1 parent 8028f3a commit bb9ab56
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 101 deletions.
17 changes: 0 additions & 17 deletions core/src/main/scala/org/apache/spark/Accumulator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@

package org.apache.spark

import org.apache.spark.storage.{BlockId, BlockStatus}


/**
* A simpler value of [[Accumulable]] where the result type being accumulated is the same
* as the types of elements being merged, i.e. variables that are only "added" to through an
Expand Down Expand Up @@ -117,18 +114,4 @@ object AccumulatorParam {
def addInPlace(t1: String, t2: String): String = t2
def zero(initialValue: String): String = ""
}

// Note: this is expensive as it makes a copy of the list every time the caller adds an item.
// A better way to use this is to first accumulate the values yourself then them all at once.
@deprecated("use AccumulatorV2", "2.0.0")
private[spark] class ListAccumulatorParam[T] extends AccumulatorParam[Seq[T]] {
def addInPlace(t1: Seq[T], t2: Seq[T]): Seq[T] = t1 ++ t2
def zero(initialValue: Seq[T]): Seq[T] = Seq.empty[T]
}

// For the internal metric that records what blocks are updated in a particular task
@deprecated("use AccumulatorV2", "2.0.0")
private[spark] object UpdatedBlockStatusesAccumulatorParam
extends ListAccumulatorParam[(BlockId, BlockStatus)]

}
137 changes: 91 additions & 46 deletions core/src/main/scala/org/apache/spark/AccumulatorV2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -257,23 +257,66 @@ private[spark] object AccumulatorContext {
}


/**
* An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 64-bit integers.
*
* @since 2.0.0
*/
class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
private[this] var _sum = 0L
private[this] var _count = 0L

override def isZero: Boolean = _sum == 0
/**
* Adds v to the accumulator, i.e. increment sum by v and count by 1.
* @since 2.0.0
*/
override def isZero: Boolean = _count == 0L

override def copyAndReset(): LongAccumulator = new LongAccumulator

override def add(v: jl.Long): Unit = _sum += v
/**
* Adds v to the accumulator, i.e. increment sum by v and count by 1.
* @since 2.0.0
*/
override def add(v: jl.Long): Unit = {
_sum += v
_count += 1
}

/**
* Adds v to the accumulator, i.e. increment sum by v and count by 1.
* @since 2.0.0
*/
def add(v: Long): Unit = {
_sum += v
_count += 1
}

def add(v: Long): Unit = _sum += v
/**
* Returns the number of elements added to the accumulator.
* @since 2.0.0
*/
def count: Long = _count

/**
* Returns the sum of elements added to the accumulator.
* @since 2.0.0
*/
def sum: Long = _sum

/**
* Returns the average of elements added to the accumulator.
* @since 2.0.0
*/
def avg: Double = _sum.toDouble / _count

override def merge(other: AccumulatorV2[jl.Long, jl.Long]): Unit = other match {
case o: LongAccumulator => _sum += o.sum
case _ => throw new UnsupportedOperationException(
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
case o: LongAccumulator =>
_sum += o.sum
_count += o.count
case _ =>
throw new UnsupportedOperationException(
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}

private[spark] def setValue(newValue: Long): Unit = _sum = newValue
Expand All @@ -282,66 +325,68 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
}


/**
* An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for double precision
* floating numbers.
*
* @since 2.0.0
*/
class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
private[this] var _sum = 0.0

override def isZero: Boolean = _sum == 0.0

override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator

override def add(v: jl.Double): Unit = _sum += v

def add(v: Double): Unit = _sum += v

def sum: Double = _sum

override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match {
case o: DoubleAccumulator => _sum += o.sum
case _ => throw new UnsupportedOperationException(
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}

private[spark] def setValue(newValue: Double): Unit = _sum = newValue

override def localValue: jl.Double = _sum
}


class AverageAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
private[this] var _sum = 0.0
private[this] var _count = 0L

override def isZero: Boolean = _sum == 0.0 && _count == 0
override def isZero: Boolean = _count == 0L

override def copyAndReset(): AverageAccumulator = new AverageAccumulator
override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator

/**
* Adds v to the accumulator, i.e. increment sum by v and count by 1.
* @since 2.0.0
*/
override def add(v: jl.Double): Unit = {
_sum += v
_count += 1
}

def add(d: Double): Unit = {
_sum += d
/**
* Adds v to the accumulator, i.e. increment sum by v and count by 1.
* @since 2.0.0
*/
def add(v: Double): Unit = {
_sum += v
_count += 1
}

/**
* Returns the number of elements added to the accumulator.
* @since 2.0.0
*/
def count: Long = _count

/**
* Returns the sum of elements added to the accumulator.
* @since 2.0.0
*/
def sum: Double = _sum

/**
* Returns the average of elements added to the accumulator.
* @since 2.0.0
*/
def avg: Double = _sum / _count

override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match {
case o: AverageAccumulator =>
case o: DoubleAccumulator =>
_sum += o.sum
_count += o.count
case _ => throw new UnsupportedOperationException(
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}

override def localValue: jl.Double = if (_count == 0) {
Double.NaN
} else {
_sum / _count
case _ =>
throw new UnsupportedOperationException(
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}

def sum: Double = _sum
private[spark] def setValue(newValue: Double): Unit = _sum = newValue

def count: Long = _count
override def localValue: jl.Double = _sum
}


Expand Down
22 changes: 0 additions & 22 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1340,28 +1340,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
acc
}

/**
* Create and register an average accumulator, which accumulates double inputs by recording the
* total sum and total count, and produce the output by sum / total. Note that Double.NaN will be
* returned if no input is added.
*/
def averageAccumulator: AverageAccumulator = {
val acc = new AverageAccumulator
register(acc)
acc
}

/**
* Create and register an average accumulator, which accumulates double inputs by recording the
* total sum and total count, and produce the output by sum / total. Note that Double.NaN will be
* returned if no input is added.
*/
def averageAccumulator(name: String): AverageAccumulator = {
val acc = new AverageAccumulator
register(acc, name)
acc
}

/**
* Create and register a list accumulator, which starts with empty list and accumulates inputs
* by adding them into the inner list.
Expand Down
17 changes: 1 addition & 16 deletions core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import scala.util.control.NonFatal
import org.scalatest.Matchers
import org.scalatest.exceptions.TestFailedException

import org.apache.spark.AccumulatorParam.{ListAccumulatorParam, StringAccumulatorParam}
import org.apache.spark.AccumulatorParam.StringAccumulatorParam
import org.apache.spark.scheduler._
import org.apache.spark.serializer.JavaSerializer

Expand Down Expand Up @@ -234,21 +234,6 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
acc.merge("kindness")
assert(acc.value === "kindness")
}

test("list accumulator param") {
val acc = new Accumulator(Seq.empty[Int], new ListAccumulatorParam[Int], Some("numbers"))
assert(acc.value === Seq.empty[Int])
acc.add(Seq(1, 2))
assert(acc.value === Seq(1, 2))
acc += Seq(3, 4)
assert(acc.value === Seq(1, 2, 3, 4))
acc ++= Seq(5, 6)
assert(acc.value === Seq(1, 2, 3, 4, 5, 6))
acc.merge(Seq(7, 8))
assert(acc.value === Seq(1, 2, 3, 4, 5, 6, 7, 8))
acc.setValue(Seq(9, 10))
assert(acc.value === Seq(9, 10))
}
}

private[spark] object AccumulatorSuite {
Expand Down
89 changes: 89 additions & 0 deletions core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util

import org.apache.spark.{DoubleAccumulator, LongAccumulator, SparkFunSuite}

class AccumulatorV2Suite extends SparkFunSuite {

test("LongAccumulator add/avg/sum/count/isZero") {
val acc = new LongAccumulator
assert(acc.isZero)
assert(acc.count == 0)
assert(acc.sum == 0)
assert(acc.avg.isNaN)

acc.add(0)
assert(!acc.isZero)
assert(acc.count == 1)
assert(acc.sum == 0)
assert(acc.avg == 0.0)

acc.add(1)
assert(acc.count == 2)
assert(acc.sum == 1)
assert(acc.avg == 0.5)

// Also test add using non-specialized add function
acc.add(new java.lang.Long(2))
assert(acc.count == 3)
assert(acc.sum == 3)
assert(acc.avg == 1.0)

// Test merging
val acc2 = new LongAccumulator
acc2.add(2)
acc.merge(acc2)
assert(acc.count == 4)
assert(acc.sum == 5)
assert(acc.avg == 1.25)
}

test("DoubleAccumulator add/avg/sum/count/isZero") {
val acc = new DoubleAccumulator
assert(acc.isZero)
assert(acc.count == 0)
assert(acc.sum == 0.0)
assert(acc.avg.isNaN)

acc.add(0.0)
assert(!acc.isZero)
assert(acc.count == 1)
assert(acc.sum == 0.0)
assert(acc.avg == 0.0)

acc.add(1.0)
assert(acc.count == 2)
assert(acc.sum == 1.0)
assert(acc.avg == 0.5)

// Also test add using non-specialized add function
acc.add(new java.lang.Double(2.0))
assert(acc.count == 3)
assert(acc.sum == 3.0)
assert(acc.avg == 1.0)

// Test merging
val acc2 = new DoubleAccumulator
acc2.add(2.0)
acc.merge(acc2)
assert(acc.count == 4)
assert(acc.sum == 5.0)
assert(acc.avg == 1.25)
}
}

0 comments on commit bb9ab56

Please sign in to comment.