Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] add more type hints on Booster #5309

Merged
merged 1 commit into from
Jun 23, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2781,7 +2781,7 @@ def free_network(self) -> "Booster":
self.network = False
return self

def trees_to_dataframe(self):
def trees_to_dataframe(self) -> pd_DataFrame:
"""Parse the fitted model and return in an easy-to-read pandas DataFrame.

The returned DataFrame has the following columns.
Expand Down Expand Up @@ -2917,7 +2917,7 @@ def tree_dict_to_node_list(tree, node_depth=1, tree_index=None,

return pd_DataFrame(model_list, columns=model_list[0].keys())

def set_train_data_name(self, name):
def set_train_data_name(self, name: str) -> "Booster":
"""Set the name to the training Dataset.

Parameters
Expand All @@ -2933,7 +2933,7 @@ def set_train_data_name(self, name):
self._train_data_name = name
return self

def add_valid(self, data, name):
def add_valid(self, data: Dataset, name: str) -> "Booster":
"""Add validation data.

Parameters
Expand Down Expand Up @@ -2963,7 +2963,7 @@ def add_valid(self, data, name):
self.__is_predicted_cur_iter.append(False)
return self

def reset_parameter(self, params):
def reset_parameter(self, params: Dict[str, Any]) -> "Booster":
"""Reset parameters of Booster.

Parameters
Expand Down Expand Up @@ -3100,7 +3100,7 @@ def __boost(self, grad, hess):
self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
return is_finished.value == 1

def rollback_one_iter(self):
def rollback_one_iter(self) -> "Booster":
"""Rollback one iteration.

Returns
Expand All @@ -3113,7 +3113,7 @@ def rollback_one_iter(self):
self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
return self

def current_iteration(self):
def current_iteration(self) -> int:
"""Get the index of the current iteration.

Returns
Expand All @@ -3127,7 +3127,7 @@ def current_iteration(self):
ctypes.byref(out_cur_iter)))
return out_cur_iter.value

def num_model_per_iteration(self):
def num_model_per_iteration(self) -> int:
"""Get number of models per iteration.

Returns
Expand All @@ -3141,7 +3141,7 @@ def num_model_per_iteration(self):
ctypes.byref(model_per_iter)))
return model_per_iter.value

def num_trees(self):
def num_trees(self) -> int:
"""Get number of weak sub-models.

Returns
Expand All @@ -3155,7 +3155,7 @@ def num_trees(self):
ctypes.byref(num_trees)))
return num_trees.value

def upper_bound(self):
def upper_bound(self) -> float:
"""Get upper bound value of a model.

Returns
Expand All @@ -3169,7 +3169,7 @@ def upper_bound(self):
ctypes.byref(ret)))
return ret.value

def lower_bound(self):
def lower_bound(self) -> float:
"""Get lower bound value of a model.

Returns
Expand Down Expand Up @@ -3353,7 +3353,7 @@ def shuffle_models(self, start_iteration=0, end_iteration=-1):
ctypes.c_int(end_iteration)))
return self

def model_from_string(self, model_str):
def model_from_string(self, model_str: str) -> "Booster":
"""Load Booster from a string.

Parameters
Expand Down Expand Up @@ -3665,7 +3665,7 @@ def refit(
new_booster.network = self.network
return new_booster

def get_leaf_output(self, tree_id, leaf_id):
def get_leaf_output(self, tree_id: int, leaf_id: int) -> float:
"""Get the output of a leaf.

Parameters
Expand Down