diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PersistentMemoryHandler.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PersistentMemoryHandler.scala index fad40873..8d6900fd 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/PersistentMemoryHandler.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/PersistentMemoryHandler.scala @@ -1,5 +1,6 @@ package org.apache.spark.storage.pmof +import java.io.File import java.nio.ByteBuffer import org.apache.spark.internal.Logging @@ -25,6 +26,8 @@ private[spark] class PersistentMemoryHandler( // need to use a locked file to get which pmem device should be used. val pmMetaHandler: PersistentMemoryMetaHandler = new PersistentMemoryMetaHandler(root_dir) var device: String = pmMetaHandler.getShuffleDevice(shuffleId) + var poolFile = "" + var isFsdaxFile = false if(device == "") { //this shuffleId haven't been written before, choose a new device val path_array_list = new java.util.ArrayList[String](path_list.asJava) @@ -33,15 +36,17 @@ private[spark] class PersistentMemoryHandler( val dev = Paths.get(device) if (Files.isDirectory(dev)) { // this is fsdax, add a subfile - device += "/shuffle_block_" + UUID.randomUUID().toString() - logInfo("This is a fsdax, filename:" + device) + isFsdaxFile = true + poolFile = device + "/shuffle_block_" + UUID.randomUUID().toString() + logInfo("This is a fsdax, filename:" + poolFile) } else { - logInfo("This is a devdax, name:" + device) + poolFile = device + logInfo("This is a devdax, name:" + poolFile) poolSize = 0 } } - val pmpool = new PersistentMemoryPool(device, poolSize) + val pmpool = new PersistentMemoryPool(poolFile, poolSize) var rkey: Long = 0 @@ -84,8 +89,20 @@ private[spark] class PersistentMemoryHandler( } def close(): Unit = synchronized { - pmpool.close() - pmMetaHandler.remove() + if (isFsdaxFile) { + try { + if (new File(poolFile).delete()) { + logInfo("File deleted successfully: " + poolFile) + } else { + logWarning("Failed to delete file: " + poolFile) + } + } catch { + case e: Exception => e.printStackTrace() + } + } else { + pmpool.close() + pmMetaHandler.remove() + } } def getRootAddr(): Long = { diff --git a/native/src/lib_jni_pmdk.cpp b/native/src/lib_jni_pmdk.cpp index 2ce9132b..9d02a33b 100644 --- a/native/src/lib_jni_pmdk.cpp +++ b/native/src/lib_jni_pmdk.cpp @@ -5,7 +5,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_spark_storage_pmof_PersistentMemoryPool_nativeOpenDevice (JNIEnv *env, jclass obj, jstring path, jlong size) { const char *CStr = env->GetStringUTFChars(path, 0); - pmemkv* kv= new pmemkv(CStr); + pmemkv* kv= new pmemkv(CStr, size); env->ReleaseStringUTFChars(path, CStr); return (long)kv; } diff --git a/native/src/pmemkv.h b/native/src/pmemkv.h index 192f27ee..d7a4f606 100644 --- a/native/src/pmemkv.h +++ b/native/src/pmemkv.h @@ -95,8 +95,8 @@ key_3 --> block_meta_list_3[block_meta, block_meta, block_meta] */ class pmemkv { public: - explicit pmemkv(const char* dev_path_) : pmem_pool(nullptr), dev_path(dev_path_), bp(nullptr) { - if (create()) { + explicit pmemkv(const char* dev_path_, long size) : pmem_pool(nullptr), dev_path(dev_path_), bp(nullptr) { + if (create(size)) { int res = open(); if (res) { std::cout << "failed to open pmem pool, errmsg: " << pmemobj_errormsg() << std::endl; @@ -448,12 +448,12 @@ class pmemkv { return (uint64_t)pmem_pool; } private: - int create() { + int create(long size) { // debug setting int sds_write_value = 0; pmemobj_ctl_set(nullptr, "sds.at_create", &sds_write_value); - pmem_pool = pmemobj_create(dev_path, PMEMKV_LAYOUT_NAME, 0, 0666); + pmem_pool = pmemobj_create(dev_path, PMEMKV_LAYOUT_NAME, size, 0666); if (pmem_pool == nullptr) { return -1; }