diff --git a/python/orca/src/bigdl/orca/learn/log_monitor.py b/python/orca/src/bigdl/orca/learn/log_monitor.py index 16a1224e5d2..aec830fe56a 100644 --- a/python/orca/src/bigdl/orca/learn/log_monitor.py +++ b/python/orca/src/bigdl/orca/learn/log_monitor.py @@ -164,3 +164,30 @@ def _print_logs(): logger_thread.daemon = True logger_thread.start() return logger_thread + + +def stop_log_server(thread, ip, port): + if thread.is_alive(): + import inspect + import ctypes + import zmq + + def _async_raise(tid, exctype): + tid = ctypes.c_long(tid) + if not inspect.isclass(exctype): + exctype = type(exctype) + res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype)) + if res == 0: + invalidInputError(False, "invalid thread id") + elif res != 1: + ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None) + invalidOperationError(False, "PyThreadState_SetAsyncExc failed") + + def stop_thread(thread): + _async_raise(thread.ident, SystemExit) + + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.connect("tcp://{}:{}".format(ip, port)) + socket.send_string("shutdown log server") + stop_thread(thread) diff --git a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py index 53b2bfff923..d5f7dc2b7dd 100644 --- a/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py +++ b/python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py @@ -38,7 +38,7 @@ from bigdl.orca.learn.utils import maybe_dataframe_to_xshards, dataframe_to_xshards, \ convert_predict_xshards_to_dataframe, make_data_creator, load_model, \ save_model, process_xshards_of_pandas_dataframe -from bigdl.orca.learn.log_monitor import start_log_server +from bigdl.orca.learn.log_monitor import start_log_server, stop_log_server from bigdl.orca.data.shard import SparkXShards from bigdl.orca import OrcaContext from bigdl.dllib.utils.log4Error import invalidInputError @@ -94,7 +94,7 @@ def __init__(self, is_local = sc.master.startswith("local") self.need_to_log_to_driver = (not is_local) and log_to_driver if self.need_to_log_to_driver: - start_log_server(self.ip, self.port) + self.log_server_thread = start_log_server(self.ip, self.port) def _get_cluster_info(self, sc): cluster_info = self.workerRDD.barrier().mapPartitions(find_ip_and_free_port).collect() @@ -528,3 +528,9 @@ def get_model(self): @property def _model_saved_path(self): return os.path.join(self.model_dir, "{}_model.h5".format(self.application_id)) + + def shutdown(self): + """ + Shutdown estimator and release resources. + """ + stop_log_server(self.log_server_thread, self.ip, self.port)