From 48e8c054f55a77bfb0618f661595818d0aed9143 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 27 Jun 2023 17:48:41 -0500 Subject: [PATCH] Ensure we exhaust row iterator before closing the hostBatch --- .../spark/sql/rapids/execution/GpuBroadcastToRowExec.scala | 2 +- .../spark/sql/rapids/execution/GpuSubqueryBroadcastExec.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastToRowExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastToRowExec.scala index c1748299f74..d794ab76e45 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastToRowExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastToRowExec.scala @@ -110,7 +110,7 @@ case class GpuBroadcastToRowExec( BoundReference(idx, buildKeys(idx).dataType, buildKeys(idx).nullable)) rowProject(broadcastRow).copy().asInstanceOf[InternalRow] } - }.toArray + }.toArray // force evaluation so we don't close hostBatch too soon } gpuLongMetric("dataSize") += serBatch.dataSize diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastExec.scala index 0080f0be448..1f18185f8ab 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastExec.scala @@ -260,13 +260,13 @@ case class GpuSubqueryBroadcastExec( hostBatch.rowIterator().asScala.map { row => val broadcastRow = broadcastModeProject.map(_(row)).getOrElse(row) rowProject(broadcastRow).copy().asInstanceOf[InternalRow] - } + }.toArray // force evaluation so we don't close hostBatch too soon } gpuLongMetric("dataSize") += serBatch.dataSize gpuLongMetric(COLLECT_TIME) += System.nanoTime() - beforeCollect - result.toArray + result } protected override def doExecute(): RDD[InternalRow] = {