Skip to content

Commit

Permalink
Merge pull request #2295 from cta-observatory/better_errors
Browse files Browse the repository at this point in the history
Better errors in case of no events in train tools
  • Loading branch information
maxnoe authored Mar 30, 2023
2 parents 53ee8cd + d8b025a commit 9b8646e
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 4 deletions.
6 changes: 6 additions & 0 deletions ctapipe/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class CTAPipeException(Exception):
pass


class TooFewEvents(CTAPipeException):
"""Raised if something that needs a minimum number of event gets fewer"""
4 changes: 3 additions & 1 deletion ctapipe/reco/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from tqdm import tqdm
from traitlets import TraitError

from ctapipe.exceptions import TooFewEvents

from ..containers import (
ArrayEventContainer,
DispContainer,
Expand Down Expand Up @@ -847,7 +849,7 @@ def __init__(self, model_component, **kwargs):

def __call__(self, telescope_type, table):
if len(table) <= self.n_cross_validations:
raise ValueError(f"Too few events for {telescope_type}.")
raise TooFewEvents(f"Too few events for {telescope_type}.")

self.log.info(
"Starting cross-validation with %d folds for type %s.",
Expand Down
3 changes: 2 additions & 1 deletion ctapipe/tools/tests/test_train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from ctapipe.core import run_tool
from ctapipe.exceptions import TooFewEvents
from ctapipe.utils.datasets import resource_file


Expand Down Expand Up @@ -29,7 +30,7 @@ def test_too_few_events(tmp_path, dl2_shower_geometry_file):
config = resource_file("train_energy_regressor.yaml")
out_file = tmp_path / "energy.pkl"

with pytest.raises(ValueError, match="Too few events"):
with pytest.raises(TooFewEvents, match="No events after quality query"):
run_tool(
tool,
argv=[
Expand Down
9 changes: 9 additions & 0 deletions ctapipe/tools/train_disp_reconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from ctapipe.core import Tool
from ctapipe.core.traits import Bool, Int, IntTelescopeParameter, Path
from ctapipe.exceptions import TooFewEvents
from ctapipe.io import TableLoader
from ctapipe.reco import CrossValidator, DispReconstructor
from ctapipe.reco.preprocessing import check_valid_rows, horizontal_to_telescope
Expand Down Expand Up @@ -113,10 +114,18 @@ def start(self):
def _read_table(self, telescope_type):
table = self.loader.read_telescope_events([telescope_type])
self.log.info("Events read from input: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"Input file does not contain any events for telescope type {telescope_type}"
)

mask = self.models.quality_query.get_table_mask(table)
table = table[mask]
self.log.info("Events after applying quality query: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"No events after quality query for telescope type {telescope_type}"
)

table = self.models.feature_generator(table)

Expand Down
11 changes: 10 additions & 1 deletion ctapipe/tools/train_energy_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ctapipe.core import Tool
from ctapipe.core.traits import Int, IntTelescopeParameter, Path
from ctapipe.exceptions import TooFewEvents
from ctapipe.io import TableLoader
from ctapipe.reco import CrossValidator, EnergyRegressor
from ctapipe.reco.preprocessing import check_valid_rows
Expand Down Expand Up @@ -113,11 +114,19 @@ def start(self):

def _read_table(self, telescope_type):
table = self.loader.read_telescope_events([telescope_type])

self.log.info("Events read from input: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"Input file does not contain any events for telescope type {telescope_type}"
)

mask = self.regressor.quality_query.get_table_mask(table)
table = table[mask]
self.log.info("Events after applying quality query: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"No events after quality query for telescope type {telescope_type}"
)

table = self.regressor.feature_generator(table)

Expand Down
11 changes: 10 additions & 1 deletion ctapipe/tools/train_particle_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ctapipe.core.tool import Tool
from ctapipe.core.traits import Int, IntTelescopeParameter, Path
from ctapipe.exceptions import TooFewEvents
from ctapipe.io import TableLoader
from ctapipe.reco import CrossValidator, ParticleClassifier
from ctapipe.reco.preprocessing import check_valid_rows
Expand Down Expand Up @@ -164,11 +165,19 @@ def start(self):

def _read_table(self, telescope_type, loader, n_events=None):
table = loader.read_telescope_events([telescope_type])

self.log.info("Events read from input: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"Input file does not contain any events for telescope type {telescope_type}"
)

mask = self.classifier.quality_query.get_table_mask(table)
table = table[mask]
self.log.info("Events after applying quality query: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"No events after quality query for telescope type {telescope_type}"
)

table = self.classifier.feature_generator(table)

Expand Down
2 changes: 2 additions & 0 deletions docs/changes/2295.maintenance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
The tools to train ml models now provide better error messages in case
the input files did not contain any events for specific telescope types.

0 comments on commit 9b8646e

Please sign in to comment.