Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunked loading of training data #2423

Merged
merged 13 commits into from
Nov 17, 2023
Merged
87 changes: 31 additions & 56 deletions ctapipe/tools/train_disp_reconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import astropy.units as u
import numpy as np

from ctapipe.containers import CoordinateFrameType
from ctapipe.core import Tool
from ctapipe.core.traits import Bool, Int, IntTelescopeParameter, Path
from ctapipe.exceptions import TooFewEvents
from ctapipe.io import TableLoader
from ctapipe.reco import CrossValidator, DispReconstructor
from ctapipe.reco.preprocessing import check_valid_rows, horizontal_to_telescope
from ctapipe.reco.preprocessing import horizontal_to_telescope

from .utils import read_training_events

__all__ = [
"TrainDispReconstructor",
Expand Down Expand Up @@ -56,6 +56,12 @@ class TrainDispReconstructor(Tool):
),
).tag(config=True)

chunk_size = Int(
default_value=100000,
allow_none=True,
help="How many subarray events to load at once before training on n_events.",
).tag(config=True)

random_seed = Int(
default_value=0, help="Random seed for sampling and cross validation"
).tag(config=True)
Expand Down Expand Up @@ -111,7 +117,28 @@ def start(self):
self.log.info("Training models for %d types", len(types))
for tel_type in types:
self.log.info("Loading events for %s", tel_type)
table = self._read_table(tel_type)
feature_names = self.models.features + [
"true_energy",
"subarray_pointing_lat",
"subarray_pointing_lon",
"true_alt",
"true_az",
"hillas_fov_lat",
"hillas_fov_lon",
"hillas_psi",
]
table = read_training_events(
loader=self.loader,
chunk_size=self.chunk_size,
telescope_type=tel_type,
reconstructor=self.models,
feature_names=feature_names,
rng=self.rng,
log=self.log,
n_events=self.n_events.tel[tel_type],
)
table[self.models.target] = self._get_true_disp(table)
table = table[self.models.features + [self.models.target, "true_energy"]]

self.log.info("Train models on %s events", len(table))
self.cross_validate(tel_type, table)
Expand All @@ -120,58 +147,6 @@ def start(self):
self.models.fit(tel_type, table)
self.log.info("done")

def _read_table(self, telescope_type):
table = self.loader.read_telescope_events([telescope_type])
self.log.info("Events read from input: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"Input file does not contain any events for telescope type {telescope_type}"
)

mask = self.models.quality_query.get_table_mask(table)
table = table[mask]
self.log.info("Events after applying quality query: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"No events after quality query for telescope type {telescope_type}"
)

if not np.all(
table["subarray_pointing_frame"] == CoordinateFrameType.ALTAZ.value
):
raise ValueError(
"Pointing information for training data has to be provided in horizontal coordinates"
)

table = self.models.feature_generator(table, subarray=self.loader.subarray)

table[self.models.target] = self._get_true_disp(table)

# Add true energy for energy-dependent performance plots
columns = self.models.features + [self.models.target, "true_energy"]
table = table[columns]

valid = check_valid_rows(table)
if np.any(~valid):
self.log.warning("Dropping non-predicable events.")
table = table[valid]

n_events = self.n_events.tel[telescope_type]
if n_events is not None:
if n_events > len(table):
self.log.warning(
"Number of events in table (%d) is less than requested number of events %d",
len(table),
n_events,
)
else:
self.log.info("Sampling %d events", n_events)
idx = self.rng.choice(len(table), n_events, replace=False)
idx.sort()
table = table[idx]

return table

def _get_true_disp(self, table):
fov_lon, fov_lat = horizontal_to_telescope(
alt=table["true_alt"],
Expand Down
65 changes: 20 additions & 45 deletions ctapipe/tools/train_energy_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from ctapipe.core import Tool
from ctapipe.core.traits import Int, IntTelescopeParameter, Path
from ctapipe.exceptions import TooFewEvents
from ctapipe.io import TableLoader
from ctapipe.reco import CrossValidator, EnergyRegressor
from ctapipe.reco.preprocessing import check_valid_rows

from .utils import read_training_events

__all__ = [
"TrainEnergyRegressor",
Expand Down Expand Up @@ -53,6 +53,12 @@ class TrainEnergyRegressor(Tool):
),
).tag(config=True)

chunk_size = Int(
default_value=100000,
allow_none=True,
help="How many subarray events to load at once before training on n_events.",
).tag(config=True)

random_seed = Int(
default_value=0, help="Random seed for sampling and cross validation"
).tag(config=True)
Expand All @@ -61,6 +67,7 @@ class TrainEnergyRegressor(Tool):
("i", "input"): "TableLoader.input_url",
("o", "output"): "TrainEnergyRegressor.output_path",
"n-events": "TrainEnergyRegressor.n_events",
"chunk-size": "TrainEnergyRegressor.chunk_size",
"cv-output": "CrossValidator.output_path",
}

