Skip to content
This repository has been archived by the owner on Jan 20, 2022. It is now read-only.

Add a DagOptimizer test #745

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package com.twitter.summingbird.planner

import com.twitter.algebird.Semigroup
import com.twitter.summingbird._
import com.twitter.summingbird.graph.{DependantGraph, Rule}
import com.twitter.summingbird.memory._

import org.scalatest.FunSuite
import org.scalacheck.{Arbitrary, Gen}
import Gen.oneOf

import scala.collection.mutable
import org.scalatest.prop.GeneratorDrivenPropertyChecks._

class DagOptimizerTest extends FunSuite {

implicit val generatorDrivenConfig =
PropertyCheckConfig(minSuccessful = 1000, maxDiscarded = 1000) // the producer generator uses filter, I think
//PropertyCheckConfig(minSuccessful = 100, maxDiscarded = 1000) // the producer generator uses filter, I think
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?


import TestGraphGenerators._
import MemoryArbitraries._
implicit def testStore: Memory#Store[Int, Int] = mutable.Map[Int, Int]()
implicit def testService: Memory#Service[Int, Int] = new mutable.HashMap[Int, Int]() with MemoryService[Int, Int]
implicit def sink1: Memory#Sink[Int] = ((_) => Unit)
implicit def sink2: Memory#Sink[(Int, Int)] = ((_) => Unit)

def genProducer: Gen[Producer[Memory, _]] = oneOf(genProd1, genProd2, summed)

test("DagOptimizer round trips") {
forAll { p: Producer[Memory, Int] =>
val dagOpt = new DagOptimizer[Memory] { }

assert(dagOpt.toLiteral(p).evaluate == p)
}
}

val dagOpt = new DagOptimizer[Memory] { }

test("ExpressionDag fanOut matches DependantGraph") {
forAll(genProducer) { p: Producer[Memory, _] =>
val expDag = dagOpt.expressionDag(p)._1

// the expression considers an also a fanout, so
// we can't use the standard Dependants, we need to
// us parentsOf as the edge function
val deps = new DependantGraph[Producer[Memory, Any]] {
override lazy val nodes: List[Producer[Memory, Any]] = Producer.entireGraphOf(p)
override def dependenciesOf(p: Producer[Memory, Any]) = Producer.parentsOf(p)
}

deps.nodes.foreach { n =>
deps.fanOut(n) match {
case Some(fo) => assert(expDag.fanOut(n) == fo)
case None => fail(s"node $n has no fanOut value")
}
}
}
}


val allRules = {
import dagOpt._

List(RemoveNames,
RemoveIdentityKeyed,
FlatMapFusion,
OptionMapFusion,
OptionToFlatMap,
KeyFlatMapToFlatMap,
FlatMapKeyFusion,
ValueFlatMapToFlatMap,
FlatMapValuesFusion,
FlatThenOptionFusion,
DiamondToFlatMap,
MergePullUp,
AlsoPullUp)
}

val genRule: Gen[Rule[dagOpt.Prod]] =
for {
n <- Gen.choose(1, allRules.size)
rs <- Gen.pick(n, allRules) // get n randomly selected
} yield rs.reduce(_.orElse(_))

test("Rules are idempotent") {
forAll(genProducer, genRule) { (p, r) =>
val once = dagOpt.optimize(p, r)
val twice = dagOpt.optimize(once, r)
assert(once == twice)
}
}

test("fanOut matches after optimization") {

forAll(genProducer, genRule) { (p, r) =>

val once = dagOpt.optimize(p, r)

val expDag = dagOpt.expressionDag(once)._1
// the expression considers an also a fanout, so
// we can't use the standard Dependants, we need to
// us parentsOf as the edge function
val deps = new DependantGraph[Producer[Memory, Any]] {
override lazy val nodes: List[Producer[Memory, Any]] = Producer.entireGraphOf(once)
override def dependenciesOf(p: Producer[Memory, Any]) = Producer.parentsOf(p)
}

deps.nodes.foreach { n =>
deps.fanOut(n) match {
case Some(fo) => assert(expDag.fanOut(n) == fo, s"node: $n, in optimized: $once")
case None => fail(s"node $n has no fanOut value")
}
}
}

}

test("test some idempotency specific past failures") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe nice to add a comment above on the details of the past failure? might be hard for folks reading the code to know?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know what to say. This was a hand minimized example (it took me about an hour) from a failure case found by scalacheck. Since the failures were quite rare, it took a long time to even find a failure with scalacheck, so once I found it, I wanted to test that failure every time.

That's what I mean by "specific past failures".

I understand (somewhat) why this failed now, but I couldn't easily generate another that would also show the bug.

Can you suggest some specific text you would like to see me add?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I guess reading the title it doesn't give you a sense of what the issue is and what it's testing. Would it make sense to add a tldr of your understanding of why it failed?

val list = List(-483916215)
val list2 = list

val map1 = new MemoryService[Int, Int] {
val map = Map(1122506458 -> -422595330)
def get(i: Int) = map.get(i)
}

val fn1 = { (i: Int) => List((i, i)) }
val fn2 = { (tup: (Int, Int)) => Option(tup) }
val fn3 = { i: (Int, (Int, Option[Int])) => List((i._1, i._2._1)) }
val fn4 = fn1
val fn5 = fn2
val mmap: Memory#Store[Int, Int] = collection.mutable.Map.empty[Int, Int]

val arg0: Producer[Memory, (Int, (Option[Int], Int))] =
Summer[Memory, Int, Int](IdentityKeyedProducer(NamedProducer(IdentityKeyedProducer(MergedProducer(IdentityKeyedProducer(FlatMappedProducer(LeftJoinedProducer[Memory, Int, Int, Int](IdentityKeyedProducer(NamedProducer(IdentityKeyedProducer(NamedProducer(IdentityKeyedProducer(OptionMappedProducer(IdentityKeyedProducer(FlatMappedProducer(Source[Memory, Int](list), fn1)), fn2)),"tjiposzOlkplcu")),"tvpwpdyScehGnwcaVjjWvlfuwxatxhdjhozscucpbq")), map1), fn3)),IdentityKeyedProducer(OptionMappedProducer(IdentityKeyedProducer(FlatMappedProducer(Source[Memory, Int](list2), fn4)),fn5)))),"ncn")),mmap, implicitly[Semigroup[Int]])

val dagOpt = new DagOptimizer[Memory] { }

val rule = {
import dagOpt._
List[Rule[Prod]](
RemoveNames,
RemoveIdentityKeyed,
FlatMapFusion,
OptionToFlatMap
).reduce(_ orElse _)
}
val once = dagOpt.optimize(arg0, rule)
val twice = dagOpt.optimize(once, rule)
assert(twice == once)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ sealed trait ExpressionDag[N[_]] { self =>
}
}
// Note this Stream must always be non-empty as long as roots are
// TODO: we don't need to use collect here, just .get on each id in s
idToExp.collect[IdSet](partial)
.reduce(_ ++ _)
}
Expand Down Expand Up @@ -162,12 +163,6 @@ sealed trait ExpressionDag[N[_]] { self =>
curr
}

