diff --git a/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkMemoryUtils.scala b/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkMemoryUtils.scala index 0ca9921a6..340f8e53f 100644 --- a/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkMemoryUtils.scala +++ b/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkMemoryUtils.scala @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._ import com.intel.oap.spark.sql.execution.datasources.v2.arrow._ import org.apache.arrow.dataset.jni.NativeMemoryPool import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.memory.RootAllocator import org.apache.spark.TaskContext import org.apache.spark.internal.Logging @@ -40,15 +41,12 @@ object SparkMemoryUtils extends Logging { } val sharedMetrics = new NativeSQLMemoryMetrics() - val defaultAllocator: BufferAllocator = { - val globalAlloc = globalAllocator() + val defaultAllocator: BufferAllocator = { val al = new SparkManagedAllocationListener( new NativeSQLMemoryConsumer(getTaskMemoryManager(), Spiller.NO_OP), sharedMetrics) - val parent = globalAlloc - parent.newChildAllocator("Spark Managed Allocator - " + - UUID.randomUUID().toString, al, 0, parent.getLimit) + new RootAllocator(al, Long.MaxValue) } val defaultMemoryPool: NativeMemoryPoolWrapper = { @@ -77,7 +75,7 @@ object SparkMemoryUtils extends Logging { val al = new SparkManagedAllocationListener( new NativeSQLMemoryConsumer(getTaskMemoryManager(), spiller), sharedMetrics) - val parent = globalAllocator() + val parent = defaultAllocator val alloc = parent.newChildAllocator("Spark Managed Allocator - " + UUID.randomUUID().toString, al, 0, parent.getLimit).asInstanceOf[BufferAllocator] allocators.add(alloc) @@ -121,7 +119,8 @@ object SparkMemoryUtils extends Logging { } def release(): Unit = { - for (allocator <- allocators.asScala) { + for (allocator <- allocators.asScala.reverse) { + // reversed iterating: close children first val allocated = allocator.getAllocatedMemory if (allocated == 0L) { close(allocator) @@ -212,9 +211,8 @@ object SparkMemoryUtils extends Logging { } def contextAllocator(): BufferAllocator = { - val globalAlloc = globalAllocator() if (!inSparkTask()) { - return globalAlloc + return globalAllocator() } getTaskMemoryResources().defaultAllocator }