diff --git a/summingbird-core-test/src/test/scala/com/twitter/summingbird/planner/DagOptimizerTest.scala b/summingbird-core-test/src/test/scala/com/twitter/summingbird/planner/DagOptimizerTest.scala new file mode 100644 index 000000000..36073bb06 --- /dev/null +++ b/summingbird-core-test/src/test/scala/com/twitter/summingbird/planner/DagOptimizerTest.scala @@ -0,0 +1,153 @@ +package com.twitter.summingbird.planner + +import com.twitter.algebird.Semigroup +import com.twitter.summingbird._ +import com.twitter.summingbird.graph.{DependantGraph, Rule} +import com.twitter.summingbird.memory._ + +import org.scalatest.FunSuite +import org.scalacheck.{Arbitrary, Gen} +import Gen.oneOf + +import scala.collection.mutable +import org.scalatest.prop.GeneratorDrivenPropertyChecks._ + +class DagOptimizerTest extends FunSuite { + + implicit val generatorDrivenConfig = + PropertyCheckConfig(minSuccessful = 1000, maxDiscarded = 1000) // the producer generator uses filter, I think + //PropertyCheckConfig(minSuccessful = 100, maxDiscarded = 1000) // the producer generator uses filter, I think + + import TestGraphGenerators._ + import MemoryArbitraries._ + implicit def testStore: Memory#Store[Int, Int] = mutable.Map[Int, Int]() + implicit def testService: Memory#Service[Int, Int] = new mutable.HashMap[Int, Int]() with MemoryService[Int, Int] + implicit def sink1: Memory#Sink[Int] = ((_) => Unit) + implicit def sink2: Memory#Sink[(Int, Int)] = ((_) => Unit) + + def genProducer: Gen[Producer[Memory, _]] = oneOf(genProd1, genProd2, summed) + + test("DagOptimizer round trips") { + forAll { p: Producer[Memory, Int] => + val dagOpt = new DagOptimizer[Memory] { } + + assert(dagOpt.toLiteral(p).evaluate == p) + } + } + + val dagOpt = new DagOptimizer[Memory] { } + + test("ExpressionDag fanOut matches DependantGraph") { + forAll(genProducer) { p: Producer[Memory, _] => + val expDag = dagOpt.expressionDag(p)._1 + + // the expression considers an also a fanout, so + // we can't use the standard Dependants, we need to + // us parentsOf as the edge function + val deps = new DependantGraph[Producer[Memory, Any]] { + override lazy val nodes: List[Producer[Memory, Any]] = Producer.entireGraphOf(p) + override def dependenciesOf(p: Producer[Memory, Any]) = Producer.parentsOf(p) + } + + deps.nodes.foreach { n => + deps.fanOut(n) match { + case Some(fo) => assert(expDag.fanOut(n) == fo) + case None => fail(s"node $n has no fanOut value") + } + } + } + } + + + val allRules = { + import dagOpt._ + + List(RemoveNames, + RemoveIdentityKeyed, + FlatMapFusion, + OptionMapFusion, + OptionToFlatMap, + KeyFlatMapToFlatMap, + FlatMapKeyFusion, + ValueFlatMapToFlatMap, + FlatMapValuesFusion, + FlatThenOptionFusion, + DiamondToFlatMap, + MergePullUp, + AlsoPullUp) + } + + val genRule: Gen[Rule[dagOpt.Prod]] = + for { + n <- Gen.choose(1, allRules.size) + rs <- Gen.pick(n, allRules) // get n randomly selected + } yield rs.reduce(_.orElse(_)) + + test("Rules are idempotent") { + forAll(genProducer, genRule) { (p, r) => + val once = dagOpt.optimize(p, r) + val twice = dagOpt.optimize(once, r) + assert(once == twice) + } + } + + test("fanOut matches after optimization") { + + forAll(genProducer, genRule) { (p, r) => + + val once = dagOpt.optimize(p, r) + + val expDag = dagOpt.expressionDag(once)._1 + // the expression considers an also a fanout, so + // we can't use the standard Dependants, we need to + // us parentsOf as the edge function + val deps = new DependantGraph[Producer[Memory, Any]] { + override lazy val nodes: List[Producer[Memory, Any]] = Producer.entireGraphOf(once) + override def dependenciesOf(p: Producer[Memory, Any]) = Producer.parentsOf(p) + } + + deps.nodes.foreach { n => + deps.fanOut(n) match { + case Some(fo) => assert(expDag.fanOut(n) == fo, s"node: $n, in optimized: $once") + case None => fail(s"node $n has no fanOut value") + } + } + } + + } + + test("test some idempotency specific past failures") { + val list = List(-483916215) + val list2 = list + + val map1 = new MemoryService[Int, Int] { + val map = Map(1122506458 -> -422595330) + def get(i: Int) = map.get(i) + } + + val fn1 = { (i: Int) => List((i, i)) } + val fn2 = { (tup: (Int, Int)) => Option(tup) } + val fn3 = { i: (Int, (Int, Option[Int])) => List((i._1, i._2._1)) } + val fn4 = fn1 + val fn5 = fn2 + val mmap: Memory#Store[Int, Int] = collection.mutable.Map.empty[Int, Int] + + val arg0: Producer[Memory, (Int, (Option[Int], Int))] = + Summer[Memory, Int, Int](IdentityKeyedProducer(NamedProducer(IdentityKeyedProducer(MergedProducer(IdentityKeyedProducer(FlatMappedProducer(LeftJoinedProducer[Memory, Int, Int, Int](IdentityKeyedProducer(NamedProducer(IdentityKeyedProducer(NamedProducer(IdentityKeyedProducer(OptionMappedProducer(IdentityKeyedProducer(FlatMappedProducer(Source[Memory, Int](list), fn1)), fn2)),"tjiposzOlkplcu")),"tvpwpdyScehGnwcaVjjWvlfuwxatxhdjhozscucpbq")), map1), fn3)),IdentityKeyedProducer(OptionMappedProducer(IdentityKeyedProducer(FlatMappedProducer(Source[Memory, Int](list2), fn4)),fn5)))),"ncn")),mmap, implicitly[Semigroup[Int]]) + + val dagOpt = new DagOptimizer[Memory] { } + + val rule = { + import dagOpt._ + List[Rule[Prod]]( + RemoveNames, + RemoveIdentityKeyed, + FlatMapFusion, + OptionToFlatMap + ).reduce(_ orElse _) + } + val once = dagOpt.optimize(arg0, rule) + val twice = dagOpt.optimize(once, rule) + assert(twice == once) + } +} diff --git a/summingbird-core/src/main/scala/com/twitter/summingbird/graph/ExpressionDag.scala b/summingbird-core/src/main/scala/com/twitter/summingbird/graph/ExpressionDag.scala index d97216578..a76763aac 100644 --- a/summingbird-core/src/main/scala/com/twitter/summingbird/graph/ExpressionDag.scala +++ b/summingbird-core/src/main/scala/com/twitter/summingbird/graph/ExpressionDag.scala @@ -126,6 +126,7 @@ sealed trait ExpressionDag[N[_]] { self => } } // Note this Stream must always be non-empty as long as roots are + // TODO: we don't need to use collect here, just .get on each id in s idToExp.collect[IdSet](partial) .reduce(_ ++ _) } @@ -162,12 +163,6 @@ sealed trait ExpressionDag[N[_]] { self => curr } - protected def toExpr[T](n: N[T]): (ExpressionDag[N], Expr[T, N]) = { - val (dag, id) = ensure(n) - val exp = dag.idToExp(id) - (dag, exp) - } - /** * Convert a N[T] to a Literal[T, N] */ @@ -182,10 +177,19 @@ sealed trait ExpressionDag[N[_]] { self => def apply[U] = { val fn = rule.apply[U](self) + def ruleApplies(id: Id[U]): Boolean = { + val n = evaluate(id) + fn(n) match { + case Some(n1) => n != n1 + case None => false + } + } + + { - case (id, exp) if fn(exp.evaluate(idToExp)).isDefined => + case (id, _) if ruleApplies(id) => // Sucks to have to call fn, twice, but oh well - (id, fn(exp.evaluate(idToExp)).get) + (id, fn(evaluate(id)).get) } } } @@ -193,21 +197,37 @@ sealed trait ExpressionDag[N[_]] { self => case None => this case Some(tup) => // some type hand holding - def act[T](in: HMap[Id, N]#Pair[T]) = { + def act[T](in: HMap[Id, N]#Pair[T]): ExpressionDag[N] = { + /* + * We can't delete Ids which may have been shared + * publicly, and the ids may be embedded in many + * nodes. Instead we remap this i to be a pointer + * to the newid. + */ val (i, n) = in - val oldNode = evaluate(i) - val (dag, exp) = toExpr(n) - dag.copy(id2Exp = dag.idToExp + (i -> exp)) + val (dag, newId) = ensure(n) + dag.copy(id2Exp = dag.idToExp + (i -> Var[T, N](newId))) } // This cast should not be needed act(tup.asInstanceOf[HMap[Id, N]#Pair[Any]]).gc } } - // This is only called by ensure + /** + * This is only called by ensure + * + * Note, Expr must never be a Var + */ private def addExp[T](node: N[T], exp: Expr[T, N]): (ExpressionDag[N], Id[T]) = { - val nodeId = Id[T](nextId) - (copy(id2Exp = idToExp + (nodeId -> exp), id = nextId + 1), nodeId) + require(!exp.isInstanceOf[Var[T, N]]) + + find(node) match { + case None => + val nodeId = Id[T](nextId) + (copy(id2Exp = idToExp + (nodeId -> exp), id = nextId + 1), nodeId) + case Some(id) => + (this, id) + } } /** @@ -216,9 +236,18 @@ sealed trait ExpressionDag[N[_]] { self => */ def find[T](node: N[T]): Option[Id[T]] = nodeToId.getOrElseUpdate(node, { val partial = new GenPartial[HMap[Id, E]#Pair, Id] { - def apply[T1] = { case (thisId, expr) if node == expr.evaluate(idToExp) => thisId } + def apply[T1] = { + // Make sure to return the original Id, not a Id -> Var -> Expr + case (thisId, expr) if !expr.isInstanceOf[Var[_, N]] && node == expr.evaluate(idToExp) => thisId + } + } + idToExp.collect(partial).toList match { + case Nil => None + case id :: Nil => + // this cast is safe if node == expr.evaluate(idToExp) implies types match + Some(id).asInstanceOf[Option[Id[T]]] + case others => None//sys.error(s"logic error, should only be one mapping: $node -> $others") } - idToExp.collect(partial).headOption.asInstanceOf[Option[Id[T]]] }) /** @@ -247,7 +276,7 @@ sealed trait ExpressionDag[N[_]] { self => * Since the code is not performance critical, but correctness critical, and we can't * check this property with the typesystem easily, check it here */ - assert(n == node, + require(n == node, "Equality or nodeToLiteral is incorrect: nodeToLit(%s) = ConstLit(%s)".format(node, n)) addExp(node, Const(n)) case UnaryLit(prev, fn) => @@ -272,10 +301,7 @@ sealed trait ExpressionDag[N[_]] { self => def evaluateOption[T](id: Id[T]): Option[N[T]] = idToN.getOrElseUpdate(id, { - val partial = new GenPartial[HMap[Id, E]#Pair, N] { - def apply[T1] = { case (thisId, expr) if (id == thisId) => expr.evaluate(idToExp) } - } - idToExp.collect(partial).headOption.asInstanceOf[Option[N[T]]] + idToExp.get(id).map(_.evaluate(idToExp)) }) /** @@ -284,25 +310,31 @@ sealed trait ExpressionDag[N[_]] { self => * We need to garbage collect nodes that are * no longer reachable from the root */ - def fanOut(id: Id[_]): Int = { - // We make a fake IntT[T] which is just Int - val partial = new GenPartial[E, ({ type IntT[T] = Int })#IntT] { - def apply[T] = { - case Var(id1) if (id1 == id) => 1 - case Unary(id1, fn) if (id1 == id) => 1 - case Binary(id1, id2, fn) if (id1 == id) && (id2 == id) => 2 - case Binary(id1, id2, fn) if (id1 == id) || (id2 == id) => 1 - case _ => 0 - } - } - idToExp.collectValues[({ type IntT[T] = Int })#IntT](partial).sum + def fanOut(id: Id[_]): Int = + evaluateOption(id) + .map(fanOut) + .getOrElse(0) + + @annotation.tailrec + private def dependsOn(expr: Expr[_, N], node: N[_]): Boolean = expr match { + case Const(_) => false + case Var(id) => dependsOn(idToExp(id), node) + case Unary(id, _) => evaluate(id) == node + case Binary(id0, id1, _) => evaluate(id0) == node || evaluate(id1) == node } /** * Returns 0 if the node is absent, which is true * use .contains(n) to check for containment */ - def fanOut(node: N[_]): Int = find(node).map(fanOut(_)).getOrElse(0) + def fanOut(node: N[_]): Int = { + val pointsToNode = new GenPartial[HMap[Id, E]#Pair, N] { + def apply[T] = { + case (id, expr) if dependsOn(expr, node) => evaluate(id) + } + } + idToExp.collect[N](pointsToNode).toSet.size + } def contains(node: N[_]): Boolean = find(node).isDefined } @@ -355,6 +387,9 @@ trait Rule[N[_]] { self => def apply[T](on: ExpressionDag[N]) = { n => self.apply(on)(n).orElse(that.apply(on)(n)) } + + override def toString: String = + s"$self.orElse($that)" } } diff --git a/summingbird-core/src/main/scala/com/twitter/summingbird/graph/package.scala b/summingbird-core/src/main/scala/com/twitter/summingbird/graph/package.scala index 60da67d68..33c62a0ed 100644 --- a/summingbird-core/src/main/scala/com/twitter/summingbird/graph/package.scala +++ b/summingbird-core/src/main/scala/com/twitter/summingbird/graph/package.scala @@ -56,7 +56,8 @@ package object graph { } } // make sure the values are sets, not .mapValues is lazy in scala - .map { case (k, v) => (k, v.distinct) }; + .map { case (k, v) => (k, v.distinct) } + graph.getOrElse(_, Nil) } diff --git a/summingbird-core/src/test/scala/com/twitter/summingbird/graph/ExpressionDagTests.scala b/summingbird-core/src/test/scala/com/twitter/summingbird/graph/ExpressionDagTests.scala index 2a32f0dbf..2e0db8363 100644 --- a/summingbird-core/src/test/scala/com/twitter/summingbird/graph/ExpressionDagTests.scala +++ b/summingbird-core/src/test/scala/com/twitter/summingbird/graph/ExpressionDagTests.scala @@ -177,6 +177,7 @@ object ExpressionDagTests extends Properties("ExpressionDag") { } yield Inc(chain, by) def genChain: Gen[Formula[Int]] = Gen.frequency((1, genConst), (3, genChainInc)) + property("CombineInc compresses linear Inc chains") = forAll(genChain) { chain => ExpressionDag.applyRule(chain, toLiteral, CombineInc) match { case Constant(n) => true