protected def toExpr[T](n: N[T]): (ExpressionDag[N], Expr[T, N]) = {
val (dag, id) = ensure(n)
val exp = dag.idToExp(id)
(dag, exp)
}

/**
* Convert a N[T] to a Literal[T, N]
*/
Expand All @@ -182,32 +177,57 @@ sealed trait ExpressionDag[N[_]] { self =>
def apply[U] = {
val fn = rule.apply[U](self)

def ruleApplies(id: Id[U]): Boolean = {
val n = evaluate(id)
fn(n) match {
case Some(n1) => n != n1
case None => false
}
}


{
case (id, exp) if fn(exp.evaluate(idToExp)).isDefined =>
case (id, _) if ruleApplies(id) =>
// Sucks to have to call fn, twice, but oh well
(id, fn(exp.evaluate(idToExp)).get)
(id, fn(evaluate(id)).get)
}
}
}
idToExp.collect[HMap[Id, N]#Pair](getN).headOption match {
case None => this
case Some(tup) =>
// some type hand holding
def act[T](in: HMap[Id, N]#Pair[T]) = {
def act[T](in: HMap[Id, N]#Pair[T]): ExpressionDag[N] = {
/*
* We can't delete Ids which may have been shared
* publicly, and the ids may be embedded in many
* nodes. Instead we remap this i to be a pointer
* to the newid.
*/
val (i, n) = in
val oldNode = evaluate(i)
val (dag, exp) = toExpr(n)
dag.copy(id2Exp = dag.idToExp + (i -> exp))
val (dag, newId) = ensure(n)
dag.copy(id2Exp = dag.idToExp + (i -> Var[T, N](newId)))
}
// This cast should not be needed
act(tup.asInstanceOf[HMap[Id, N]#Pair[Any]]).gc
}
}

// This is only called by ensure
/**
* This is only called by ensure
*
* Note, Expr must never be a Var
*/
private def addExp[T](node: N[T], exp: Expr[T, N]): (ExpressionDag[N], Id[T]) = {
val nodeId = Id[T](nextId)
(copy(id2Exp = idToExp + (nodeId -> exp), id = nextId + 1), nodeId)
require(!exp.isInstanceOf[Var[T, N]])

find(node) match {
case None =>
val nodeId = Id[T](nextId)
(copy(id2Exp = idToExp + (nodeId -> exp), id = nextId + 1), nodeId)
case Some(id) =>
(this, id)
}
}

/**
Expand All @@ -216,9 +236,18 @@ sealed trait ExpressionDag[N[_]] { self =>
*/
def find[T](node: N[T]): Option[Id[T]] = nodeToId.getOrElseUpdate(node, {
val partial = new GenPartial[HMap[Id, E]#Pair, Id] {
def apply[T1] = { case (thisId, expr) if node == expr.evaluate(idToExp) => thisId }
def apply[T1] = {
// Make sure to return the original Id, not a Id -> Var -> Expr
case (thisId, expr) if !expr.isInstanceOf[Var[_, N]] && node == expr.evaluate(idToExp) => thisId
}
}
idToExp.collect(partial).toList match {
case Nil => None
case id :: Nil =>
// this cast is safe if node == expr.evaluate(idToExp) implies types match
Some(id).asInstanceOf[Option[Id[T]]]
case others => None//sys.error(s"logic error, should only be one mapping: $node -> $others")
}
idToExp.collect(partial).headOption.asInstanceOf[Option[Id[T]]]
})

/**
Expand Down Expand Up @@ -247,7 +276,7 @@ sealed trait ExpressionDag[N[_]] { self =>
* Since the code is not performance critical, but correctness critical, and we can't
* check this property with the typesystem easily, check it here
*/
assert(n == node,
require(n == node,
"Equality or nodeToLiteral is incorrect: nodeToLit(%s) = ConstLit(%s)".format(node, n))
addExp(node, Const(n))
case UnaryLit(prev, fn) =>
Expand All @@ -272,10 +301,7 @@ sealed trait ExpressionDag[N[_]] { self =>

def evaluateOption[T](id: Id[T]): Option[N[T]] =
idToN.getOrElseUpdate(id, {
val partial = new GenPartial[HMap[Id, E]#Pair, N] {
def apply[T1] = { case (thisId, expr) if (id == thisId) => expr.evaluate(idToExp) }
}
idToExp.collect(partial).headOption.asInstanceOf[Option[N[T]]]
idToExp.get(id).map(_.evaluate(idToExp))
})

/**
Expand All @@ -284,25 +310,31 @@ sealed trait ExpressionDag[N[_]] { self =>
* We need to garbage collect nodes that are
* no longer reachable from the root
*/
def fanOut(id: Id[_]): Int = {
// We make a fake IntT[T] which is just Int
val partial = new GenPartial[E, ({ type IntT[T] = Int })#IntT] {
def apply[T] = {
case Var(id1) if (id1 == id) => 1
case Unary(id1, fn) if (id1 == id) => 1
case Binary(id1, id2, fn) if (id1 == id) && (id2 == id) => 2
case Binary(id1, id2, fn) if (id1 == id) || (id2 == id) => 1
case _ => 0
}
}
idToExp.collectValues[({ type IntT[T] = Int })#IntT](partial).sum
def fanOut(id: Id[_]): Int =
evaluateOption(id)
.map(fanOut)
.getOrElse(0)

@annotation.tailrec
private def dependsOn(expr: Expr[_, N], node: N[_]): Boolean = expr match {
case Const(_) => false
case Var(id) => dependsOn(idToExp(id), node)
case Unary(id, _) => evaluate(id) == node
case Binary(id0, id1, _) => evaluate(id0) == node || evaluate(id1) == node
}

/**
* Returns 0 if the node is absent, which is true
* use .contains(n) to check for containment
*/
def fanOut(node: N[_]): Int = find(node).map(fanOut(_)).getOrElse(0)
def fanOut(node: N[_]): Int = {
val pointsToNode = new GenPartial[HMap[Id, E]#Pair, N] {
def apply[T] = {
case (id, expr) if dependsOn(expr, node) => evaluate(id)
}
}
idToExp.collect[N](pointsToNode).toSet.size
}
def contains(node: N[_]): Boolean = find(node).isDefined
}

Expand Down Expand Up @@ -355,6 +387,9 @@ trait Rule[N[_]] { self =>
def apply[T](on: ExpressionDag[N]) = { n =>
self.apply(on)(n).orElse(that.apply(on)(n))
}

override def toString: String =
s"$self.orElse($that)"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ package object graph {
}
}
// make sure the values are sets, not .mapValues is lazy in scala
.map { case (k, v) => (k, v.distinct) };
.map { case (k, v) => (k, v.distinct) }

graph.getOrElse(_, Nil)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ object ExpressionDagTests extends Properties("ExpressionDag") {
} yield Inc(chain, by)

def genChain: Gen[Formula[Int]] = Gen.frequency((1, genConst), (3, genChainInc))

property("CombineInc compresses linear Inc chains") = forAll(genChain) { chain =>
ExpressionDag.applyRule(chain, toLiteral, CombineInc) match {
case Constant(n) => true
Expand Down