Skip to content

Commit

Permalink
[SPARK-23697][CORE] LegacyAccumulatorWrapper should define isZero cor…
Browse files Browse the repository at this point in the history
…rectly

## What changes were proposed in this pull request?

It's possible that Accumulators of Spark 1.x may no longer work with Spark 2.x. This is because `LegacyAccumulatorWrapper.isZero` may return wrong answer if `AccumulableParam` doesn't define equals/hashCode.

This PR fixes this by using reference equality check in `LegacyAccumulatorWrapper.isZero`.

## How was this patch tested?

a new test

Author: Wenchen Fan <[email protected]>

Closes #21229 from cloud-fan/accumulator.

(cherry picked from commit 4d5de4d)
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
cloud-fan committed May 4, 2018
1 parent d51c6aa commit a42dd00
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,9 @@ class LegacyAccumulatorWrapper[R, T](
param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] {
private[spark] var _value = initialValue // Current value on driver

override def isZero: Boolean = _value == param.zero(initialValue)
@transient private lazy val _zero = param.zero(initialValue)

override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef])

override def copy(): LegacyAccumulatorWrapper[R, T] = {
val acc = new LegacyAccumulatorWrapper(initialValue, param)
Expand All @@ -485,7 +487,7 @@ class LegacyAccumulatorWrapper[R, T](
}

override def reset(): Unit = {
_value = param.zero(initialValue)
_value = _zero
}

override def add(v: T): Unit = _value = param.addAccumulator(_value, v)
Expand Down
19 changes: 19 additions & 0 deletions core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.util

import org.apache.spark._
import org.apache.spark.serializer.JavaSerializer

class AccumulatorV2Suite extends SparkFunSuite {

Expand Down Expand Up @@ -162,4 +163,22 @@ class AccumulatorV2Suite extends SparkFunSuite {
assert(acc3.isZero)
assert(acc3.value === "")
}

test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") {
class MyData(val i: Int) extends Serializable
val param = new AccumulatorParam[MyData] {
override def zero(initialValue: MyData): MyData = new MyData(0)
override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i)
}

val acc = new LegacyAccumulatorWrapper(new MyData(0), param)
acc.metadata = AccumulatorMetadata(
AccumulatorContext.newId(),
Some("test"),
countFailedValues = false)
AccumulatorContext.register(acc)

val ser = new JavaSerializer(new SparkConf).newInstance()
ser.serialize(acc)
}
}

0 comments on commit a42dd00

Please sign in to comment.