From 32db1947c787302dac2375af8111b2251c4642c5 Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Mon, 22 Aug 2022 16:02:02 +0800 Subject: [PATCH 1/4] feat: add shutdown API to tf2 pyspark estimator. --- .../bigdl/orca/learn/tf2/pyspark_estimator.py | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py index 53b2bfff923..4ca7ce81978 100644 --- a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py +++ b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py @@ -94,7 +94,7 @@ def __init__(self, is_local = sc.master.startswith("local") self.need_to_log_to_driver = (not is_local) and log_to_driver if self.need_to_log_to_driver: - start_log_server(self.ip, self.port) + self.log_server_thread = start_log_server(self.ip, self.port) def _get_cluster_info(self, sc): cluster_info = self.workerRDD.barrier().mapPartitions(find_ip_and_free_port).collect() @@ -528,3 +528,32 @@ def get_model(self): @property def _model_saved_path(self): return os.path.join(self.model_dir, "{}_model.h5".format(self.application_id)) + + def shutdown(self): + """ + Shutdown estimator and release resources. + """ + if self.log_server_thread.is_alive(): + import inspect + import ctypes + import zmq + + def _async_raise(tid, exctype): + tid = ctypes.c_long(tid) + if not inspect.isclass(exctype): + exctype = type(exctype) + res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype)) + if res == 0: + raise ValueError("invalid thread id") + elif res != 1: + ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None) + raise SystemError("PyThreadState_SetAsyncExc failed") + + def stop_thread(thread): + _async_raise(thread.ident, SystemExit) + + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.connect("tcp://{}:{}".format(self.ip, self.port)) + socket.send_string("shutdown log server") + stop_thread(self.log_server_thread) From e77e0a181936625412402f27cbdf520e41c8c7e7 Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Mon, 22 Aug 2022 16:10:54 +0800 Subject: [PATCH 2/4] fix: replace raise with log4Error. --- python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py index 4ca7ce81978..4dd118d4e4d 100644 --- a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py +++ b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py @@ -41,7 +41,7 @@ from bigdl.orca.learn.log_monitor import start_log_server from bigdl.orca.data.shard import SparkXShards from bigdl.orca import OrcaContext -from bigdl.dllib.utils.log4Error import invalidInputError +from bigdl.dllib.utils.log4Error import invalidInputError, invalidOperationError logger = logging.getLogger(__name__) @@ -544,10 +544,10 @@ def _async_raise(tid, exctype): exctype = type(exctype) res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype)) if res == 0: - raise ValueError("invalid thread id") + invalidInputError(False, "invalid thread id") elif res != 1: ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None) - raise SystemError("PyThreadState_SetAsyncExc failed") + invalidOperationError(False, "PyThreadState_SetAsyncExc failed") def stop_thread(thread): _async_raise(thread.ident, SystemExit) From 14ba4e0a154b79e221ad111f5e302b5d89ec08f6 Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Tue, 23 Aug 2022 14:53:33 +0800 Subject: [PATCH 3/4] refactor: refactor the stop log server function. --- .../orca/src/bigdl/orca/learn/log_monitor.py | 26 +++++++++++++++++ .../bigdl/orca/learn/tf2/pyspark_estimator.py | 29 ++----------------- 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/python/orca/src/bigdl/orca/learn/log_monitor.py b/python/orca/src/bigdl/orca/learn/log_monitor.py index 16a1224e5d2..d8efe638a6e 100644 --- a/python/orca/src/bigdl/orca/learn/log_monitor.py +++ b/python/orca/src/bigdl/orca/learn/log_monitor.py @@ -164,3 +164,29 @@ def _print_logs(): logger_thread.daemon = True logger_thread.start() return logger_thread + +def stop_log_server(thread, ip, port): + if thread.is_alive(): + import inspect + import ctypes + import zmq + + def _async_raise(tid, exctype): + tid = ctypes.c_long(tid) + if not inspect.isclass(exctype): + exctype = type(exctype) + res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype)) + if res == 0: + invalidInputError(False, "invalid thread id") + elif res != 1: + ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None) + invalidOperationError(False, "PyThreadState_SetAsyncExc failed") + + def stop_thread(thread): + _async_raise(thread.ident, SystemExit) + + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.connect("tcp://{}:{}".format(ip, port)) + socket.send_string("shutdown log server") + stop_thread(thread) diff --git a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py index 4dd118d4e4d..d5f7dc2b7dd 100644 --- a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py +++ b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py @@ -38,10 +38,10 @@ from bigdl.orca.learn.utils import maybe_dataframe_to_xshards, dataframe_to_xshards, \ convert_predict_xshards_to_dataframe, make_data_creator, load_model, \ save_model, process_xshards_of_pandas_dataframe -from bigdl.orca.learn.log_monitor import start_log_server +from bigdl.orca.learn.log_monitor import start_log_server, stop_log_server from bigdl.orca.data.shard import SparkXShards from bigdl.orca import OrcaContext -from bigdl.dllib.utils.log4Error import invalidInputError, invalidOperationError +from bigdl.dllib.utils.log4Error import invalidInputError logger = logging.getLogger(__name__) @@ -533,27 +533,4 @@ def shutdown(self): """ Shutdown estimator and release resources. """ - if self.log_server_thread.is_alive(): - import inspect - import ctypes - import zmq - - def _async_raise(tid, exctype): - tid = ctypes.c_long(tid) - if not inspect.isclass(exctype): - exctype = type(exctype) - res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype)) - if res == 0: - invalidInputError(False, "invalid thread id") - elif res != 1: - ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None) - invalidOperationError(False, "PyThreadState_SetAsyncExc failed") - - def stop_thread(thread): - _async_raise(thread.ident, SystemExit) - - context = zmq.Context() - socket = context.socket(zmq.REQ) - socket.connect("tcp://{}:{}".format(self.ip, self.port)) - socket.send_string("shutdown log server") - stop_thread(self.log_server_thread) + stop_log_server(self.log_server_thread, self.ip, self.port) From fed091e8b3e92398193e38e06660da1927d1e6c1 Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Tue, 23 Aug 2022 15:20:13 +0800 Subject: [PATCH 4/4] fix: fix code style issue. --- python/orca/src/bigdl/orca/learn/log_monitor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/orca/src/bigdl/orca/learn/log_monitor.py b/python/orca/src/bigdl/orca/learn/log_monitor.py index d8efe638a6e..aec830fe56a 100644 --- a/python/orca/src/bigdl/orca/learn/log_monitor.py +++ b/python/orca/src/bigdl/orca/learn/log_monitor.py @@ -165,6 +165,7 @@ def _print_logs(): logger_thread.start() return logger_thread + def stop_log_server(thread, ip, port): if thread.is_alive(): import inspect