From 1d0a420257435f17f240694ff626edf1f55dd74f Mon Sep 17 00:00:00 2001 From: Andrew Coleman Date: Fri, 25 Oct 2024 07:23:01 +0100 Subject: [PATCH] fix(spark): output columns of expand relation Fix the logic for deriving whether each output column of the expand relation is nullable or not. If one or more of the projections can assign null to the column, then it must be defined as nullable. Signed-off-by: Andrew Coleman # Conflicts: # spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala --- .../scala/io/substrait/spark/logical/ToLogicalPlan.scala | 7 +++++-- spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 9d49303b0..7525ccc5f 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -224,9 +224,12 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] throw new UnsupportedOperationException("ConsistentField not currently supported") } - val output = projections.head + // An output column is nullable if any of the projections can assign null to it + val types = projections.transpose.map(p => (p.head.dataType, p.exists(_.nullable))) + + val output = types .zip(names) - .map { case (t, name) => StructField(name, t.dataType, t.nullable) } + .map { case (t, name) => StructField(name, t._1, t._2) } .map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) Expand(projections, output, child) diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala index e06ddfdb0..fd6fd0f00 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -33,7 +33,7 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { // spotless:off val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7", - "q11", "q13", "q14b", "q15", "q16", "q18", "q19", + "q11", "q13", "q14a", "q14b", "q15", "q16", "q18", "q19", "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q28", "q29", "q30", "q31", "q32", "q33", "q37", "q38", "q40", "q41", "q42", "q43", "q46", "q48",