From 5f04971169d8cbc63788ce26074252d396e0f070 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 14 Jun 2024 16:57:08 +0800 Subject: [PATCH 1/2] Allow blocking launch of federated tracker. --- plugin/example/custom_obj.cc | 2 +- python-package/xgboost/federated.py | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/plugin/example/custom_obj.cc b/plugin/example/custom_obj.cc index b996447a3cd6..5d61e812ac9b 100644 --- a/plugin/example/custom_obj.cc +++ b/plugin/example/custom_obj.cc @@ -69,7 +69,7 @@ class MyLogistic : public ObjFunction { void SaveConfig(Json* p_out) const override { auto& out = *p_out; - out["name"] = String("my_logistic"); + out["name"] = String("mylogistic"); out["my_logistic_param"] = ToJson(param_); } diff --git a/python-package/xgboost/federated.py b/python-package/xgboost/federated.py index dcba9ec81a68..2e42c03ac967 100644 --- a/python-package/xgboost/federated.py +++ b/python-package/xgboost/federated.py @@ -65,9 +65,19 @@ def run_federated_server( # pylint: disable=too-many-arguments server_key_path: Optional[str] = None, server_cert_path: Optional[str] = None, client_cert_path: Optional[str] = None, + blocking: bool = True, timeout: int = 300, -) -> Dict[str, Any]: - """See :py:class:`~xgboost.federated.FederatedTracker` for more info.""" +) -> Optional[Dict[str, Any]]: + """See :py:class:`~xgboost.federated.FederatedTracker` for more info. + + Parameters + ---------- + blocking : + Block the server until the training is finished. If set to False, the function + launches an additional thread and returns the worker arguments. The default is + True and a higher level framework is responsible for setting worker parameters. + + """ args: Dict[str, Any] = {"n_workers": n_workers} secure = all( path is not None @@ -78,6 +88,10 @@ def run_federated_server( # pylint: disable=too-many-arguments ) tracker.start() + if blocking: + tracker.wait_for() + return None + thread = Thread(target=tracker.wait_for) thread.daemon = True thread.start() From 450bc29d39790b7a4e4a3646ec7882f1a4369e80 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sun, 16 Jun 2024 00:36:10 +0800 Subject: [PATCH 2/2] Do not block for unittest. --- tests/python/test_collective.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index a3923e9df4e4..2beedf8a1caf 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -63,7 +63,7 @@ def test_federated_communicator(): world_size = 2 tracker = multiprocessing.Process( target=federated.run_federated_server, - kwargs={"port": port, "n_workers": world_size}, + kwargs={"port": port, "n_workers": world_size, "blocking": False}, ) tracker.start() if not tracker.is_alive():