From ab78420c4c2722e85c41cd5e0292583e26f6999c Mon Sep 17 00:00:00 2001 From: Zongheng Yang Date: Thu, 12 Jun 2014 22:17:49 -0700 Subject: [PATCH] Add a test. --- .../spark/sql/catalyst/trees/TreeNode.scala | 3 ++- .../sql/catalyst/trees/TreeNodeSuite.scala | 26 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 9829b4a54ac0c..cd04bdf02cf84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -289,7 +289,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } catch { case e: java.lang.IllegalArgumentException => throw new TreeNodeException( - this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName?") + this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? " + + s"Exception message: ${e.getMessage}.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 1ddc41a731ff5..0b40a2dcea7c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{StringType, NullType} class TreeNodeSuite extends FunSuite { test("top node changed") { @@ -75,4 +76,29 @@ class TreeNodeSuite extends FunSuite { assert(expected === actual) } + + test("transform works on nodes with Option children") { + case class Dummy(optKey: Option[Expression]) extends Expression { + def children = optKey.toSeq + def references = Set.empty[Attribute] + def nullable = true + def dataType = NullType + override lazy val resolved = true + type EvaluatedType = Any + def eval(input: Row) = null.asInstanceOf[Any] + } + val dummy1 = Dummy(Some(Literal("1", StringType))) + val dummy2 = Dummy(None) + val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } + + var actual = dummy1 transformDown toZero + assert(actual === Dummy(Some(Literal(0)))) + + actual = dummy1 transformUp toZero + assert(actual === Dummy(Some(Literal(0)))) + + actual = dummy2 transform toZero + assert(actual === Dummy(None)) + } + }