diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index cb4771dd92f80..ac136dfb898ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -28,6 +28,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest { val nullInt = Literal(null, IntegerType) val nullString = Literal(null, StringType) val testRelation = LocalRelation($"a".string, $"b".string, $"c".string, $"d".string, $"e".int) + val testRelation2 = LocalRelation($"a".double, $"b".int, $"c".int, $"d".int, $"e".int) private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match { case Aggregate(_, _, Aggregate(_, _, _: Expand)) => @@ -77,7 +78,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest { } test("SPARK-40382: eliminate multiple distinct groups due to superficial differences") { - val input = testRelation + val input = testRelation2 .groupBy($"a")( countDistinct($"b" + $"c").as("agg1"), countDistinct($"c" + $"b").as("agg2"), @@ -92,7 +93,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest { } test("SPARK-40382: reduce multiple distinct groups due to superficial differences") { - val input = testRelation + val input = testRelation2 .groupBy($"a")( countDistinct($"b" + $"c" + $"d").as("agg1"), countDistinct($"d" + $"c" + $"b").as("agg2"),