Expand Down Expand Up @@ -103,7 +110,17 @@ def start(self):
self.log.info("Training models for %d types", len(types))
for tel_type in types:
self.log.info("Loading events for %s", tel_type)
table = self._read_table(tel_type)
feature_names = self.regressor.features + [self.regressor.target]
table = read_training_events(
loader=self.loader,
chunk_size=self.chunk_size,
telescope_type=tel_type,
reconstructor=self.regressor,
feature_names=feature_names,
rng=self.rng,
log=self.log,
n_events=self.n_events.tel[tel_type],
)

self.log.info("Train on %s events", len(table))
self.cross_validate(tel_type, table)
Expand All @@ -112,48 +129,6 @@ def start(self):
self.regressor.fit(tel_type, table)
self.log.info("done")

def _read_table(self, telescope_type):
table = self.loader.read_telescope_events([telescope_type])
self.log.info("Events read from input: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"Input file does not contain any events for telescope type {telescope_type}"
)

mask = self.regressor.quality_query.get_table_mask(table)
table = table[mask]
self.log.info("Events after applying quality query: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"No events after quality query for telescope type {telescope_type}"
)

table = self.regressor.feature_generator(table, subarray=self.loader.subarray)

feature_names = self.regressor.features + [self.regressor.target]
table = table[feature_names]

valid = check_valid_rows(table)
if np.any(~valid):
self.log.warning("Dropping non-predictable events.")
table = table[valid]

n_events = self.n_events.tel[telescope_type]
if n_events is not None:
if n_events > len(table):
self.log.warning(
"Number of events in table (%d) is less than requested number of events %d",
len(table),
n_events,
)
else:
self.log.info("Sampling %d events", n_events)
idx = self.rng.choice(len(table), n_events, replace=False)
idx.sort()
table = table[idx]

return table

def finish(self):
"""
Write-out trained models and cross-validation results.
Expand Down
81 changes: 33 additions & 48 deletions ctapipe/tools/train_particle_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from ctapipe.core.tool import Tool
from ctapipe.core.traits import Int, IntTelescopeParameter, Path
from ctapipe.exceptions import TooFewEvents
from ctapipe.io import TableLoader
from ctapipe.reco import CrossValidator, ParticleClassifier
from ctapipe.reco.preprocessing import check_valid_rows

from .utils import read_training_events

__all__ = [
"TrainParticleClassifier",
Expand Down Expand Up @@ -78,6 +78,15 @@ class TrainParticleClassifier(Tool):
),
).tag(config=True)

chunk_size = Int(
default_value=100000,
allow_none=True,
help=(
"How many subarray events to load at once before training on"
" n_signal and n_background events."
),
).tag(config=True)

random_seed = Int(
default_value=0,
help="Random number seed for sampling and the cross validation splitting",
Expand Down Expand Up @@ -161,54 +170,30 @@ def start(self):
self.classifier.fit(tel_type, table)
self.log.info("done")

def _read_table(self, telescope_type, loader, n_events=None):
table = loader.read_telescope_events([telescope_type])
self.log.info("Events read from input: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"Input file does not contain any events for telescope type {telescope_type}"
)

mask = self.classifier.quality_query.get_table_mask(table)
table = table[mask]
self.log.info("Events after applying quality query: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"No events after quality query for telescope type {telescope_type}"
)

table = self.classifier.feature_generator(table, subarray=self.subarray)

# Add true energy for energy-dependent performance plots
columns = self.classifier.features + [self.classifier.target, "true_energy"]
table = table[columns]

valid = check_valid_rows(table)
if np.any(~valid):
self.log.warning("Dropping non-predictable events.")
table = table[valid]

if n_events is not None:
if n_events > len(table):
self.log.warning(
"Number of events in table (%d) is less than requested number of events %d",
len(table),
n_events,
)
else:
self.log.info("Sampling %d events", n_events)
idx = self.rng.choice(len(table), n_events, replace=False)
idx.sort()
table = table[idx]

return table

def _read_input_data(self, tel_type):
signal = self._read_table(
tel_type, self.signal_loader, self.n_signal.tel[tel_type]
feature_names = self.classifier.features + [
self.classifier.target,
"true_energy",
]
signal = read_training_events(
loader=self.signal_loader,
chunk_size=self.chunk_size,
telescope_type=tel_type,
reconstructor=self.classifier,
feature_names=feature_names,
rng=self.rng,
log=self.log,
n_events=self.n_signal.tel[tel_type],
)
background = self._read_table(
tel_type, self.background_loader, self.n_background.tel[tel_type]
background = read_training_events(
loader=self.background_loader,
chunk_size=self.chunk_size,
telescope_type=tel_type,
reconstructor=self.classifier,
feature_names=feature_names,
rng=self.rng,
log=self.log,
n_events=self.n_background.tel[tel_type],
)
table = vstack([signal, background])
self.log.info(
Expand Down
Loading