From 782a87da5cf832e8d03e41ffa7a53e9a5efaab33 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Sun, 27 Oct 2024 15:45:04 +0100 Subject: [PATCH] chore: show a bug in Expand conversion --- .../src/main/scala/io/substrait/spark/SparkExtension.scala | 2 ++ .../scala/io/substrait/spark/SubstraitPlanTestBase.scala | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala index 53b5bfaaf..c470c7a42 100644 --- a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala +++ b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala @@ -34,6 +34,8 @@ object SparkExtension { private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection = SimpleExtension.loadDefaults() + val COLLECTION: SimpleExtension.ExtensionCollection = EXTENSION_COLLECTION.merge(SparkImpls) + lazy val SparkScalarFunctions: Seq[SimpleExtension.ScalarFunctionVariant] = { val ret = new collection.mutable.ArrayBuffer[SimpleExtension.ScalarFunctionVariant]() ret.appendAll(EXTENSION_COLLECTION.scalarFunctions().asScala) diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala index 4fa9ec263..cbd7a151c 100644 --- a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala +++ b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala @@ -26,7 +26,7 @@ import io.substrait.debug.TreePrinter import io.substrait.extension.ExtensionCollector import io.substrait.plan.{Plan, PlanProtoConverter, ProtoPlanConverter} import io.substrait.proto -import io.substrait.relation.RelProtoConverter +import io.substrait.relation.{ProtoRelConverter, RelProtoConverter} import org.scalactic.Equality import org.scalactic.source.Position import org.scalatest.Succeeded @@ -93,6 +93,10 @@ trait SubstraitPlanTestBase { self: SharedSparkSession => require(logicalPlan2.resolved); val pojoRel2 = new ToSubstraitRel().visit(logicalPlan2) + val extensionCollector = new ExtensionCollector; + val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel) + new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto) + pojoRel2.shouldEqualPlainly(pojoRel) logicalPlan2 }