From 4e0b83230332cdbe6c398489a86999f9af0480d4 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Wed, 14 Feb 2018 15:24:05 -1000 Subject: [PATCH] merge in cherry pick of dagon use --- build.sbt | 6 +- .../com/twitter/summingbird/graph/Expr.scala | 84 ---- .../summingbird/graph/ExpressionDag.scala | 368 ---------------- .../com/twitter/summingbird/graph/HMap.scala | 122 ------ .../summingbird/memory/ConcurrentMemory.scala | 11 +- .../twitter/summingbird/memory/Memory.scala | 2 +- .../planner/ComposedFunctions.scala | 16 + .../summingbird/planner/DagOptimizer.scala | 400 ++++++++---------- .../graph/ExpressionDagTests.scala | 205 --------- .../twitter/summingbird/graph/HMapTests.scala | 107 ----- .../summingbird/graph/LiteralTests.scala | 68 --- 11 files changed, 210 insertions(+), 1179 deletions(-) delete mode 100644 summingbird-core/src/main/scala/com/twitter/summingbird/graph/Expr.scala delete mode 100644 summingbird-core/src/main/scala/com/twitter/summingbird/graph/ExpressionDag.scala delete mode 100644 summingbird-core/src/main/scala/com/twitter/summingbird/graph/HMap.scala delete mode 100644 summingbird-core/src/test/scala/com/twitter/summingbird/graph/ExpressionDagTests.scala delete mode 100644 summingbird-core/src/test/scala/com/twitter/summingbird/graph/HMapTests.scala delete mode 100644 summingbird-core/src/test/scala/com/twitter/summingbird/graph/LiteralTests.scala diff --git a/build.sbt b/build.sbt index 43e09b4a5..7bb6a26cd 100644 --- a/build.sbt +++ b/build.sbt @@ -20,7 +20,7 @@ val bijectionVersion = "0.9.1" val chillVersion = "0.7.3" val commonsHttpClientVersion = "3.1" val commonsLangVersion = "2.6" -val finagleVersion = "6.27.0" +val dagonVersion = "0.3.0" val hadoopVersion = "1.2.1" val junitVersion = "4.11" val log4jVersion = "1.2.16" @@ -234,7 +234,9 @@ lazy val summingbirdClient = module("client").settings( ).dependsOn(summingbirdBatch) lazy val summingbirdCore = module("core").settings( - libraryDependencies += "com.twitter" %% "algebird-core" % algebirdVersion + libraryDependencies ++= Seq( + "com.twitter" %% "algebird-core" % algebirdVersion, + "com.stripe" %% "dagon-core" % dagonVersion) ) lazy val summingbirdOnline = module("online").settings( diff --git a/summingbird-core/src/main/scala/com/twitter/summingbird/graph/Expr.scala b/summingbird-core/src/main/scala/com/twitter/summingbird/graph/Expr.scala deleted file mode 100644 index d582bb620..000000000 --- a/summingbird-core/src/main/scala/com/twitter/summingbird/graph/Expr.scala +++ /dev/null @@ -1,84 +0,0 @@ -/* - Copyright 2014 Twitter, Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package com.twitter.summingbird.graph - -/** - * The Expressions are assigned Ids. Each Id is associated with - * an expression of inner type T. - * - * This is done to put an indirection in the ExpressionDag that - * allows us to rewrite nodes by simply replacing the expressions - * associated with given Ids. - * - * T is a phantom type used by the type system - */ -final case class Id[T](id: Int) - -/** - * Expr[T, N] is an expression of a graph of container nodes N[_] with - * result type N[T]. These expressions are like the Literal[T, N] graphs - * except that functions always operate with an indirection of a Id[T] - * where N[T] is the type of the input node. - * - * Nodes can be deleted from the graph by replacing an Expr at Id = idA - * with Var(idB) pointing to some upstream node. - * - * To add nodes to the graph, add depth to the final node returned in - * a Unary or Binary expression. - * - * TODO: see the approach here: https://gist.github.com/pchiusano/1369239 - * Which seems to show a way to do currying, so we can handle general - * arity - */ -sealed trait Expr[T, N[_]] { - def evaluate(idToExp: HMap[Id, ({ type E[t] = Expr[t, N] })#E]): N[T] = - Expr.evaluate(idToExp, this) -} -case class Const[T, N[_]](value: N[T]) extends Expr[T, N] { - override def evaluate(idToExp: HMap[Id, ({ type E[t] = Expr[t, N] })#E]): N[T] = value -} -case class Var[T, N[_]](name: Id[T]) extends Expr[T, N] -case class Unary[T1, T2, N[_]](arg: Id[T1], fn: N[T1] => N[T2]) extends Expr[T2, N] -case class Binary[T1, T2, T3, N[_]](arg1: Id[T1], - arg2: Id[T2], - fn: (N[T1], N[T2]) => N[T3]) extends Expr[T3, N] - -object Expr { - def evaluate[T, N[_]](idToExp: HMap[Id, ({ type E[t] = Expr[t, N] })#E], expr: Expr[T, N]): N[T] = - evaluate(idToExp, HMap.empty[({ type E[t] = Expr[t, N] })#E, N], expr)._2 - - private def evaluate[T, N[_]](idToExp: HMap[Id, ({ type E[t] = Expr[t, N] })#E], - cache: HMap[({ type E[t] = Expr[t, N] })#E, N], - expr: Expr[T, N]): (HMap[({ type E[t] = Expr[t, N] })#E, N], N[T]) = cache.get(expr) match { - case Some(node) => (cache, node) - case None => expr match { - case Const(n) => (cache + (expr -> n), n) - case Var(id) => - val (c1, n) = evaluate(idToExp, cache, idToExp(id)) - (c1 + (expr -> n), n) - case Unary(id, fn) => - val (c1, n1) = evaluate(idToExp, cache, idToExp(id)) - val n2 = fn(n1) - (c1 + (expr -> n2), n2) - case Binary(id1, id2, fn) => - val (c1, n1) = evaluate(idToExp, cache, idToExp(id1)) - val (c2, n2) = evaluate(idToExp, c1, idToExp(id2)) - val n3 = fn(n1, n2) - (c2 + (expr -> n3), n3) - } - } -} diff --git a/summingbird-core/src/main/scala/com/twitter/summingbird/graph/ExpressionDag.scala b/summingbird-core/src/main/scala/com/twitter/summingbird/graph/ExpressionDag.scala deleted file mode 100644 index 02ee40ed9..000000000 --- a/summingbird-core/src/main/scala/com/twitter/summingbird/graph/ExpressionDag.scala +++ /dev/null @@ -1,368 +0,0 @@ -/* - Copyright 2014 Twitter, Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package com.twitter.summingbird.graph - -///////////////////// -// There is no logical reason for Literal[T, N] to be here, -// but the scala compiler crashes in 2.9.3 if it is not. -// with: -// java.lang.Error: typeConstructor inapplicable for -// at scala.tools.nsc.symtab.SymbolTable.abort(SymbolTable.scala:34) -// at scala.tools.nsc.symtab.Symbols$Symbol.typeConstructor(Symbols.scala:880) -//////////////////// - -/** - * This represents literal expressions (no variable redirection) - * of container nodes of type N[T] - */ -sealed trait Literal[T, N[_]] { - def evaluate: N[T] = Literal.evaluate(this) -} -case class ConstLit[T, N[_]](override val evaluate: N[T]) extends Literal[T, N] -case class UnaryLit[T1, T2, N[_]](arg: Literal[T1, N], - fn: N[T1] => N[T2]) extends Literal[T2, N] { -} -case class BinaryLit[T1, T2, T3, N[_]](arg1: Literal[T1, N], arg2: Literal[T2, N], - fn: (N[T1], N[T2]) => N[T3]) extends Literal[T3, N] { -} - -object Literal { - /** - * This evaluates a literal formula back to what it represents - * being careful to handle diamonds by creating referentially - * equivalent structures (not just structurally equivalent) - */ - def evaluate[T, N[_]](lit: Literal[T, N]): N[T] = - evaluate(HMap.empty[({ type L[T] = Literal[T, N] })#L, N], lit)._2 - - // Memoized version of the above to handle diamonds - private def evaluate[T, N[_]](hm: HMap[({ type L[T] = Literal[T, N] })#L, N], lit: Literal[T, N]): (HMap[({ type L[T] = Literal[T, N] })#L, N], N[T]) = - hm.get(lit) match { - case Some(prod) => (hm, prod) - case None => - lit match { - case ConstLit(prod) => (hm + (lit -> prod), prod) - case UnaryLit(in, fn) => - val (h1, p1) = evaluate(hm, in) - val p2 = fn(p1) - (h1 + (lit -> p2), p2) - case BinaryLit(in1, in2, fn) => - val (h1, p1) = evaluate(hm, in1) - val (h2, p2) = evaluate(h1, in2) - val p3 = fn(p1, p2) - (h2 + (lit -> p3), p3) - } - } -} - -sealed trait ExpressionDag[N[_]] { self => - // Once we fix N above, we can make E[T] = Expr[T, N] - type E[t] = Expr[t, N] - type Lit[t] = Literal[t, N] - - /** - * These have package visibility to test - * the law that for all Expr, the node they - * evaluate to is unique - */ - protected[graph] def idToExp: HMap[Id, E] - protected def nodeToLiteral: GenFunction[N, Lit] - protected def roots: Set[Id[_]] - protected def nextId: Int - - private def copy(id2Exp: HMap[Id, E] = self.idToExp, - node2Literal: GenFunction[N, Lit] = self.nodeToLiteral, - gcroots: Set[Id[_]] = self.roots, - id: Int = self.nextId): ExpressionDag[N] = new ExpressionDag[N] { - def idToExp = id2Exp - def roots = gcroots - def nodeToLiteral = node2Literal - def nextId = id - } - - override def toString: String = - "ExpressionDag(idToExp = %s)".format(idToExp) - - // This is a cache of Id[T] => Option[N[T]] - private val idToN = - new HCache[Id, ({ type ON[T] = Option[N[T]] })#ON]() - private val nodeToId = - new HCache[N, ({ type OID[T] = Option[Id[T]] })#OID]() - - /** - * Add a GC root, or tail in the DAG, that can never be deleted - * currently, we only support a single root - */ - private def addRoot[_](id: Id[_]) = copy(gcroots = roots + id) - - /** - * Which ids are reachable from the roots - */ - private def reachableIds: Set[Id[_]] = { - // We actually don't care about the return type of the Set - // This is a constant function at the type level - type IdSet[t] = Set[Id[_]] - def expand(s: Set[Id[_]]): Set[Id[_]] = { - val partial = new GenPartial[HMap[Id, E]#Pair, IdSet] { - def apply[T] = { - case (id, Const(_)) if s(id) => s - case (id, Var(v)) if s(id) => s + v - case (id, Unary(id0, _)) if s(id) => s + id0 - case (id, Binary(id0, id1, _)) if s(id) => (s + id0) + id1 - } - } - // Note this Stream must always be non-empty as long as roots are - idToExp.collect[IdSet](partial) - .reduce(_ ++ _) - } - // call expand while we are still growing - def go(s: Set[Id[_]]): Set[Id[_]] = { - val step = expand(s) - if (step == s) s - else go(step) - } - go(roots) - } - - private def gc: ExpressionDag[N] = { - val goodIds = reachableIds - type BoolT[t] = Boolean - val toKeepI2E = idToExp.filter(new GenFunction[HMap[Id, E]#Pair, BoolT] { - def apply[T] = { idExp => goodIds(idExp._1) } - }) - copy(id2Exp = toKeepI2E) - } - - /** - * Apply the given rule to the given dag until - * the graph no longer changes. - */ - def apply(rule: Rule[N]): ExpressionDag[N] = { - // for some reason, scala can't optimize this with tailrec - var prev: ExpressionDag[N] = null - var curr: ExpressionDag[N] = this - while (!(curr eq prev)) { - prev = curr - curr = curr.applyOnce(rule) - } - curr - } - - protected def toExpr[T](n: N[T]): (ExpressionDag[N], Expr[T, N]) = { - val (dag, id) = ensure(n) - val exp = dag.idToExp(id) - (dag, exp) - } - - /** - * Convert a N[T] to a Literal[T, N] - */ - def toLiteral[T](n: N[T]): Literal[T, N] = nodeToLiteral.apply[T](n) - - /** - * apply the rule at the first place that satisfies - * it, and return from there. - */ - def applyOnce(rule: Rule[N]): ExpressionDag[N] = { - val getN = new GenPartial[HMap[Id, E]#Pair, HMap[Id, N]#Pair] { - def apply[U] = { - val fn = rule.apply[U](self) - - { - case (id, exp) if fn(exp.evaluate(idToExp)).isDefined => - // Sucks to have to call fn, twice, but oh well - (id, fn(exp.evaluate(idToExp)).get) - } - } - } - idToExp.collect[HMap[Id, N]#Pair](getN).headOption match { - case None => this - case Some(tup) => - // some type hand holding - def act[T](in: HMap[Id, N]#Pair[T]) = { - val (i, n) = in - val oldNode = evaluate(i) - val (dag, exp) = toExpr(n) - dag.copy(id2Exp = dag.idToExp + (i -> exp)) - } - // This cast should not be needed - act(tup.asInstanceOf[HMap[Id, N]#Pair[Any]]).gc - } - } - - // This is only called by ensure - private def addExp[T](node: N[T], exp: Expr[T, N]): (ExpressionDag[N], Id[T]) = { - val nodeId = Id[T](nextId) - (copy(id2Exp = idToExp + (nodeId -> exp), id = nextId + 1), nodeId) - } - - /** - * This finds the Id[T] in the current graph that is equivalent - * to the given N[T] - */ - def find[T](node: N[T]): Option[Id[T]] = nodeToId.getOrElseUpdate(node, { - val partial = new GenPartial[HMap[Id, E]#Pair, Id] { - def apply[T] = { case (thisId, expr) if node == expr.evaluate(idToExp) => thisId } - } - idToExp.collect(partial).headOption.asInstanceOf[Option[Id[T]]] - }) - - /** - * This throws if the node is missing, use find if this is not - * a logic error in your programming. With dependent types we could - * possibly get this to not compile if it could throw. - */ - def idOf[T](node: N[T]): Id[T] = - find(node) - .getOrElse(sys.error("could not get node: %s\n from %s".format(node, this))) - - /** - * ensure the given literal node is present in the Dag - * Note: it is important that at each moment, each node has - * at most one id in the graph. Put another way, for all - * Id[T] in the graph evaluate(id) is distinct. - */ - protected def ensure[T](node: N[T]): (ExpressionDag[N], Id[T]) = - find(node) match { - case Some(id) => (this, id) - case None => { - val lit: Lit[T] = toLiteral(node) - lit match { - case ConstLit(n) => - /** - * Since the code is not performance critical, but correctness critical, and we can't - * check this property with the typesystem easily, check it here - */ - assert(n == node, - "Equality or nodeToLiteral is incorrect: nodeToLit(%s) = ConstLit(%s)".format(node, n)) - addExp(node, Const(n)) - case UnaryLit(prev, fn) => - val (exp1, idprev) = ensure(prev.evaluate) - exp1.addExp(node, Unary(idprev, fn)) - case BinaryLit(n1, n2, fn) => - val (exp1, id1) = ensure(n1.evaluate) - val (exp2, id2) = exp1.ensure(n2.evaluate) - exp2.addExp(node, Binary(id1, id2, fn)) - } - } - } - - /** - * After applying rules to your Dag, use this method - * to get the original node type. - * Only call this on an Id[T] that was generated by - * this dag or a parent. - */ - def evaluate[T](id: Id[T]): N[T] = - evaluateOption(id).getOrElse(sys.error("Could not evaluate: %s\nin %s".format(id, this))) - - def evaluateOption[T](id: Id[T]): Option[N[T]] = - idToN.getOrElseUpdate(id, { - val partial = new GenPartial[HMap[Id, E]#Pair, N] { - def apply[T] = { case (thisId, expr) if (id == thisId) => expr.evaluate(idToExp) } - } - idToExp.collect(partial).headOption.asInstanceOf[Option[N[T]]] - }) - - /** - * Return the number of nodes that depend on the - * given Id, TODO we might want to cache these. - * We need to garbage collect nodes that are - * no longer reachable from the root - */ - def fanOut(id: Id[_]): Int = { - // We make a fake IntT[T] which is just Int - val partial = new GenPartial[E, ({ type IntT[T] = Int })#IntT] { - def apply[T] = { - case Var(id1) if (id1 == id) => 1 - case Unary(id1, fn) if (id1 == id) => 1 - case Binary(id1, id2, fn) if (id1 == id) && (id2 == id) => 2 - case Binary(id1, id2, fn) if (id1 == id) || (id2 == id) => 1 - case _ => 0 - } - } - idToExp.collectValues[({ type IntT[T] = Int })#IntT](partial).sum - } - - /** - * Returns 0 if the node is absent, which is true - * use .contains(n) to check for containment - */ - def fanOut(node: N[_]): Int = find(node).map(fanOut(_)).getOrElse(0) - def contains(node: N[_]): Boolean = find(node).isDefined -} - -object ExpressionDag { - private def empty[N[_]](n2l: GenFunction[N, ({ type L[t] = Literal[t, N] })#L]): ExpressionDag[N] = - new ExpressionDag[N] { - val idToExp = HMap.empty[Id, ({ type E[t] = Expr[t, N] })#E] - val nodeToLiteral = n2l - val roots = Set.empty[Id[_]] - val nextId = 0 - } - - /** - * This creates a new ExpressionDag rooted at the given tail node - */ - def apply[T, N[_]](n: N[T], - nodeToLit: GenFunction[N, ({ type L[t] = Literal[t, N] })#L]): (ExpressionDag[N], Id[T]) = { - val (dag, id) = empty(nodeToLit).ensure(n) - (dag.addRoot(id), id) - } - - /** - * This is the most useful function. Given a N[T] and a way to convert to Literal[T, N], - * apply the given rule until it no longer applies, and return the N[T] which is - * equivalent under the given rule - */ - def applyRule[T, N[_]](n: N[T], - nodeToLit: GenFunction[N, ({ type L[t] = Literal[t, N] })#L], - rule: Rule[N]): N[T] = { - val (dag, id) = apply(n, nodeToLit) - dag(rule).evaluate(id) - } -} - -/** - * This implements a simplification rule on ExpressionDags - */ -trait Rule[N[_]] { self => - /** - * If the given Id can be replaced with a simpler expression, - * return Some(expr) else None. - * - * If it is convenient, you might write a partial function - * and then call .lift to get the correct Function type - */ - def apply[T](on: ExpressionDag[N]): (N[T] => Option[N[T]]) - - // If the current rule cannot apply, then try the argument here - def orElse(that: Rule[N]): Rule[N] = new Rule[N] { - def apply[T](on: ExpressionDag[N]) = { n => - self.apply(on)(n).orElse(that.apply(on)(n)) - } - } -} - -/** - * Often a partial function is an easier way to express rules - */ -trait PartialRule[N[_]] extends Rule[N] { - final def apply[T](on: ExpressionDag[N]) = applyWhere[T](on).lift - def applyWhere[T](on: ExpressionDag[N]): PartialFunction[N[T], N[T]] -} - diff --git a/summingbird-core/src/main/scala/com/twitter/summingbird/graph/HMap.scala b/summingbird-core/src/main/scala/com/twitter/summingbird/graph/HMap.scala deleted file mode 100644 index 4926565cb..000000000 --- a/summingbird-core/src/main/scala/com/twitter/summingbird/graph/HMap.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - Copyright 2014 Twitter, Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package com.twitter.summingbird.graph - -/** - * This is a weak heterogenous map. It uses equals on the keys, - * so it is your responsibilty that if k: K[_] == k2: K[_] then - * the types are actually equal (either be careful or store a - * type identifier). - */ -sealed abstract class HMap[K[_], V[_]] { - type Pair[t] = (K[t], V[t]) - protected val map: Map[K[_], V[_]] - override def toString: String = - "H%s".format(map) - - override def equals(that: Any): Boolean = that match { - case null => false - case h: HMap[_, _] => map.equals(h.map) - case _ => false - } - override def hashCode = map.hashCode - - def +[T](kv: (K[T], V[T])): HMap[K, V] = - HMap.from[K, V](map + kv) - - def -(k: K[_]): HMap[K, V] = - HMap.from[K, V](map - k) - - def apply[T](id: K[T]): V[T] = get(id).get - - def contains[T](id: K[T]): Boolean = get(id).isDefined - - def filter(pred: GenFunction[Pair, ({ type BoolT[T] = Boolean })#BoolT]): HMap[K, V] = { - val filtered = map.asInstanceOf[Map[K[Any], V[Any]]].filter(pred.apply[Any]) - HMap.from[K, V](filtered.asInstanceOf[Map[K[_], V[_]]]) - } - - def get[T](id: K[T]): Option[V[T]] = - map.get(id).asInstanceOf[Option[V[T]]] - - def keysOf[T](v: V[T]): Set[K[T]] = map.collect { - case (k, w) if v == w => - k.asInstanceOf[K[T]] - }.toSet - - // go through all the keys, and find the first key that matches this - // function and apply - def updateFirst(p: GenPartial[K, V]): Option[(HMap[K, V], K[_])] = { - def collector[T]: PartialFunction[(K[T], V[T]), (K[T], V[T])] = { - val pf = p.apply[T] - - { - case (kv: (K[T], V[T])) if pf.isDefinedAt(kv._1) => - val v2 = pf(kv._1) - (kv._1, v2) - } - } - - map.asInstanceOf[Map[K[Any], V[Any]]].collectFirst(collector) - .map { kv => - (this + kv, kv._1) - } - } - - def collect[R[_]](p: GenPartial[Pair, R]): Stream[R[_]] = - map.toStream.asInstanceOf[Stream[(K[Any], V[Any])]].collect(p.apply) - - def collectValues[R[_]](p: GenPartial[V, R]): Stream[R[_]] = - map.values.toStream.asInstanceOf[Stream[V[Any]]].collect(p.apply) -} - -// This is a function that preserves the inner type -trait GenFunction[T[_], R[_]] { - def apply[U]: (T[U] => R[U]) -} - -trait GenPartial[T[_], R[_]] { - def apply[U]: PartialFunction[T[U], R[U]] -} - -object HMap { - def empty[K[_], V[_]]: HMap[K, V] = from[K, V](Map.empty[K[_], V[_]]) - private def from[K[_], V[_]](m: Map[K[_], V[_]]): HMap[K, V] = - new HMap[K, V] { override val map = m } -} - -/** - * This is a useful cache for memoizing heterogenously types functions - */ -class HCache[K[_], V[_]]() { - private var hmap: HMap[K, V] = HMap.empty[K, V] - - /** - * Get snapshot of the current state - */ - def snapshot: HMap[K, V] = hmap - - def getOrElseUpdate[T](k: K[T], v: => V[T]): V[T] = - hmap.get(k) match { - case Some(exists) => exists - case None => - val res = v - hmap = hmap + (k -> res) - res - } -} - diff --git a/summingbird-core/src/main/scala/com/twitter/summingbird/memory/ConcurrentMemory.scala b/summingbird-core/src/main/scala/com/twitter/summingbird/memory/ConcurrentMemory.scala index 7b927a37f..daaf0915d 100644 --- a/summingbird-core/src/main/scala/com/twitter/summingbird/memory/ConcurrentMemory.scala +++ b/summingbird-core/src/main/scala/com/twitter/summingbird/memory/ConcurrentMemory.scala @@ -16,14 +16,12 @@ package com.twitter.summingbird.memory -import com.twitter.summingbird.graph._ - import com.twitter.summingbird.planner.DagOptimizer import com.twitter.algebird.{ Monoid, Semigroup } import com.twitter.summingbird._ import com.twitter.summingbird.option.JobId -import scala.collection.mutable.Buffer +import com.stripe.dagon.HMap import scala.concurrent.{ ExecutionContext, Future } import java.util.concurrent.{ BlockingQueue, LinkedBlockingQueue, ConcurrentHashMap } @@ -134,7 +132,10 @@ object PhysicalNode { } class ConcurrentMemory(implicit jobID: JobId = JobId("default.concurrent.memory.jobId")) - extends Platform[ConcurrentMemory] with DagOptimizer[ConcurrentMemory] { + extends Platform[ConcurrentMemory] { + + private[this] val optimizer = DagOptimizer[ConcurrentMemory] + import optimizer._ type Source[T] = TraversableOnce[T] type Store[K, V] = ConcurrentHashMap[K, V] @@ -144,7 +145,7 @@ class ConcurrentMemory(implicit jobID: JobId = JobId("default.concurrent.memory. import PhysicalNode._ - type ProdCons[T] = Prod[Any] + private type ProdCons[T] = Prod[Any] def counter(group: Group, name: Name): Option[Long] = MemoryStatProvider.getCountersForJob(jobID).flatMap { _.get(group.getString + "/" + name.getString).map { _.get } } diff --git a/summingbird-core/src/main/scala/com/twitter/summingbird/memory/Memory.scala b/summingbird-core/src/main/scala/com/twitter/summingbird/memory/Memory.scala index 43be42c59..d1abdae14 100644 --- a/summingbird-core/src/main/scala/com/twitter/summingbird/memory/Memory.scala +++ b/summingbird-core/src/main/scala/com/twitter/summingbird/memory/Memory.scala @@ -16,8 +16,8 @@ package com.twitter.summingbird.memory +import com.stripe.dagon.HMap import com.twitter.summingbird._ -import com.twitter.summingbird.graph.HMap import com.twitter.summingbird.option.JobId import com.twitter.summingbird.planner.DagOptimizer import collection.mutable.{ Map => MutableMap } diff --git a/summingbird-core/src/main/scala/com/twitter/summingbird/planner/ComposedFunctions.scala b/summingbird-core/src/main/scala/com/twitter/summingbird/planner/ComposedFunctions.scala index 8b1b63f52..d1c099fc0 100644 --- a/summingbird-core/src/main/scala/com/twitter/summingbird/planner/ComposedFunctions.scala +++ b/summingbird-core/src/main/scala/com/twitter/summingbird/planner/ComposedFunctions.scala @@ -111,3 +111,19 @@ case class MergeResults[A, B](left: A => TraversableOnce[B], right: A => Travers def apply(a: A) = (left(a).toIterator) ++ (right(a).toIterator) def irreducibles = IrreducibleContainer.flatten(left, right) } + +/** + * flatMapping keys can be done by a normal flatmap + */ +case class KeyFlatMapFunction[K1, K2, V](fn: K1 => TraversableOnce[K2]) extends (Tuple2[K1, V] => TraversableOnce[(K2, V)]) with IrreducibleContainer { + def apply(kv: (K1, V)) = fn(kv._1).map((_, kv._2)) + def irreducibles = IrreducibleContainer.flatten(fn) +} + +/** + * flatMapping values can be done by a normal flatmap + */ +case class ValueFlatMapFunction[K, V1, V2](fn: V1 => TraversableOnce[V2]) extends (Tuple2[K, V1] => TraversableOnce[(K, V2)]) with IrreducibleContainer { + def apply(kv: (K, V1)) = fn(kv._2).map((kv._1, _)) + def irreducibles = IrreducibleContainer.flatten(fn) +} diff --git a/summingbird-core/src/main/scala/com/twitter/summingbird/planner/DagOptimizer.scala b/summingbird-core/src/main/scala/com/twitter/summingbird/planner/DagOptimizer.scala index 0835f58cf..4efc45f2d 100644 --- a/summingbird-core/src/main/scala/com/twitter/summingbird/planner/DagOptimizer.scala +++ b/summingbird-core/src/main/scala/com/twitter/summingbird/planner/DagOptimizer.scala @@ -1,26 +1,11 @@ -/* - Copyright 2014 Twitter, Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - package com.twitter.summingbird.planner import com.twitter.summingbird._ -import com.twitter.summingbird.graph._ import com.twitter.algebird.Semigroup +import com.stripe.dagon.{Dag => DagonDag, _} +import java.io.Serializable -trait DagOptimizer[P <: Platform[P]] { +class DagOptimizer[P <: Platform[P]] extends Serializable { type Prod[T] = Producer[P, T] @@ -29,188 +14,108 @@ trait DagOptimizer[P <: Platform[P]] { * in converting from an AlsoProducer to a Literal[T, Prod] below, it is * not actually dangerous because we always use it in a safe position. */ - protected def mkAlso[T, U]: (Prod[T], Prod[U]) => Prod[U] = { - (left, right) => AlsoProducer(left.asInstanceOf[TailProducer[P, T]], right) + protected def mkAlso[T, U]: (Prod[T], Prod[U]) => Prod[U] = { (left, right) => + AlsoProducer(left.asInstanceOf[TailProducer[P, T]], right) } - protected def mkAlsoTail[T, U]: (Prod[T], Prod[U]) => Prod[U] = { - (left, right) => new AlsoTailProducer(left.asInstanceOf[TailProducer[P, T]], right.asInstanceOf[TailProducer[P, U]]) + protected def mkAlsoTail[T, U]: (Prod[T], Prod[U]) => Prod[U] = { (left, right) => + new AlsoTailProducer(left.asInstanceOf[TailProducer[P, T]], + right.asInstanceOf[TailProducer[P, U]]) } - protected def mkMerge[T]: (Prod[T], Prod[T]) => Prod[T] = { - (left, right) => MergedProducer(left, right) + protected def mkMerge[T]: (Prod[T], Prod[T]) => Prod[T] = { (left, right) => + MergedProducer(left, right) } - protected def mkNamed[T](name: String): (Prod[T] => Prod[T]) = { - prod => NamedProducer(prod, name) + protected def mkNamed[T](name: String): (Prod[T] => Prod[T]) = { prod => + NamedProducer(prod, name) } - protected def mkTPNamed[T](name: String): (Prod[T] => Prod[T]) = { - prod => new TPNamedProducer(prod.asInstanceOf[TailProducer[P, T]], name) + protected def mkTPNamed[T](name: String): (Prod[T] => Prod[T]) = { prod => + new TPNamedProducer(prod.asInstanceOf[TailProducer[P, T]], name) } - protected def mkIdentKey[K, V]: (Prod[(K, V)] => Prod[(K, V)]) = { - prod => IdentityKeyedProducer(prod) + protected def mkIdentKey[K, V]: (Prod[(K, V)] => Prod[(K, V)]) = { prod => + IdentityKeyedProducer(prod) } - protected def mkOptMap[T, U](fn: T => Option[U]): (Prod[T] => Prod[U]) = { - prod => OptionMappedProducer(prod, fn) + protected def mkOptMap[T, U](fn: T => Option[U]): (Prod[T] => Prod[U]) = { prod => + OptionMappedProducer(prod, fn) } - protected def mkFlatMapped[T, U](fn: T => TraversableOnce[U]): (Prod[T] => Prod[U]) = { - prod => FlatMappedProducer(prod, fn) + protected def mkFlatMapped[T, U](fn: T => TraversableOnce[U]): (Prod[T] => Prod[U]) = { prod => + FlatMappedProducer(prod, fn) } protected def mkKeyFM[T, U, V](fn: T => TraversableOnce[U]): (Prod[(T, V)] => Prod[(U, V)]) = { - prod => KeyFlatMappedProducer(prod, fn) + prod => + KeyFlatMappedProducer(prod, fn) } protected def mkValueFM[K, U, V](fn: U => TraversableOnce[V]): (Prod[(K, U)] => Prod[(K, V)]) = { - prod => ValueFlatMappedProducer(prod, fn) + prod => + ValueFlatMappedProducer(prod, fn) } - protected def mkWritten[T, U >: T](sink: P#Sink[U]): (Prod[T] => Prod[T]) = { - prod => WrittenProducer[P, T, U](prod, sink) + protected def mkWritten[T, U >: T](sink: P#Sink[U]): (Prod[T] => Prod[T]) = { prod => + WrittenProducer[P, T, U](prod, sink) } - protected def mkSrv[K, T, V](serv: P#Service[K, V]): (Prod[(K, T)] => Prod[(K, (T, Option[V]))]) = { - prod => LeftJoinedProducer(prod, serv) + protected def mkSrv[K, T, V]( + serv: P#Service[K, V]): (Prod[(K, T)] => Prod[(K, (T, Option[V]))]) = { prod => + LeftJoinedProducer(prod, serv) } - protected def mkSum[K, V](store: P#Store[K, V], sg: Semigroup[V]): (Prod[(K, V)] => Prod[(K, (Option[V], V))]) = { - prod => Summer(prod, store, sg) + protected def mkSum[K, V](store: P#Store[K, V], + sg: Semigroup[V]): (Prod[(K, V)] => Prod[(K, (Option[V], V))]) = { + prod => + Summer(prod, store, sg) } - type LitProd[T] = Literal[T, Prod] + type LitProd[T] = Literal[Prod, T] /** - * Convert a Producer graph into a Literal in the Dag rewriter + * Convert a Producer graph into a Literal in the DagonDag rewriter * This is where the tedious work comes in. */ - def toLiteral[T](prod: Producer[P, T]): Literal[T, Prod] = - toLiteral(HMap.empty[Prod, LitProd], prod)._2 + def toLiteral: FunctionK[Prod, LitProd] = + Memoize.functionK[Prod, LitProd](new Memoize.RecursiveK[Prod, LitProd] { + import Literal._ - protected def toLiteral[T](hm: HMap[Prod, LitProd], prod: Producer[P, T]): (HMap[Prod, LitProd], LitProd[T]) = { - // These get typed over and over below - type N[t] = Prod[t] - type M = HMap[Prod, LitProd] - type L[t] = Literal[t, N] - - /** - * All this shit is due to the scala compiler's inability to see the types - * in case matches. I can see this is unneeded, why can't scala? - */ - - def source[T1 <: T](t: Source[P, T1]): (M, L[T]) = { - val lit = ConstLit[T, N](t) - (hm + (t -> lit), lit) - } - def also[R](a: AlsoProducer[P, R, T]): (M, L[T]) = { - val (h1, l1) = toLiteral(hm, a.ensure) - val (h2, l2) = toLiteral(h1, a.result) - val lit = BinaryLit[R, T, T, N](l1, l2, mkAlso) - (h2 + (a -> lit), lit) - } - def alsoTail[R](a: AlsoTailProducer[P, R, T]): (M, L[T]) = { - val (h1, l1) = toLiteral(hm, a.ensure) - val (h2, l2) = toLiteral(h1, a.result) - val lit = BinaryLit[R, T, T, N](l1, l2, mkAlsoTail) - (h2 + (a -> lit), lit) - } - def merge(m: MergedProducer[P, T]): (M, L[T]) = { - val (h1, l1) = toLiteral(hm, m.left) - val (h2, l2) = toLiteral(h1, m.right) - val lit = BinaryLit[T, T, T, N](l1, l2, mkMerge) - (h2 + (m -> lit), lit) - } - def named(n: NamedProducer[P, T]): (M, L[T]) = { - val (h1, l1) = toLiteral(hm, n.producer) - val lit = UnaryLit[T, T, N](l1, mkNamed(n.id)) - (h1 + (n -> lit), lit) - } - def namedTP(n: TPNamedProducer[P, T]): (M, L[T]) = { - val (h1, l1) = toLiteral(hm, n.producer) - val lit = UnaryLit[T, T, N](l1, mkTPNamed(n.id)) - (h1 + (n -> lit), lit) - } - def ikp[K, V](ik: IdentityKeyedProducer[P, K, V]): (M, L[(K, V)]) = { - val (h1, l1) = toLiteral(hm, ik.producer) - val lit = UnaryLit[(K, V), (K, V), N](l1, mkIdentKey) - (h1 + (ik -> lit), lit) - } - def optm[T1](optm: OptionMappedProducer[P, T1, T]): (M, L[T]) = { - val (h1, l1) = toLiteral(hm, optm.producer) - val lit = UnaryLit[T1, T, N](l1, mkOptMap(optm.fn)) - (h1 + (optm -> lit), lit) - } - def flm[T1](fm: FlatMappedProducer[P, T1, T]): (M, L[T]) = { - val (h1, l1) = toLiteral(hm, fm.producer) - val lit = UnaryLit[T1, T, N](l1, mkFlatMapped(fm.fn)) - (h1 + (fm -> lit), lit) - } - def kfm[K, V, K2](kf: KeyFlatMappedProducer[P, K, V, K2]): (M, L[(K2, V)]) = { - val (h1, l1) = toLiteral(hm, kf.producer) - val lit = UnaryLit[(K, V), (K2, V), N](l1, mkKeyFM(kf.fn)) - (h1 + (kf -> lit), lit) - } - def vfm[K, V, V2](kf: ValueFlatMappedProducer[P, K, V, V2]): (M, L[(K, V2)]) = { - val (h1, l1) = toLiteral(hm, kf.producer) - val lit = UnaryLit[(K, V), (K, V2), N](l1, mkValueFM(kf.fn)) - (h1 + (kf -> lit), lit) - } - def writer[T1 <: T, U >: T1](w: WrittenProducer[P, T1, U]): (M, L[T]) = { - val (h1, l1) = toLiteral(hm, w.producer) - val lit = UnaryLit[T1, T, N](l1, mkWritten[T1, U](w.sink)) - (h1 + (w -> lit), lit) - } - def joined[K, V, U](join: LeftJoinedProducer[P, K, V, U]): (M, L[(K, (V, Option[U]))]) = { - val (h1, l1) = toLiteral(hm, join.left) - val lit = UnaryLit[(K, V), (K, (V, Option[U])), N](l1, mkSrv(join.joined)) - (h1 + (join -> lit), lit) - } - def summer[K, V](s: Summer[P, K, V]): (M, L[(K, (Option[V], V))]) = { - val (h1, l1) = toLiteral(hm, s.producer) - val lit = UnaryLit[(K, V), (K, (Option[V], V)), N](l1, mkSum(s.store, s.semigroup)) - (h1 + (s -> lit), lit) - } - - // the keyed have to be cast because - // all the keyed get inferred types (Any, Any), not - // (K, V) <: T, which is what they are - def cast[K, V](tup: (M, L[(K, V)])): (M, L[T]) = - tup.asInstanceOf[(M, L[T])] - - hm.get(prod) match { - case Some(lit) => (hm, lit) - case None => - prod match { - case s @ Source(_) => source(s) - case a: AlsoTailProducer[_, _, _] => alsoTail(a.asInstanceOf[AlsoTailProducer[P, _, T]]) - case a @ AlsoProducer(_, _) => also(a) - case m @ MergedProducer(l, r) => merge(m) - case n: TPNamedProducer[_, _] => namedTP(n.asInstanceOf[TPNamedProducer[P, T]]) - case n @ NamedProducer(producer, name) => named(n) - case w @ WrittenProducer(producer, sink) => writer(w) - case fm @ FlatMappedProducer(producer, fn) => flm(fm) - case om @ OptionMappedProducer(producer, fn) => optm(om) - // These casts can't fail due to the pattern match, - // but I can't convince scala of this without the cast. - case ik @ IdentityKeyedProducer(producer) => cast(ikp(ik)) - case kf @ KeyFlatMappedProducer(producer, fn) => cast(kfm(kf)) - case vf @ ValueFlatMappedProducer(producer, fn) => cast(vfm(vf)) - case j @ LeftJoinedProducer(producer, srv) => cast(joined(j)) - case s @ Summer(producer, store, sg) => cast(summer(s)) - } - } - } + def toFunction[T] = { + case (s @ Source(_), _) => Const[Prod, T](s) + case (a: AlsoTailProducer[P, a, T], rec) => + Binary[Prod, a, T, T](rec(a.ensure), rec(a.result), mkAlsoTail) + case (AlsoProducer(ensure, result), rec) => + Binary(rec(ensure), rec(result), mkAlso[Any, T]) + case (MergedProducer(l, r), rec) => + Binary(rec(l), rec(r), mkMerge) + case (n: TPNamedProducer[P, T], rec) => + Unary(rec(n.producer), mkTPNamed[T](n.id)) + case (NamedProducer(producer, name), rec) => + Unary(rec(producer), mkNamed(name)) + case (WrittenProducer(producer, sink), rec) => + Unary(rec(producer), mkWritten(sink)) + case (FlatMappedProducer(producer, fn), rec) => + Unary(rec(producer), mkFlatMapped(fn)) + case (OptionMappedProducer(producer, fn), rec) => + Unary(rec(producer), mkOptMap(fn)) + case (ik: IdentityKeyedProducer[P, k, v], rec) => + Unary(rec(ik.producer), mkIdentKey[k, v]) + case (kf: KeyFlatMappedProducer[P, t, u, v], rec) => + Unary(rec(kf.producer), mkKeyFM[t, v, u](kf.fn)) + case (vf: ValueFlatMappedProducer[P, k, v, u], rec) => + Unary(rec(vf.producer), mkValueFM[k, v, u](vf.fn)) + case (j: LeftJoinedProducer[P, k, v, u], rec) => + Unary(rec(j.left), mkSrv[k, v, u](j.joined)) + case (s @ Summer(producer, store, sg), rec) => + Unary(rec(producer), mkSum(store, sg)) + } + }) /** - * Create an ExpressionDag for the given node. This should be the + * Create an DagonDag for the given node. This should be the * final tail of the graph. You can apply optimizations on this - * Dag and then use the Id returned to evaluate it back to an + * DagonDag and then use the Id returned to evaluate it back to an * optimized producer */ - def expressionDag[T](p: Producer[P, T]): (ExpressionDag[Prod], Id[T]) = { - val prodToLit = new GenFunction[Prod, LitProd] { - def apply[T] = { p => toLiteral(p) } - } - ExpressionDag[T, Prod](p, prodToLit) - } + def dag[T](p: Producer[P, T]): (DagonDag[Prod], Id[T]) = + DagonDag[T, Prod](p, toLiteral) /** * Optimize the given producer according to the rule */ def optimize[T](p: Producer[P, T], rule: Rule[Prod]): Producer[P, T] = { - val (dag, id) = expressionDag(p) - dag(rule).evaluate(id) + val (d, id) = dag(p) + d(rule).evaluate(id) } /* @@ -224,7 +129,7 @@ trait DagOptimizer[P <: Platform[P]] { * the AST that we generate and optimize along the way */ object RemoveNames extends PartialRule[Prod] { - def applyWhere[T](on: ExpressionDag[Prod]) = { + def applyWhere[T](on: DagonDag[Prod]) = { case NamedProducer(p, _) => p } } @@ -234,9 +139,8 @@ trait DagOptimizer[P <: Platform[P]] { * types, they have no meaning at runtime. */ object RemoveIdentityKeyed extends PartialRule[Prod] { - def applyWhere[T](on: ExpressionDag[Prod]) = { - // scala can't see that (K, V) <: T - case IdentityKeyedProducer(p) => p.asInstanceOf[Prod[T]] + def applyWhere[T](on: DagonDag[Prod]) = { + case IdentityKeyedProducer(p) => p } } @@ -244,19 +148,53 @@ trait DagOptimizer[P <: Platform[P]] { * a.flatMap(fn).flatMap(fn2) can be written as a.flatMap(compose(fn, fn2)) */ object FlatMapFusion extends PartialRule[Prod] { - def applyWhere[T](on: ExpressionDag[Prod]) = { + def applyWhere[T](on: DagonDag[Prod]) = { //Can't fuse flatMaps when on fanout - case FlatMappedProducer(in1 @ FlatMappedProducer(in0, fn0), fn1) if (on.fanOut(in1) == 1) => + case FlatMappedProducer(in1 @ FlatMappedProducer(in0, fn0), fn1) => FlatMappedProducer(in0, ComposedFlatMap(fn0, fn1)) } } - // a.optionMap(b).optionMap(c) == a.optionMap(compose(b, c)) - object OptionMapFusion extends PartialRule[Prod] { - def applyWhere[T](on: ExpressionDag[Prod]) = { + /** + * a.flatMapKeys(fn).flatMapKeys(fn2) can be written as a.flatMapKeys(compose(fn, fn2)) + */ + object FlatMapKeyFusion extends PartialRule[Prod] { + def applyWhere[T](on: DagonDag[Prod]) = { //Can't fuse flatMaps when on fanout - case OptionMappedProducer(in1 @ OptionMappedProducer(in0, fn0), fn1) if (on.fanOut(in1) == 1) => - OptionMappedProducer(in0, ComposedOptionMap(fn0, fn1)) + case KeyFlatMappedProducer(in1 @ KeyFlatMappedProducer(in0, fn0), fn1) if (on.fanOut(in1) == 1) => + // we know that (K, V) <: T due to the case match, but scala can't see it + def cast[K, V](p: Prod[(K, V)]): Prod[T] = p.asInstanceOf[Prod[T]] + cast(KeyFlatMappedProducer(in0, ComposedFlatMap(fn0, fn1))) + } + } + + /** + * a.flatMapValues(fn).flatMapValues(fn2) can be written as a.flatMapValues(compose(fn, fn2)) + */ + object FlatMapValuesFusion extends PartialRule[Prod] { + def applyWhere[T](on: DagonDag[Prod]) = { + //Can't fuse flatMaps when on fanout + case ValueFlatMappedProducer(in1 @ ValueFlatMappedProducer(in0, fn0), fn1) if (on.fanOut(in1) == 1) => + // we know that (K, V) <: T due to the case match, but scala can't see it + def cast[K, V](p: Prod[(K, V)]): Prod[T] = p.asInstanceOf[Prod[T]] + cast(ValueFlatMappedProducer(in0, ComposedFlatMap(fn0, fn1))) + } + } + + // a.optionMap(b).optionMap(c) == a.optionMap(compose(b, c)) + object OptionMapFusion extends Rule[Prod] { + def apply[T](on: DagonDag[Prod]) = { + case OptionMappedProducer(in1 @ OptionMappedProducer(in0, fn0), fn1) + if (in0.isInstanceOf[Source[_, _]]) => + if (on.fanOut(in1) == 1) { + // only merge options up if we can't merge with a source. Don't destroy the ability to merge + // with the source + Some(OptionMappedProducer(in0, ComposedOptionMap(fn0, fn1))) + } else None + case OptionMappedProducer(in1 @ OptionMappedProducer(in0, fn0), fn1) => + // otherwise always merge since the main cost is serialization + node overhead + Some(OptionMappedProducer(in0, ComposedOptionMap(fn0, fn1))) + case _ => None } } @@ -265,34 +203,41 @@ trait DagOptimizer[P <: Platform[P]] { * you can use this rule */ object OptionToFlatMap extends PartialRule[Prod] { - def applyWhere[T](on: ExpressionDag[Prod]) = { - //Can't fuse flatMaps when on fanout + def applyWhere[T](on: DagonDag[Prod]) = { case OptionMappedProducer(in, fn) => in.flatMap(OptionToFlat(fn)) } } + + object OptionThenFlatFusion extends PartialRule[Prod] { + def applyWhere[T](on: DagonDag[Prod]) = { + // don't mess with option-map before a source + case FlatMappedProducer(in1 @ OptionMappedProducer(in0, fn0), fn1) + if (!in0.isInstanceOf[Source[_, _]]) => + in0.flatMap(planner.ComposedOptionFlat(fn0, fn1)) + } + } + /** * If you can't optimize KeyFlatMaps, use this */ object KeyFlatMapToFlatMap extends PartialRule[Prod] { - def applyWhere[T](on: ExpressionDag[Prod]) = { - //Can't fuse flatMaps when on fanout - // TODO: we need to case class here to not lose the irreducible which may be named + def applyWhere[T](on: DagonDag[Prod]) = { case KeyFlatMappedProducer(in, fn) => // we know that (K, V) <: T due to the case match, but scala can't see it def cast[K, V](p: Prod[(K, V)]): Prod[T] = p.asInstanceOf[Prod[T]] - cast(in.flatMap { case (k, v) => fn(k).map((_, v)) }) + cast(in.flatMap(KeyFlatMapFunction(fn))) } } + /** * If you can't optimize ValueFlatMaps, use this */ object ValueFlatMapToFlatMap extends PartialRule[Prod] { - def applyWhere[T](on: ExpressionDag[Prod]) = { - // TODO: we need to case class here to not lose the irreducible which may be named + def applyWhere[T](on: DagonDag[Prod]) = { case ValueFlatMappedProducer(in, fn) => // we know that (K, V) <: T due to the case match, but scala can't see it def cast[K, V](p: Prod[(K, V)]): Prod[T] = p.asInstanceOf[Prod[T]] - cast(in.flatMap { case (k, v) => fn(v).map((k, _)) }) + cast(in.flatMap(ValueFlatMapFunction(fn))) } } @@ -302,11 +247,12 @@ trait DagOptimizer[P <: Platform[P]] { * On the other direction, you might not want to run optionMap with flatMap since some * platforms (storm) can't easily control source parallelism, so we don't want to push * big expansions up to sources + * + * Since we want to minimize the number of nodes, we always perform this optimization. */ object FlatThenOptionFusion extends PartialRule[Prod] { - def applyWhere[T](on: ExpressionDag[Prod]) = { - //Can't fuse flatMaps when on fanout - case OptionMappedProducer(in1 @ FlatMappedProducer(in0, fn0), fn1) if (on.fanOut(in1) == 1) => + def applyWhere[T](on: DagonDag[Prod]) = { + case OptionMappedProducer(in1 @ FlatMappedProducer(in0, fn0), fn1) => FlatMappedProducer(in0, ComposedFlatMap(fn0, OptionToFlat(fn1))) } } @@ -315,10 +261,9 @@ trait DagOptimizer[P <: Platform[P]] { * (a.flatMap(f1) ++ a.flatMap(f2)) == a.flatMap { i => f1(i) ++ f2(i) } */ object DiamondToFlatMap extends PartialRule[Prod] { - def applyWhere[T](on: ExpressionDag[Prod]) = { - //Can't fuse flatMaps when on fanout + def applyWhere[T](on: DagonDag[Prod]) = { case MergedProducer(left @ FlatMappedProducer(inleft, fnleft), - right @ FlatMappedProducer(inright, fnright)) if (inleft == inright) && (on.fanOut(left) == 1) && (on.fanOut(right) == 1) => + right @ FlatMappedProducer(inright, fnright)) if (inleft == inright) => FlatMappedProducer(inleft, MergeResults(fnleft, fnright)) } } @@ -331,10 +276,10 @@ trait DagOptimizer[P <: Platform[P]] { */ object MergePullUp extends PartialRule[Prod] { //Can't do this operation if the merge fans out - def applyWhere[T](on: ExpressionDag[Prod]) = { - case OptionMappedProducer(m @ MergedProducer(a, b), fn) if (on.fanOut(m) == 1) => + def applyWhere[T](on: DagonDag[Prod]) = { + case OptionMappedProducer(m @ MergedProducer(a, b), fn) => (a.optionMap(fn)) ++ (b.optionMap(fn)) - case FlatMappedProducer(m @ MergedProducer(a, b), fn) if (on.fanOut(m) == 1) => + case FlatMappedProducer(m @ MergedProducer(a, b), fn) => (a.flatMap(fn)) ++ (b.flatMap(fn)) } } @@ -347,31 +292,52 @@ trait DagOptimizer[P <: Platform[P]] { * AlsoProducer(tail, fn(r)) */ object AlsoPullUp extends Rule[Prod] { - def apply[T](on: ExpressionDag[Prod]) = { + import Literal.Unary + + def apply[T](on: DagonDag[Prod]) = { case a @ AlsoProducer(_, _) => None // If we are already an also, we are done case MergedProducer(AlsoProducer(tail, l), r) => Some(AlsoProducer(tail, l ++ r)) case MergedProducer(l, AlsoProducer(tail, r)) => Some(AlsoProducer(tail, l ++ r)) - case node => on.toLiteral(node) match { - // There are a lot of unary operators, use the literal graph here: - // note that this cannot be an Also, due to the first case - case UnaryLit(alsoLit, fn) => - alsoLit.evaluate match { - case AlsoProducer(tail, rest) => - fn(rest) match { - case rightTail: TailProducer[_, _] => - // The type of the result must be T, but scala - // can't see this - val typedTail = rightTail.asInstanceOf[TailProducer[P, T]] - Some(new AlsoTailProducer(tail, typedTail)) - case nonTail => Some(AlsoProducer(tail, nonTail)) - } - case _ => None - } - case _ => None - } + case node => + on.toLiteral(node) match { + // There are a lot of unary operators, use the literal graph here: + // note that this cannot be an Also, due to the first case + case Unary(alsoLit, fn) => + alsoLit.evaluate match { + case AlsoProducer(tail, rest) => + fn(rest) match { + case rightTail: TailProducer[_, _] => + // The type of the result must be T, but scala + // can't see this + val typedTail = rightTail.asInstanceOf[TailProducer[P, T]] + Some(new AlsoTailProducer(tail, typedTail)) + case nonTail => Some(AlsoProducer(tail, nonTail)) + } + case _ => None + } + case _ => None + } } } + + /** + * We create a lot of merges followed by maps and merges + * which summingbird online does not deal with well. + * Here we optimize those away to reduce the number of nodes + */ + val standardRule = RemoveIdentityKeyed + .orElse(MergePullUp) + .orElse(OptionMapFusion) + .orElse(FlatMapFusion) + .orElse(FlatMapKeyFusion) + .orElse(FlatMapValuesFusion) + .orElse(FlatThenOptionFusion) + .orElse(OptionThenFlatFusion) + .orElse(DiamondToFlatMap) } +object DagOptimizer { + def apply[P <: Platform[P]]: DagOptimizer[P] = new DagOptimizer[P] +} diff --git a/summingbird-core/src/test/scala/com/twitter/summingbird/graph/ExpressionDagTests.scala b/summingbird-core/src/test/scala/com/twitter/summingbird/graph/ExpressionDagTests.scala deleted file mode 100644 index 2a32f0dbf..000000000 --- a/summingbird-core/src/test/scala/com/twitter/summingbird/graph/ExpressionDagTests.scala +++ /dev/null @@ -1,205 +0,0 @@ -/* - Copyright 2014 Twitter, Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package com.twitter.summingbird.graph - -import org.scalacheck.Prop._ -import org.scalacheck.{ Gen, Properties } - -object ExpressionDagTests extends Properties("ExpressionDag") { - /* - * Here we test with a simple algebra optimizer - */ - - sealed trait Formula[T] { // we actually will ignore T - def evaluate: Int - def closure: Set[Formula[T]] - } - case class Constant[T](override val evaluate: Int) extends Formula[T] { - def closure = Set(this) - } - case class Inc[T](in: Formula[T], by: Int) extends Formula[T] { - def evaluate = in.evaluate + by - def closure = in.closure + this - } - case class Sum[T](left: Formula[T], right: Formula[T]) extends Formula[T] { - def evaluate = left.evaluate + right.evaluate - def closure = (left.closure ++ right.closure) + this - } - case class Product[T](left: Formula[T], right: Formula[T]) extends Formula[T] { - def evaluate = left.evaluate * right.evaluate - def closure = (left.closure ++ right.closure) + this - } - - def genForm: Gen[Formula[Int]] = Gen.frequency((1, genProd), - (1, genSum), - (4, genInc), - (4, genConst)) - - def genConst: Gen[Formula[Int]] = Gen.chooseNum(Int.MinValue, Int.MaxValue).map(Constant(_)) - def genInc: Gen[Formula[Int]] = for { - by <- Gen.chooseNum(Int.MinValue, Int.MaxValue) - f <- Gen.lzy(genForm) - } yield Inc(f, by) - - def genSum: Gen[Formula[Int]] = for { - left <- Gen.lzy(genForm) - // We have to make dags, so select from the closure of left sometimes - right <- Gen.oneOf(genForm, Gen.oneOf(left.closure.toSeq)) - } yield Sum(left, right) - def genProd: Gen[Formula[Int]] = for { - left <- Gen.lzy(genForm) - // We have to make dags, so select from the closure of left sometimes - right <- Gen.oneOf(genForm, Gen.oneOf(left.closure.toSeq)) - } yield Product(left, right) - - type L[T] = Literal[T, Formula] - - /** - * Here we convert our dag nodes into Literal[Formula, T] - */ - def toLiteral = new GenFunction[Formula, L] { - def apply[T] = { (form: Formula[T]) => - def recurse[T2](memo: HMap[Formula, L], f: Formula[T2]): (HMap[Formula, L], L[T2]) = memo.get(f) match { - case Some(l) => (memo, l) - case None => f match { - case c @ Constant(_) => - def makeLit[T1](c: Constant[T1]) = { - val lit: L[T1] = ConstLit(c) - (memo + (c -> lit), lit) - } - makeLit(c) - case inc @ Inc(_, _) => - def makeLit[T1](i: Inc[T1]) = { - val (m1, f1) = recurse(memo, i.in) - val lit = UnaryLit(f1, { f: Formula[T1] => Inc(f, i.by) }) - (m1 + (i -> lit), lit) - } - makeLit(inc) - case sum @ Sum(_, _) => - def makeLit[T1](s: Sum[T1]) = { - val (m1, fl) = recurse(memo, s.left) - val (m2, fr) = recurse(m1, s.right) - val lit = BinaryLit(fl, fr, { (f: Formula[T1], g: Formula[T1]) => Sum(f, g) }) - (m2 + (s -> lit), lit) - } - makeLit(sum) - case prod @ Product(_, _) => - def makeLit[T1](p: Product[T1]) = { - val (m1, fl) = recurse(memo, p.left) - val (m2, fr) = recurse(m1, p.right) - val lit = BinaryLit(fl, fr, { (f: Formula[T1], g: Formula[T1]) => Product(f, g) }) - (m2 + (p -> lit), lit) - } - makeLit(prod) - } - } - recurse(HMap.empty[Formula, L], form)._2 - } - } - - /** - * Inc(Inc(a, b), c) = Inc(a, b + c) - */ - object CombineInc extends Rule[Formula] { - def apply[T](on: ExpressionDag[Formula]) = { - case Inc(i @ Inc(a, b), c) if on.fanOut(i) == 1 => Some(Inc(a, b + c)) - case _ => None - } - } - - object RemoveInc extends PartialRule[Formula] { - def applyWhere[T](on: ExpressionDag[Formula]) = { - case Inc(f, by) => Sum(f, Constant(by)) - } - } - - //Check the Node[T] <=> Id[T] is an Injection for all nodes reachable from the root - - property("toLiteral/Literal.evaluate is a bijection") = forAll(genForm) { form => - toLiteral.apply(form).evaluate == form - } - - property("Going to ExpressionDag round trips") = forAll(genForm) { form => - val (dag, id) = ExpressionDag(form, toLiteral) - dag.evaluate(id) == form - } - - property("CombineInc does not change results") = forAll(genForm) { form => - val simplified = ExpressionDag.applyRule(form, toLiteral, CombineInc) - form.evaluate == simplified.evaluate - } - - property("RemoveInc removes all Inc") = forAll(genForm) { form => - val noIncForm = ExpressionDag.applyRule(form, toLiteral, RemoveInc) - def noInc(f: Formula[Int]): Boolean = f match { - case Constant(_) => true - case Inc(_, _) => false - case Sum(l, r) => noInc(l) && noInc(r) - case Product(l, r) => noInc(l) && noInc(r) - } - noInc(noIncForm) && (noIncForm.evaluate == form.evaluate) - } - - /** - * This law is important for the rules to work as expected, and not have equivalent - * nodes appearing more than once in the Dag - */ - property("Node structural equality implies Id equality") = forAll(genForm) { form => - val (dag, id) = ExpressionDag(form, toLiteral) - type BoolT[T] = Boolean // constant type function - dag.idToExp.collect(new GenPartial[HMap[Id, ExpressionDag[Formula]#E]#Pair, BoolT] { - def apply[T] = { - case (id, expr) => - val node = expr.evaluate(dag.idToExp) - dag.idOf(node) == id - } - }).forall(identity) - } - - // The normal Inc gen recursively calls the general dag Generator - def genChainInc: Gen[Formula[Int]] = for { - by <- Gen.chooseNum(Int.MinValue, Int.MaxValue) - chain <- genChain - } yield Inc(chain, by) - - def genChain: Gen[Formula[Int]] = Gen.frequency((1, genConst), (3, genChainInc)) - property("CombineInc compresses linear Inc chains") = forAll(genChain) { chain => - ExpressionDag.applyRule(chain, toLiteral, CombineInc) match { - case Constant(n) => true - case Inc(Constant(n), b) => true - case _ => false // All others should have been compressed - } - } - - /** - * We should be able to totally evaluate these formulas - */ - object EvaluationRule extends Rule[Formula] { - def apply[T](on: ExpressionDag[Formula]) = { - case Sum(Constant(a), Constant(b)) => Some(Constant(a + b)) - case Product(Constant(a), Constant(b)) => Some(Constant(a * b)) - case Inc(Constant(a), b) => Some(Constant(a + b)) - case _ => None - } - } - property("EvaluationRule totally evaluates") = forAll(genForm) { form => - ExpressionDag.applyRule(form, toLiteral, EvaluationRule) match { - case Constant(x) if x == form.evaluate => true - case _ => false - } - } -} diff --git a/summingbird-core/src/test/scala/com/twitter/summingbird/graph/HMapTests.scala b/summingbird-core/src/test/scala/com/twitter/summingbird/graph/HMapTests.scala deleted file mode 100644 index 487b75595..000000000 --- a/summingbird-core/src/test/scala/com/twitter/summingbird/graph/HMapTests.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - Copyright 2014 Twitter, Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package com.twitter.summingbird.graph - -import org.scalacheck.Prop._ -import org.scalacheck.{ Arbitrary, Gen, Properties } - -/** - * This tests the HMap. We use the type system to - * prove the types are correct and don't (yet?) engage - * in the problem of higher kinded Arbitraries. - */ -object HMapTests extends Properties("HMap") { - case class Key[T](key: Int) - case class Value[T](value: Int) - - implicit def keyGen: Gen[Key[Int]] = Gen.choose(Int.MinValue, Int.MaxValue).map(Key(_)) - implicit def valGen: Gen[Value[Int]] = Gen.choose(Int.MinValue, Int.MaxValue).map(Value(_)) - - def zip[T, U](g: Gen[T], h: Gen[U]): Gen[(T, U)] = for { - a <- g - b <- h - } yield (a, b) - - implicit def hmapGen: Gen[HMap[Key, Value]] = - Gen.listOf(zip(keyGen, valGen)).map { list => - list.foldLeft(HMap.empty[Key, Value]) { (hm, kv) => - hm + kv - } - } - - implicit def arb[T](implicit g: Gen[T]): Arbitrary[T] = Arbitrary(g) - - property("adding a pair works") = forAll { (hmap: HMap[Key, Value], k: Key[Int], v: Value[Int]) => - val initContains = hmap.contains(k) - val added = hmap + (k -> v) - // Adding puts the item in, and does not change the initial - (added.get(k) == Some(v)) && - (initContains == hmap.contains(k)) && - (initContains == hmap.get(k).isDefined) - } - property("removing a key works") = forAll { (hmap: HMap[Key, Value], k: Key[Int]) => - val initContains = hmap.get(k).isDefined - val next = hmap - k - // Adding puts the item in, and does not change the initial - (!next.contains(k)) && - (initContains == hmap.contains(k)) && - (next.get(k) == None) - } - - property("keysOf works") = forAll { (hmap: HMap[Key, Value], k: Key[Int], v: Value[Int]) => - val initKeys = hmap.keysOf(v) - val added = hmap + (k -> v) - val finalKeys = added.keysOf(v) - val sizeIsConsistent = (finalKeys -- initKeys).size match { - case 0 => hmap.contains(k) // initially present - case 1 => !hmap.contains(k) // initially absent - case _ => false // we can't change the count by more than 1. - } - - sizeIsConsistent && added.contains(k) - } - - property("updateFirst works") = forAll { (hmap: HMap[Key, Value]) => - val partial = new GenPartial[Key, Value] { - def apply[T] = { case Key(id) if (id % 2 == 0) => Value(0) } - } - hmap.updateFirst(partial) match { - case Some((updated, k)) => updated.get(k) == Some(Value(0)) - case None => true - } - } - - property("collect works") = forAll { (map: Map[Key[Int], Value[Int]]) => - val hm = map.foldLeft(HMap.empty[Key, Value])(_ + _) - val partial = new GenPartial[HMap[Key, Value]#Pair, Value] { - def apply[T] = { case (Key(k), Value(v)) if k > v => Value(k * v) } - } - val collected = hm.collect(partial).map { case Value(v) => v }.toSet - val mapCollected = map.collect(partial.apply[Int]).map { case Value(v) => v }.toSet - collected == mapCollected - } - - property("collectValues works") = forAll { (map: Map[Key[Int], Value[Int]]) => - val hm = map.foldLeft(HMap.empty[Key, Value])(_ + _) - val partial = new GenPartial[Value, Value] { - def apply[T] = { case Value(v) if v < 0 => Value(v * v) } - } - val collected = hm.collectValues(partial).map { case Value(v) => v }.toSet - val mapCollected = map.values.collect(partial.apply[Int]).map { case Value(v) => v }.toSet - collected == mapCollected - } -} diff --git a/summingbird-core/src/test/scala/com/twitter/summingbird/graph/LiteralTests.scala b/summingbird-core/src/test/scala/com/twitter/summingbird/graph/LiteralTests.scala deleted file mode 100644 index bab2446fa..000000000 --- a/summingbird-core/src/test/scala/com/twitter/summingbird/graph/LiteralTests.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - Copyright 2014 Twitter, Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package com.twitter.summingbird.graph - -import org.scalacheck.Prop._ -import org.scalacheck.{ Arbitrary, Gen, Properties } - -object LiteralTests extends Properties("Literal") { - case class Box[T](get: T) - - def transitiveClosure[N[_]](l: Literal[_, N], acc: Set[Literal[_, N]] = Set.empty[Literal[_, N]]): Set[Literal[_, N]] = l match { - case c @ ConstLit(_) => acc + c - case u @ UnaryLit(prev, _) => if (acc(u)) acc else transitiveClosure(prev, acc + u) - case b @ BinaryLit(p1, p2, _) => if (acc(b)) acc else transitiveClosure(p2, transitiveClosure(p1, acc + b)) - } - - def genBox: Gen[Box[Int]] = Gen.chooseNum(Int.MinValue, Int.MaxValue).map(Box(_)) - - def genConst: Gen[Literal[Int, Box]] = genBox.map(ConstLit(_)) - def genUnary: Gen[Literal[Int, Box]] = for { - fn <- Arbitrary.arbitrary[(Int) => (Int)] - bfn = { case Box(b) => Box(fn(b)) }: Box[Int] => Box[Int] - input <- genLiteral - } yield UnaryLit(input, bfn) - - def genBinary: Gen[Literal[Int, Box]] = for { - fn <- Arbitrary.arbitrary[(Int, Int) => (Int)] - bfn = { case (Box(l), Box(r)) => Box(fn(l, r)) }: (Box[Int], Box[Int]) => Box[Int] - left <- genLiteral - // We have to make dags, so select from the closure of left sometimes - right <- Gen.oneOf(genLiteral, genChooseFrom(transitiveClosure[Box](left))) - } yield BinaryLit(left, right, bfn) - - def genChooseFrom[N[_]](s: Set[Literal[_, N]]): Gen[Literal[Int, N]] = - Gen.oneOf(s.toSeq.asInstanceOf[Seq[Literal[Int, N]]]) - - /* - * Create dags. Don't use binary too much as it can create exponentially growing dags - */ - def genLiteral: Gen[Literal[Int, Box]] = Gen.frequency((3, genConst), - (6, genUnary), (1, genBinary)) - - //This evaluates by recursively walking the tree without memoization - //as lit.evaluate should do - def slowEvaluate[T](lit: Literal[T, Box]): Box[T] = lit match { - case ConstLit(n) => n - case UnaryLit(in, fn) => fn(slowEvaluate(in)) - case BinaryLit(a, b, fn) => fn(slowEvaluate(a), slowEvaluate(b)) - } - - property("Literal.evaluate must match simple explanation") = forAll(genLiteral) { (l: Literal[Int, Box]) => - l.evaluate == slowEvaluate(l) - } -}