Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] use more specific method names on _DaskLGBMModel #4004

Merged
merged 1 commit into from
Feb 20, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def client_(self) -> Client:

return _get_dask_client(client=self.client)

def _lgb_getstate(self) -> Dict[Any, Any]:
def _lgb_dask_getstate(self) -> Dict[Any, Any]:
"""Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("client", None)
self._other_params.pop("client", None)
Expand All @@ -474,7 +474,7 @@ def _lgb_getstate(self) -> Dict[Any, Any]:
self.client = client
return out

def _fit(
def _lgb_dask_fit(
self,
model_factory: Type[LGBMModel],
X: _DaskMatrixLike,
Expand All @@ -501,20 +501,20 @@ def _fit(
)

self.set_params(**model.get_params())
self._copy_extra_params(model, self)
self._lgb_dask_copy_extra_params(model, self)

return self

def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
def _lgb_dask_to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
params = self.get_params()
params.pop("client", None)
model = model_factory(**params)
self._copy_extra_params(self, model)
self._lgb_dask_copy_extra_params(self, model)
model._other_params.pop("client", None)
return model

@staticmethod
def _copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["_DaskLGBMModel", LGBMModel]) -> None:
def _lgb_dask_copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["_DaskLGBMModel", LGBMModel]) -> None:
params = source.get_params()
attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys())
Expand Down Expand Up @@ -590,7 +590,7 @@ def __init__(
__init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]

def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
return self._lgb_dask_getstate()

def fit(
self,
Expand All @@ -600,7 +600,7 @@ def fit(
**kwargs: Any
) -> "DaskLGBMClassifier":
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
return self._fit(
return self._lgb_dask_fit(
model_factory=LGBMClassifier,
X=X,
y=y,
Expand Down Expand Up @@ -670,7 +670,7 @@ def to_local(self) -> LGBMClassifier:
model : lightgbm.LGBMClassifier
Local underlying model.
"""
return self._to_local(LGBMClassifier)
return self._lgb_dask_to_local(LGBMClassifier)


class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
Expand Down Expand Up @@ -741,7 +741,7 @@ def __init__(
__init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]

def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
return self._lgb_dask_getstate()

def fit(
self,
Expand All @@ -751,7 +751,7 @@ def fit(
**kwargs: Any
) -> "DaskLGBMRegressor":
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
return self._fit(
return self._lgb_dask_fit(
model_factory=LGBMRegressor,
X=X,
y=y,
Expand Down Expand Up @@ -802,7 +802,7 @@ def to_local(self) -> LGBMRegressor:
model : lightgbm.LGBMRegressor
Local underlying model.
"""
return self._to_local(LGBMRegressor)
return self._lgb_dask_to_local(LGBMRegressor)


class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
Expand Down Expand Up @@ -873,7 +873,7 @@ def __init__(
__init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]

def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
return self._lgb_dask_getstate()

def fit(
self,
Expand All @@ -888,7 +888,7 @@ def fit(
if init_score is not None:
raise RuntimeError('init_score is not currently supported in lightgbm.dask')

return self._fit(
return self._lgb_dask_fit(
model_factory=LGBMRanker,
X=X,
y=y,
Expand Down Expand Up @@ -939,4 +939,4 @@ def to_local(self) -> LGBMRanker:
model : lightgbm.LGBMRanker
Local underlying model.
"""
return self._to_local(LGBMRanker)
return self._lgb_dask_to_local(LGBMRanker)