Skip to content

Commit

Permalink
Merge branch 'master' into otaviocv/remove-scikit-0-25-constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
otaviocv authored Nov 8, 2023
2 parents d20dd44 + 054d319 commit 2730c07
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions src/fklearn/training/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,10 +643,14 @@ def lgbm_classification_learner(

import lightgbm as lgbm

LGBM_MULTICLASS_OBJECTIVES = {'multiclass', 'softmax', 'multiclassova', 'multiclass_ova', 'ova', 'ovr'}

params = extra_params if extra_params else {}
params = assoc(params, "eta", learning_rate)
params = params if "objective" in params else assoc(params, "objective", 'binary')

is_multiclass_classification = params["objective"] in LGBM_MULTICLASS_OBJECTIVES

weights = df[weight_column].values if weight_column else None

features = features if not encode_extra_cols else expand_features_encoded(df, features)
Expand Down Expand Up @@ -674,23 +678,20 @@ def lgbm_classification_learner(
)

def p(new_df: pd.DataFrame, apply_shap: bool = False) -> pd.DataFrame:
if params["objective"] == "multiclass":
predictions = bst.predict(new_df[features].values)
if not isinstance(predictions, List):
predictions_transposed = predictions.T
else:
predictions_transposed = predictions
col_dict = {prediction_column + "_" + str(key): value for (key, value) in enumerate(predictions_transposed)}
predictions = bst.predict(new_df[features].values)
if is_multiclass_classification:
col_dict = {prediction_column + "_" + str(key): value
for (key, value) in enumerate(predictions.T)}
else:
col_dict = {prediction_column: bst.predict(new_df[features].values)}
col_dict = {prediction_column: predictions}

if apply_shap:
import shap
explainer = shap.TreeExplainer(bst)
shap_values = explainer.shap_values(new_df[features])
shap_expected_value = explainer.expected_value

if params["objective"] == "multiclass":
if is_multiclass_classification:
shap_values_multiclass = {f"shap_values_{class_index}": list(value)
for (class_index, value) in enumerate(shap_values)}
shap_expected_value_multiclass = {
Expand Down

0 comments on commit 2730c07

Please sign in to comment.