Skip to content

Commit

Permalink
[dask] Drop aliases of core network parameters (#3843)
Browse files Browse the repository at this point in the history
* Update dask.py

* Update basic.py

* hotfix pop
  • Loading branch information
StrikerRUS authored Jan 24, 2021
1 parent b7ccdaf commit 98a85a8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
6 changes: 6 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ class _ConfigAliases:
"local_listen_port": {"local_listen_port",
"local_port",
"port"},
"machine_list_filename": {"machine_list_filename",
"machine_list_file",
"machine_list",
"mlist"},
"machines": {"machines",
"workers",
"nodes"},
Expand All @@ -315,6 +319,8 @@ class _ConfigAliases:
"num_rounds",
"num_boost_round",
"n_estimators"},
"num_machines": {"num_machines",
"num_machine"},
"num_threads": {"num_threads",
"num_thread",
"nthread",
Expand Down
14 changes: 13 additions & 1 deletion python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
return part # trigger error locally

# Find locations of all parts and map them to particular Dask workers
key_to_part_dict = dict([(part.key, part) for part in parts])
key_to_part_dict = {part.key: part for part in parts}
who_has = client.who_has(parts)
worker_map = defaultdict(list)
for key, workers in who_has.items():
Expand Down Expand Up @@ -280,6 +280,18 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
for num_thread_alias in _ConfigAliases.get('num_threads'):
params.pop(num_thread_alias, None)

# machines is constructed manually, so remove it and all aliases of it from params
for machine_alias in _ConfigAliases.get('machines'):
params.pop(machine_alias, None)

# machines is constructed manually, so remove machine_list_filename and all aliases of it from params
for machine_list_filename_alias in _ConfigAliases.get('machine_list_filename'):
params.pop(machine_list_filename_alias, None)

# machines is constructed manually, so remove num_machines and all aliases of it from params
for num_machine_alias in _ConfigAliases.get('num_machines'):
params.pop(num_machine_alias, None)

# Tell each worker to train on the parts that it has locally
futures_classifiers = [
client.submit(
Expand Down

0 comments on commit 98a85a8

Please sign in to comment.