Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-12729. PhantomReference to replace finalize in python broadcast… #11257

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
package org.apache.spark.api.python

import java.io._
import java.lang.ref.PhantomReference
import java.lang.ref.ReferenceQueue
import java.net._
import java.util.{ArrayList => JArrayList, Collections, List => JList, Map => JMap}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.language.existentials
import scala.util.control.NonFatal

Expand Down Expand Up @@ -871,6 +874,44 @@ private class PythonAccumulatorParam(@transient private val serverHost: String,
* write the data into disk after deserialization, then Python can read it from disks.
*/
// scalastyle:off no.finalize

/**
* Create a class that extends PhantomReference.
*/
class FilePhantomReference(@transient var f: File, var q: ReferenceQueue[File])
extends PhantomReference(f, q){

def cleanup()
{
f.delete()
}
}

class PhantomThread( threadName: String, queue: ReferenceQueue[File],
phantomReferences: ListBuffer[FilePhantomReference])
extends Thread with Logging {

def shutdownOnTaskCompletion()
{
this.interrupt()
}
setDaemon(true)
override def run(): Unit = Utils.logUncaughtExceptions {
while (phantomReferences.size > 0)
{
try {
val ref = queue.remove().asInstanceOf[FilePhantomReference]
phantomReferences -= ref
ref.cleanup()
} catch {
case ex: InterruptedException => {
logDebug("Exception thrown after file object cleanup", ex)
}
}
}
}
}

private[spark] class PythonBroadcast(@transient var path: String) extends Serializable
with Logging {

Expand All @@ -892,27 +933,23 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val dir = new File(Utils.getLocalDir(SparkEnv.get.conf))
val file = File.createTempFile("broadcast", "", dir)
val queue = new ReferenceQueue[File]()
val exitWhenFinished = false
val phantomReferences = new ListBuffer[FilePhantomReference]()
val threadName = "WeakReference"
val phantomThread = new PhantomThread(threadName, queue, phantomReferences)
path = file.getAbsolutePath
val out = new FileOutputStream(file)
phantomReferences += new FilePhantomReference(file, queue)
phantomThread.start()
if (phantomReferences.size == 0) {
phantomThread.shutdownOnTaskCompletion
}
Utils.tryWithSafeFinally {
Utils.copyStream(in, out)
} {
out.close()
}
}

/**
* Delete the file once the object is GCed.
*/
override def finalize() {
if (!path.isEmpty) {
val file = new File(path)
if (file.exists()) {
if (!file.delete()) {
logWarning(s"Error deleting ${file.getPath}")
}
}
}
}
}
// scalastyle:on no.finalize