Skip to content

Commit

Permalink
Merge pull request #2479 from cta-observatory/disp_score
Browse files Browse the repository at this point in the history
Let the DispReconstructor also compute a score for the sign prediction
  • Loading branch information
maxnoe authored Dec 7, 2023
2 parents fbe8f05 + 6b7fbfd commit f86411d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 12 deletions.
5 changes: 5 additions & 0 deletions ctapipe/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,11 @@ class DispContainer(Container):
parameter = Field(
nan * u.deg, "reconstructed value for disp (= sign * norm)", unit=u.deg
)
sign_score = Field(
nan,
"Score for how certain the disp sign classification was."
" 0 means completely uncertain, 1 means very certain.",
)


class ReconstructedContainer(Container):
Expand Down
51 changes: 39 additions & 12 deletions ctapipe/reco/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,7 @@ def _predict(self, key, table):
)
X, valid = table_to_X(table, self.features, self.log)
prediction = np.full(len(table), np.nan)
score = np.full(len(table), np.nan)

if np.any(valid):
valid_norms = self._models[key][0].predict(X)
Expand All @@ -681,12 +682,17 @@ def _predict(self, key, table):
else:
prediction[valid] = valid_norms

prediction[valid] *= self._models[key][1].predict(X)
sign_proba = self._models[key][1].predict_proba(X)[:, 0]
# proba is [0 and 1] where 0 => very certain -1, 1 => very certain 1
# and 0.5 means random guessing either. So we transform to a score
# where 0 means "guessing" and 1 means "very certain"
score[valid] = np.abs(2 * sign_proba - 1.0)
prediction[valid] *= np.where(sign_proba >= 0.5, 1.0, -1.0)

if self.unit is not None:
prediction = u.Quantity(prediction, self.unit, copy=False)

return prediction, valid
return prediction, score, valid

def __call__(self, event: ArrayEventContainer) -> None:
"""Event-wise prediction for the EventSource-Loop.
Expand All @@ -705,10 +711,15 @@ def __call__(self, event: ArrayEventContainer) -> None:
passes_quality_checks = self.quality_query.get_table_mask(table)[0]

if passes_quality_checks:
disp, valid = self._predict(self.subarray.tel[tel_id], table)
disp, sign_score, valid = self._predict(
self.subarray.tel[tel_id], table
)

if valid:
disp_container = DispContainer(parameter=disp[0])
disp_container = DispContainer(
parameter=disp[0],
sign_score=sign_score[0],
)

hillas = event.dl1.tel[tel_id].parameters.hillas
psi = hillas.psi.to_value(u.rad)
Expand Down Expand Up @@ -775,11 +786,19 @@ def predict_table(self, key, table: Table) -> Dict[ReconstructionProperty, Table
n_rows = len(table)
disp = u.Quantity(np.full(n_rows, np.nan), self.unit, copy=False)
is_valid = np.full(n_rows, False)
sign_score = np.full(n_rows, np.nan)

valid = self.quality_query.get_table_mask(table)
disp[valid], is_valid[valid] = self._predict(key, table[valid])
disp[valid], sign_score[valid], is_valid[valid] = self._predict(
key, table[valid]
)

disp_result = Table({f"{self.prefix}_tel_parameter": disp})
disp_result = Table(
{
f"{self.prefix}_tel_parameter": disp,
f"{self.prefix}_tel_sign_score": sign_score,
}
)
add_defaults_and_meta(
disp_result,
DispContainer,
Expand Down Expand Up @@ -917,10 +936,10 @@ def __call__(self, telescope_type, table):
{
"cv_fold": np.full(len(truth), fold, dtype=np.uint8),
"tel_type": [str(telescope_type)] * len(truth),
"prediction": cv_prediction,
"truth": truth,
"true_energy": test["true_energy"],
"true_impact_distance": test["true_impact_distance"],
**cv_prediction,
}
)
)
Expand All @@ -945,7 +964,7 @@ def _cross_validate_regressor(self, telescope_type, train, test):
prediction, _ = regressor._predict(telescope_type, test)
truth = test[regressor.target]
r2 = r2_score(truth, prediction)
return prediction, truth, {"R^2": r2}
return {f"{regressor.prefix}_energy": prediction}, truth, {"R^2": r2}

def _cross_validate_classification(self, telescope_type, train, test):
classifier = self.model_component
Expand All @@ -957,15 +976,23 @@ def _cross_validate_classification(self, telescope_type, train, test):
0,
)
roc_auc = roc_auc_score(truth, prediction)
return prediction, truth, {"ROC AUC": roc_auc}
return (
{f"{classifier.prefix}_prediction": prediction},
truth,
{"ROC AUC": roc_auc},
)

def _cross_validate_disp(self, telescope_type, train, test):
models = self.model_component
models.fit(telescope_type, train)
prediction, _ = models._predict(telescope_type, test)
disp, sign_score, _ = models._predict(telescope_type, test)
truth = test[models.target]
r2 = r2_score(np.abs(truth), np.abs(prediction))
accuracy = accuracy_score(np.sign(truth), np.sign(prediction))
r2 = r2_score(np.abs(truth), np.abs(disp))
accuracy = accuracy_score(np.sign(truth), np.sign(disp))
prediction = {
f"{models.prefix}_parameter": disp,
f"{models.prefix}_sign_score": sign_score,
}
return prediction, truth, {"R^2": r2, "accuracy": accuracy}

def write(self, overwrite=False):
Expand Down
1 change: 1 addition & 0 deletions docs/changes/2479.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The ``DispReconstructor`` now computes a score for how certain the prediction of the disp sign is.

0 comments on commit f86411d

Please sign in to comment.