diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 529d16b480399..5e4aeac330c5a 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -54,6 +54,64 @@ } +class Py4jCallbackConnectionCleaner(object): + + """ + A cleaner to clean up callback connections that are not closed by Py4j. See SPARK-12617. + It will scan all callback connections every 30 seconds and close the dead connections. + """ + + def __init__(self, gateway): + self._gateway = gateway + self._stopped = False + self._timer = None + self._lock = RLock() + + def start(self): + if self._stopped: + return + + def clean_closed_connections(): + from py4j.java_gateway import quiet_close, quiet_shutdown + + callback_server = self._gateway._callback_server + with callback_server.lock: + try: + closed_connections = [] + for connection in callback_server.connections: + if not connection.isAlive(): + quiet_close(connection.input) + quiet_shutdown(connection.socket) + quiet_close(connection.socket) + closed_connections.append(connection) + + for closed_connection in closed_connections: + callback_server.connections.remove(closed_connection) + except Exception: + import traceback + traceback.print_exc() + + self._start_timer(clean_closed_connections) + + self._start_timer(clean_closed_connections) + + def _start_timer(self, f): + from threading import Timer + + with self._lock: + if not self._stopped: + self._timer = Timer(30.0, f) + self._timer.daemon = True + self._timer.start() + + def stop(self): + with self._lock: + self._stopped = True + if self._timer: + self._timer.cancel() + self._timer = None + + class SparkContext(object): """ @@ -68,6 +126,7 @@ class SparkContext(object): _active_spark_context = None _lock = RLock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH + _py4j_cleaner = None PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar') @@ -244,6 +303,8 @@ def _ensure_initialized(cls, instance=None, gateway=None): if not SparkContext._gateway: SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm + _py4j_cleaner = Py4jCallbackConnectionCleaner(SparkContext._gateway) + _py4j_cleaner.start() if instance: if (SparkContext._active_spark_context and