Skip to content
This repository has been archived by the owner on Jan 20, 2022. It is now read-only.

Add aggregate to summingbird #562

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License.

package com.twitter.summingbird.memory

import com.twitter.algebird.{ MapAlgebra, Monoid, Semigroup }
import com.twitter.algebird.{ Aggregator, MapAlgebra, Monoid, Semigroup }
import com.twitter.summingbird._
import com.twitter.summingbird.option.JobId
import org.scalacheck.{ Arbitrary, _ }
Expand Down Expand Up @@ -214,6 +214,32 @@ class MemoryLaws extends WordSpec {
assert(store1.toMap == ((0 to 100).groupBy(_ % 3).mapValues(_.sum)))
assert(store2.toMap == ((0 to 100).groupBy(_ % 3).mapValues(_.sum)))
}
"aggregate should work" in {
val source = Memory.toSource((0 to 100).reverse)
val store = MutableMap.empty[Int, Int]
val buf = MutableMap.empty[Int, List[(Option[Int], Int)]]
val prod = source.map { t => (t % 2, t) }
.aggregate(store, Aggregator.max[Int].andThenPresent(_ * 2).composePrepare(_ / 2))
.write { kv =>
val (k, vs) = kv
buf(k) = vs :: buf.getOrElse(k, Nil)
}
val mem = new Memory
mem.run(mem.plan(prod))

assert(store.keySet == Set(0, 1))
assert(store(0) == (0 to 100).filter(_ % 2 == 0).map(_ / 2).max)
assert(store(1) == (0 to 100).filter(_ % 2 == 1).map(_ / 2).max)
assert(buf.keySet == Set(0, 1))
assert(buf(0).map(_._2) ==
(0 to 100).reverse.filter(_ % 2 == 0).map { t => (t / 2) * 2 }.toList)
assert(buf(0).map(_._1) ==
(None :: ((0 to 100).reverse.filter(_ % 2 == 0).map { t => Some((t / 2)*2) }.toList)))
assert(buf(1).map(_._2) ==
(0 to 100).reverse.filter(_ % 2 == 1).map { t => (t / 2) * 2 }.toList)
assert(buf(1).map(_._1) ==
(None :: ((0 to 100).reverse.filter(_ % 2 == 1).map { t => Some((t / 2)*2) }.toList)))
}

"self also shouldn't duplicate work" in {
val platform = new Memory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.twitter.summingbird

import com.twitter.algebird.Semigroup
import com.twitter.algebird.{ Aggregator, Semigroup }

object Producer {

Expand Down Expand Up @@ -251,6 +251,23 @@ case class Summer[P <: Platform[P], K, V](
*/
sealed trait KeyedProducer[P <: Platform[P], K, V] extends Producer[P, (K, V)] {

/**
* This applies an Aggregator to the values. The result type is similar to sumByKey with
* a crucial difference: the tuple is Option(previous aggregated value), current aggregated value
* in sumByKey you get previous and the delta, but after agg.present, the delta cannot be combined
* and is not meaningful in the general case.
*/
def aggregate[V1, V2](store: P#Store[K, V1], agg: Aggregator[V, V1, V2]): KeyedProducer[P, K, (Option[V2], V2)] = {
val sg = agg.semigroup
mapValues(agg.prepare)
.sumByKey(store)(sg)
.mapValues {
case (optv1, v1) =>
val resultv1 = if (optv1.isDefined) sg.plus(optv1.get, v1) else v1
(optv1.map(agg.present), agg.present(resultv1))
}
}

/** Builds a new KeyedProvider by applying a partial function to keys of elements of this one on which the function is defined.*/
def collectKeys[K2](pf: PartialFunction[K, K2]): KeyedProducer[P, K2, V] =
IdentityKeyedProducer(collect { case (k, v) if pf.isDefinedAt(k) => (pf(k), v) })
Expand Down