From c29877852d3c237d7a43505262c3cfbedfc38ff1 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sun, 19 Jun 2022 22:39:24 -0500 Subject: [PATCH] [python-package] add more hints on Booster --- python-package/lightgbm/basic.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 92d482163ee4..6c38ee3a609a 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -2781,7 +2781,7 @@ def free_network(self) -> "Booster": self.network = False return self - def trees_to_dataframe(self): + def trees_to_dataframe(self) -> pd_DataFrame: """Parse the fitted model and return in an easy-to-read pandas DataFrame. The returned DataFrame has the following columns. @@ -2917,7 +2917,7 @@ def tree_dict_to_node_list(tree, node_depth=1, tree_index=None, return pd_DataFrame(model_list, columns=model_list[0].keys()) - def set_train_data_name(self, name): + def set_train_data_name(self, name: str) -> "Booster": """Set the name to the training Dataset. Parameters @@ -2933,7 +2933,7 @@ def set_train_data_name(self, name): self._train_data_name = name return self - def add_valid(self, data, name): + def add_valid(self, data: Dataset, name: str) -> "Booster": """Add validation data. Parameters @@ -2963,7 +2963,7 @@ def add_valid(self, data, name): self.__is_predicted_cur_iter.append(False) return self - def reset_parameter(self, params): + def reset_parameter(self, params: Dict[str, Any]) -> "Booster": """Reset parameters of Booster. Parameters @@ -3100,7 +3100,7 @@ def __boost(self, grad, hess): self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)] return is_finished.value == 1 - def rollback_one_iter(self): + def rollback_one_iter(self) -> "Booster": """Rollback one iteration. Returns @@ -3113,7 +3113,7 @@ def rollback_one_iter(self): self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)] return self - def current_iteration(self): + def current_iteration(self) -> int: """Get the index of the current iteration. Returns @@ -3127,7 +3127,7 @@ def current_iteration(self): ctypes.byref(out_cur_iter))) return out_cur_iter.value - def num_model_per_iteration(self): + def num_model_per_iteration(self) -> int: """Get number of models per iteration. Returns @@ -3141,7 +3141,7 @@ def num_model_per_iteration(self): ctypes.byref(model_per_iter))) return model_per_iter.value - def num_trees(self): + def num_trees(self) -> int: """Get number of weak sub-models. Returns @@ -3155,7 +3155,7 @@ def num_trees(self): ctypes.byref(num_trees))) return num_trees.value - def upper_bound(self): + def upper_bound(self) -> float: """Get upper bound value of a model. Returns @@ -3169,7 +3169,7 @@ def upper_bound(self): ctypes.byref(ret))) return ret.value - def lower_bound(self): + def lower_bound(self) -> float: """Get lower bound value of a model. Returns @@ -3353,7 +3353,7 @@ def shuffle_models(self, start_iteration=0, end_iteration=-1): ctypes.c_int(end_iteration))) return self - def model_from_string(self, model_str): + def model_from_string(self, model_str: str) -> "Booster": """Load Booster from a string. Parameters @@ -3665,7 +3665,7 @@ def refit( new_booster.network = self.network return new_booster - def get_leaf_output(self, tree_id, leaf_id): + def get_leaf_output(self, tree_id: int, leaf_id: int) -> float: """Get the output of a leaf. Parameters