diff --git a/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkMemoryUtils.scala b/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkMemoryUtils.scala index 4aeafe87b..a5067098e 100644 --- a/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkMemoryUtils.scala +++ b/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkMemoryUtils.scala @@ -27,17 +27,15 @@ import com.intel.oap.spark.sql.execution.datasources.v2.arrow._ import com.sun.xml.internal.messaging.saaj.util.ByteOutputStream import org.apache.arrow.dataset.jni.NativeMemoryPool import org.apache.arrow.memory.AllocationListener -import org.apache.arrow.memory.AllocationOutcome -import org.apache.arrow.memory.AutoBufferLedger import org.apache.arrow.memory.BufferAllocator -import org.apache.arrow.memory.BufferLedger -import org.apache.arrow.memory.DirectAllocationListener -import org.apache.arrow.memory.ImmutableConfig -import org.apache.arrow.memory.LegacyBufferLedger +import org.apache.arrow.memory.MemoryChunkCleaner +import org.apache.arrow.memory.MemoryChunkManager import org.apache.arrow.memory.RootAllocator +import org.apache.spark.SparkEnv import org.apache.spark.TaskContext import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.TaskCompletionListener @@ -46,36 +44,6 @@ object SparkMemoryUtils extends Logging { private val DEBUG: Boolean = false - class AllocationListenerList(listeners: AllocationListener *) - extends AllocationListener { - override def onPreAllocation(size: Long): Unit = { - listeners.foreach(_.onPreAllocation(size)) - } - - override def onAllocation(size: Long): Unit = { - listeners.foreach(_.onAllocation(size)) - } - - override def onRelease(size: Long): Unit = { - listeners.foreach(_.onRelease(size)) - } - - override def onFailedAllocation(size: Long, outcome: AllocationOutcome): Boolean = { - listeners.forall(_.onFailedAllocation(size, outcome)) - } - - override def onChildAdded(parentAllocator: BufferAllocator, - childAllocator: BufferAllocator): Unit = { - listeners.foreach(_.onChildAdded(parentAllocator, childAllocator)) - - } - - override def onChildRemoved(parentAllocator: BufferAllocator, - childAllocator: BufferAllocator): Unit = { - listeners.foreach(_.onChildRemoved(parentAllocator, childAllocator)) - } - } - class TaskMemoryResources { if (!inSparkTask()) { throw new IllegalStateException("Creating TaskMemoryResources instance out of Spark task") @@ -88,19 +56,18 @@ object SparkMemoryUtils extends Logging { .getConfString("spark.oap.sql.columnar.autorelease", "false").toBoolean } - val ledgerFactory: BufferLedger.Factory = if (isArrowAutoReleaseEnabled) { - AutoBufferLedger.newFactory() + val memoryChunkManagerFactory: MemoryChunkManager.Factory = if (isArrowAutoReleaseEnabled) { + MemoryChunkCleaner.newFactory(MemoryChunkCleaner.Mode.HYBRID_WITH_LOG) } else { - LegacyBufferLedger.FACTORY + MemoryChunkManager.FACTORY } val sparkManagedAllocationListener = new SparkManagedAllocationListener( new NativeSQLMemoryConsumer(getTaskMemoryManager(), Spiller.NO_OP), sharedMetrics) - val directAllocationListener = DirectAllocationListener.INSTANCE val allocListener: AllocationListener = if (isArrowAutoReleaseEnabled) { - new AllocationListenerList(sparkManagedAllocationListener, directAllocationListener) + MemoryChunkCleaner.gcTrigger(sparkManagedAllocationListener) } else { sparkManagedAllocationListener } @@ -121,12 +88,13 @@ object SparkMemoryUtils extends Logging { private val memoryPools = new util.ArrayList[NativeMemoryPoolWrapper]() - val defaultAllocator: BufferAllocator = { - val alloc = new RootAllocator(ImmutableConfig.builder() - .maxAllocation(Long.MaxValue) - .bufferLedgerFactory(ledgerFactory) - .listener(allocListener) - .build) + val taskDefaultAllocator: BufferAllocator = { + val alloc = new RootAllocator( + RootAllocator.configBuilder() + .maxAllocation(Long.MaxValue) + .memoryChunkManagerFactory(memoryChunkManagerFactory) + .listener(allocListener) + .build) allocators.add(alloc) alloc } @@ -154,7 +122,7 @@ object SparkMemoryUtils extends Logging { val al = new SparkManagedAllocationListener( new NativeSQLMemoryConsumer(getTaskMemoryManager(), spiller), sharedMetrics) - val parent = defaultAllocator + val parent = taskDefaultAllocator val alloc = parent.newChildAllocator("Spark Managed Allocator - " + UUID.randomUUID().toString, al, 0, parent.getLimit).asInstanceOf[BufferAllocator] allocators.add(alloc) @@ -198,7 +166,7 @@ object SparkMemoryUtils extends Logging { } def release(): Unit = { - ledgerFactory match { + memoryChunkManagerFactory match { case closeable: AutoCloseable => closeable.close() case _ => @@ -208,7 +176,11 @@ object SparkMemoryUtils extends Logging { if (allocated == 0L) { close(allocator) } else { - softClose(allocator) + if (isArrowAutoReleaseEnabled) { + close(allocator) + } else { + softClose(allocator) + } } } for (pool <- memoryPools.asScala) { @@ -271,15 +243,19 @@ object SparkMemoryUtils extends Logging { } } - private val allocator = new RootAllocator( - ImmutableConfig.builder() - .maxAllocation(Long.MaxValue) - .bufferLedgerFactory(AutoBufferLedger.newFactory()) - .listener(DirectAllocationListener.INSTANCE) + private val maxAllocationSize = { + SparkEnv.get.conf.get(MEMORY_OFFHEAP_SIZE) + } + + private val globalAlloc = new RootAllocator( + RootAllocator.configBuilder() + .maxAllocation(maxAllocationSize) + .memoryChunkManagerFactory(MemoryChunkCleaner.newFactory()) + .listener(MemoryChunkCleaner.gcTrigger()) .build) def globalAllocator(): BufferAllocator = { - allocator + globalAlloc } def globalMemoryPool(): NativeMemoryPool = { @@ -304,7 +280,7 @@ object SparkMemoryUtils extends Logging { if (!inSparkTask()) { return globalAllocator() } - getTaskMemoryResources().defaultAllocator + getTaskMemoryResources().taskDefaultAllocator } def contextMemoryPool(): NativeMemoryPool = { diff --git a/arrow-data-source/script/build_arrow.sh b/arrow-data-source/script/build_arrow.sh index 410e31070..7ed20fe44 100755 --- a/arrow-data-source/script/build_arrow.sh +++ b/arrow-data-source/script/build_arrow.sh @@ -62,7 +62,7 @@ echo "ARROW_SOURCE_DIR=${ARROW_SOURCE_DIR}" echo "ARROW_INSTALL_DIR=${ARROW_INSTALL_DIR}" mkdir -p $ARROW_SOURCE_DIR mkdir -p $ARROW_INSTALL_DIR -git clone https://github.com/oap-project/arrow.git --branch arrow-4.0.0-oap $ARROW_SOURCE_DIR +git clone https://github.com/zhztheplayer/arrow-1.git --branch oap-auto-release-hybrid $ARROW_SOURCE_DIR pushd $ARROW_SOURCE_DIR cmake ./cpp \ diff --git a/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/JniUtils.java b/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/JniUtils.java index 9f55054ce..48d84a85a 100644 --- a/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/JniUtils.java +++ b/native-sql-engine/core/src/main/java/com/intel/oap/vectorized/JniUtils.java @@ -61,7 +61,7 @@ public static JniUtils getInstance(String tmp_dir) throws IOException { try { INSTANCE = new JniUtils(tmp_dir); } catch (IllegalAccessException ex) { - throw new IOException("IllegalAccess"); + throw new IOException("IllegalAccess", ex); } } }