diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 3deed52be0be2..5cc4bbde39958 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -98,8 +98,28 @@ def _ensure_initialized(cls): # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing - cls._transformerSerializer = TransformFunctionSerializer( - SparkContext._active_spark_context, CloudPickleSerializer(), gw) + if cls._transformerSerializer is None: + transformer_serializer = TransformFunctionSerializer() + transformer_serializer.init( + SparkContext._active_spark_context, CloudPickleSerializer(), gw) + # SPARK-12511 streaming driver with checkpointing unable to finalize leading to OOM + # There is an issue that Py4J's PythonProxyHandler.finalize blocks forever. + # (https://github.com/bartdag/py4j/pull/184) + # + # Py4j will create a PythonProxyHandler in Java for "transformer_serializer" when + # calling "registerSerializer". If we call "registerSerializer" twice, the second + # PythonProxyHandler will override the first one, then the first one will be GCed and + # trigger "PythonProxyHandler.finalize". To avoid that, we should not call + # "registerSerializer" more than once, so that "PythonProxyHandler" in Java side won't + # be GCed. + # + # TODO Once Py4J fixes this issue, we should upgrade Py4j to the latest version. + transformer_serializer.gateway.jvm.PythonDStream.registerSerializer( + transformer_serializer) + cls._transformerSerializer = transformer_serializer + else: + cls._transformerSerializer.init( + SparkContext._active_spark_context, CloudPickleSerializer(), gw) @classmethod def getOrCreate(cls, checkpointPath, setupFunc): @@ -116,16 +136,13 @@ def getOrCreate(cls, checkpointPath, setupFunc): gw = SparkContext._gateway # Check whether valid checkpoint information exists in the given path - if gw.jvm.CheckpointReader.read(checkpointPath).isEmpty(): + ssc_option = gw.jvm.StreamingContextPythonHelper().tryRecoverFromCheckpoint(checkpointPath) + if ssc_option.isEmpty(): ssc = setupFunc() ssc.checkpoint(checkpointPath) return ssc - try: - jssc = gw.jvm.JavaStreamingContext(checkpointPath) - except Exception: - print("failed to load StreamingContext from checkpoint", file=sys.stderr) - raise + jssc = gw.jvm.JavaStreamingContext(ssc_option.get()) # If there is already an active instance of Python SparkContext use it, or create a new one if not SparkContext._active_spark_context: diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index abbbf6eb9394f..e617fc9ce9eec 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -89,11 +89,10 @@ class TransformFunctionSerializer(object): it uses this class to invoke Python, which returns the serialized function as a byte array. """ - def __init__(self, ctx, serializer, gateway=None): + def init(self, ctx, serializer, gateway=None): self.ctx = ctx self.serializer = serializer self.gateway = gateway or self.ctx._gateway - self.gateway.jvm.PythonDStream.registerSerializer(self) self.failure = None def dumps(self, id): diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index c4a10aa2dd3b9..a5ab66697589b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -902,3 +902,15 @@ object StreamingContext extends Logging { result } } + +private class StreamingContextPythonHelper { + + /** + * This is a private method only for Python to implement `getOrCreate`. + */ + def tryRecoverFromCheckpoint(checkpointPath: String): Option[StreamingContext] = { + val checkpointOption = CheckpointReader.read( + checkpointPath, new SparkConf(), SparkHadoopUtil.get.conf, false) + checkpointOption.map(new StreamingContext(null, _, null)) + } +}