From 6b7fbfde6801b6468f5afc39cb2d21c6424483c7 Mon Sep 17 00:00:00 2001 From: Maximilian Linhoff Date: Wed, 6 Dec 2023 12:57:11 +0100 Subject: [PATCH] Let the DispReconstructor also compute a score for the sign prediction --- ctapipe/containers.py | 5 ++++ ctapipe/reco/sklearn.py | 51 ++++++++++++++++++++++++++--------- docs/changes/2479.feature.rst | 1 + 3 files changed, 45 insertions(+), 12 deletions(-) create mode 100644 docs/changes/2479.feature.rst diff --git a/ctapipe/containers.py b/ctapipe/containers.py index 5a3f7c54892..29e0d1b693f 100644 --- a/ctapipe/containers.py +++ b/ctapipe/containers.py @@ -1005,6 +1005,11 @@ class DispContainer(Container): parameter = Field( nan * u.deg, "reconstructed value for disp (= sign * norm)", unit=u.deg ) + sign_score = Field( + nan, + "Score for how certain the disp sign classification was." + " 0 means completely uncertain, 1 means very certain.", + ) class ReconstructedContainer(Container): diff --git a/ctapipe/reco/sklearn.py b/ctapipe/reco/sklearn.py index 230f6af36d4..2b08b891e54 100644 --- a/ctapipe/reco/sklearn.py +++ b/ctapipe/reco/sklearn.py @@ -672,6 +672,7 @@ def _predict(self, key, table): ) X, valid = table_to_X(table, self.features, self.log) prediction = np.full(len(table), np.nan) + score = np.full(len(table), np.nan) if np.any(valid): valid_norms = self._models[key][0].predict(X) @@ -681,12 +682,17 @@ def _predict(self, key, table): else: prediction[valid] = valid_norms - prediction[valid] *= self._models[key][1].predict(X) + sign_proba = self._models[key][1].predict_proba(X)[:, 0] + # proba is [0 and 1] where 0 => very certain -1, 1 => very certain 1 + # and 0.5 means random guessing either. So we transform to a score + # where 0 means "guessing" and 1 means "very certain" + score[valid] = np.abs(2 * sign_proba - 1.0) + prediction[valid] *= np.where(sign_proba >= 0.5, 1.0, -1.0) if self.unit is not None: prediction = u.Quantity(prediction, self.unit, copy=False) - return prediction, valid + return prediction, score, valid def __call__(self, event: ArrayEventContainer) -> None: """Event-wise prediction for the EventSource-Loop. @@ -705,10 +711,15 @@ def __call__(self, event: ArrayEventContainer) -> None: passes_quality_checks = self.quality_query.get_table_mask(table)[0] if passes_quality_checks: - disp, valid = self._predict(self.subarray.tel[tel_id], table) + disp, sign_score, valid = self._predict( + self.subarray.tel[tel_id], table + ) if valid: - disp_container = DispContainer(parameter=disp[0]) + disp_container = DispContainer( + parameter=disp[0], + sign_score=sign_score[0], + ) hillas = event.dl1.tel[tel_id].parameters.hillas psi = hillas.psi.to_value(u.rad) @@ -775,11 +786,19 @@ def predict_table(self, key, table: Table) -> Dict[ReconstructionProperty, Table n_rows = len(table) disp = u.Quantity(np.full(n_rows, np.nan), self.unit, copy=False) is_valid = np.full(n_rows, False) + sign_score = np.full(n_rows, np.nan) valid = self.quality_query.get_table_mask(table) - disp[valid], is_valid[valid] = self._predict(key, table[valid]) + disp[valid], sign_score[valid], is_valid[valid] = self._predict( + key, table[valid] + ) - disp_result = Table({f"{self.prefix}_tel_parameter": disp}) + disp_result = Table( + { + f"{self.prefix}_tel_parameter": disp, + f"{self.prefix}_tel_sign_score": sign_score, + } + ) add_defaults_and_meta( disp_result, DispContainer, @@ -917,10 +936,10 @@ def __call__(self, telescope_type, table): { "cv_fold": np.full(len(truth), fold, dtype=np.uint8), "tel_type": [str(telescope_type)] * len(truth), - "prediction": cv_prediction, "truth": truth, "true_energy": test["true_energy"], "true_impact_distance": test["true_impact_distance"], + **cv_prediction, } ) ) @@ -945,7 +964,7 @@ def _cross_validate_regressor(self, telescope_type, train, test): prediction, _ = regressor._predict(telescope_type, test) truth = test[regressor.target] r2 = r2_score(truth, prediction) - return prediction, truth, {"R^2": r2} + return {f"{regressor.prefix}_energy": prediction}, truth, {"R^2": r2} def _cross_validate_classification(self, telescope_type, train, test): classifier = self.model_component @@ -957,15 +976,23 @@ def _cross_validate_classification(self, telescope_type, train, test): 0, ) roc_auc = roc_auc_score(truth, prediction) - return prediction, truth, {"ROC AUC": roc_auc} + return ( + {f"{classifier.prefix}_prediction": prediction}, + truth, + {"ROC AUC": roc_auc}, + ) def _cross_validate_disp(self, telescope_type, train, test): models = self.model_component models.fit(telescope_type, train) - prediction, _ = models._predict(telescope_type, test) + disp, sign_score, _ = models._predict(telescope_type, test) truth = test[models.target] - r2 = r2_score(np.abs(truth), np.abs(prediction)) - accuracy = accuracy_score(np.sign(truth), np.sign(prediction)) + r2 = r2_score(np.abs(truth), np.abs(disp)) + accuracy = accuracy_score(np.sign(truth), np.sign(disp)) + prediction = { + f"{models.prefix}_parameter": disp, + f"{models.prefix}_sign_score": sign_score, + } return prediction, truth, {"R^2": r2, "accuracy": accuracy} def write(self, overwrite=False): diff --git a/docs/changes/2479.feature.rst b/docs/changes/2479.feature.rst new file mode 100644 index 00000000000..2a6b96a6013 --- /dev/null +++ b/docs/changes/2479.feature.rst @@ -0,0 +1 @@ +The ``DispReconstructor`` now computes a score for how certain the prediction of the disp sign is.