Skip to content

Commit

Permalink
Fix mypi for multiclass classification for lgbm classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
otaviocv committed Nov 8, 2023
1 parent 2730c07 commit c6bbca6
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/fklearn/training/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,8 @@ def lgbm_classification_learner(

def p(new_df: pd.DataFrame, apply_shap: bool = False) -> pd.DataFrame:
predictions = bst.predict(new_df[features].values)
if isinstance(predictions, List):
predictions = np.ndarray(predictions)
if is_multiclass_classification:
col_dict = {prediction_column + "_" + str(key): value
for (key, value) in enumerate(predictions.T)}
Expand Down

0 comments on commit c6bbca6

Please sign in to comment.