Skip to content

Commit

Permalink
Orca: add shutdown API to tf2 pyspark estimator. (#5499)
Browse files Browse the repository at this point in the history
* feat: add shutdown API to tf2 pyspark estimator.

* fix: replace raise with log4Error.

* refactor: refactor the stop log server function.

* fix: fix code style issue.
  • Loading branch information
lalalapotter authored Aug 24, 2022
1 parent 79098b0 commit dbaab1e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
27 changes: 27 additions & 0 deletions python/orca/src/bigdl/orca/learn/log_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,30 @@ def _print_logs():
logger_thread.daemon = True
logger_thread.start()
return logger_thread


def stop_log_server(thread, ip, port):
if thread.is_alive():
import inspect
import ctypes
import zmq

def _async_raise(tid, exctype):
tid = ctypes.c_long(tid)
if not inspect.isclass(exctype):
exctype = type(exctype)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
if res == 0:
invalidInputError(False, "invalid thread id")
elif res != 1:
ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
invalidOperationError(False, "PyThreadState_SetAsyncExc failed")

def stop_thread(thread):
_async_raise(thread.ident, SystemExit)

context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect("tcp://{}:{}".format(ip, port))
socket.send_string("shutdown log server")
stop_thread(thread)
10 changes: 8 additions & 2 deletions python/orca/src/bigdl/orca/learn/tf2/pyspark_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from bigdl.orca.learn.utils import maybe_dataframe_to_xshards, dataframe_to_xshards, \
convert_predict_xshards_to_dataframe, make_data_creator, load_model, \
save_model, process_xshards_of_pandas_dataframe
from bigdl.orca.learn.log_monitor import start_log_server
from bigdl.orca.learn.log_monitor import start_log_server, stop_log_server
from bigdl.orca.data.shard import SparkXShards
from bigdl.orca import OrcaContext
from bigdl.dllib.utils.log4Error import invalidInputError
Expand Down Expand Up @@ -94,7 +94,7 @@ def __init__(self,
is_local = sc.master.startswith("local")
self.need_to_log_to_driver = (not is_local) and log_to_driver
if self.need_to_log_to_driver:
start_log_server(self.ip, self.port)
self.log_server_thread = start_log_server(self.ip, self.port)

def _get_cluster_info(self, sc):
cluster_info = self.workerRDD.barrier().mapPartitions(find_ip_and_free_port).collect()
Expand Down Expand Up @@ -528,3 +528,9 @@ def get_model(self):
@property
def _model_saved_path(self):
return os.path.join(self.model_dir, "{}_model.h5".format(self.application_id))

def shutdown(self):
"""
Shutdown estimator and release resources.
"""
stop_log_server(self.log_server_thread, self.ip, self.port)

0 comments on commit dbaab1e

Please sign in to comment.