Skip to content

Commit

Permalink
Merge pull request #2126 from cta-observatory/hstack_loader
Browse files Browse the repository at this point in the history
Speed up loader by using hstack instead of merge where possible
  • Loading branch information
maxnoe authored Nov 21, 2022
2 parents 248dfd0 + 799925c commit 56fce14
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 8 deletions.
52 changes: 44 additions & 8 deletions ctapipe/io/tableloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
Class and related functions to read DL1 (a,b) and/or DL2 (a) data
from an HDF5 file produced with ctapipe-process.
"""
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict

import numpy as np
import tables
from astropy.table import Table, vstack
from astropy.table import Table, hstack, vstack
from astropy.utils.decorators import lazyproperty

from ctapipe.instrument.optics import FocalLengthKind
Expand Down Expand Up @@ -37,6 +38,10 @@
TELESCOPE_EVENT_KEYS = ["obs_id", "event_id", "tel_id"]


class IndexNotMatching(UserWarning):
"""Warning that is raised if the order of two tables is not matching as expected"""


class ChunkIterator:
"""An iterator that calls a function on advancemnt
Expand Down Expand Up @@ -113,6 +118,32 @@ def _join_telescope_events(table1, table2):
return join_allow_empty(table1, table2, TELESCOPE_EVENT_KEYS, how)


def _merge_table_same_index(table1, table2, index_keys, fallback_join_type="left"):
"""Merge two tables assuming their primary keys are identical"""
if len(table1) != len(table2):
raise ValueError("Tables must have identical length")

if len(table1) == 0:
return table1

if not np.all(table1[index_keys] == table2[index_keys]):
warnings.warn(
"Table order does not match, falling back to join", IndexNotMatching
)
return join_allow_empty(table1, table2, index_keys, fallback_join_type)

columns = [col for col in table2.columns if col not in index_keys]
return hstack((table1, table2[columns]), join_type="exact")


def _merge_subarray_tables(table1, table2):
return _merge_table_same_index(table1, table2, SUBARRAY_EVENT_KEYS)


def _merge_telescope_tables(table1, table2):
return _merge_table_same_index(table1, table2, TELESCOPE_EVENT_KEYS)


class TableLoader(Component):
"""
Load telescope-event or subarray-event data from ctapipe HDF5 files
Expand Down Expand Up @@ -310,7 +341,7 @@ def read_subarray_events(self, start=None, stop=None, keep_order=True):

if self.load_simulated and SHOWER_TABLE in self.h5file:
showers = read_table(self.h5file, SHOWER_TABLE, start=start, stop=stop)
table = _join_subarray_events(table, showers)
table = _merge_subarray_tables(table, showers)

if self.load_dl2:
if DL2_SUBARRAY_GROUP in self.h5file:
Expand All @@ -325,7 +356,7 @@ def read_subarray_events(self, start=None, stop=None, keep_order=True):
start=start,
stop=stop,
)
table = _join_subarray_events(table, dl2)
table = _merge_subarray_tables(table, dl2)

if self.load_observation_info:
table = self._join_observation_info(table, start=start, stop=stop)
Expand Down Expand Up @@ -371,19 +402,21 @@ def _read_telescope_events_for_id(self, tel_id, start=None, stop=None):
if tel_id is None:
raise ValueError("Please, specify a telescope ID.")

table = _empty_telescope_events_table()
table = read_table(self.h5file, "/dl1/event/telescope/trigger")
table = table[table["tel_id"] == tel_id]
table = table[slice(start, stop)]

if self.load_dl1_parameters:
parameters = self._read_telescope_table(
PARAMETERS_GROUP, tel_id, start=start, stop=stop
)
table = _join_telescope_events(table, parameters)
table = _merge_telescope_tables(table, parameters)

if self.load_dl1_images:
images = self._read_telescope_table(
IMAGES_GROUP, tel_id, start=start, stop=stop
)
table = _join_telescope_events(table, images)
table = _merge_telescope_tables(table, images)

if self.load_dl2:
if DL2_TELESCOPE_GROUP in self.h5file:
Expand All @@ -397,13 +430,16 @@ def _read_telescope_events_for_id(self, tel_id, start=None, stop=None):
dl2 = self._read_telescope_table(
path, tel_id, start=start, stop=stop
)
table = _join_telescope_events(table, dl2)
if len(dl2) == 0:
continue

table = _merge_telescope_tables(table, dl2)

if self.load_true_images:
true_images = self._read_telescope_table(
TRUE_IMAGES_GROUP, tel_id, start=start, stop=stop
)
table = _join_telescope_events(table, true_images)
table = _merge_telescope_tables(table, true_images)

if self.load_true_parameters:
true_parameters = self._read_telescope_table(
Expand Down
11 changes: 11 additions & 0 deletions ctapipe/tools/apply_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from ctapipe.core.tool import Tool
from ctapipe.core.traits import Bool, Path, flag
from ctapipe.io import TableLoader, write_table
from ctapipe.io.astropy_helpers import read_table
from ctapipe.io.tableio import TelListToMaskTransform
from ctapipe.io.tableloader import _join_subarray_events
from ctapipe.reco import EnergyRegressor, ParticleClassifier, StereoCombiner

__all__ = [
Expand Down Expand Up @@ -203,6 +205,15 @@ def _combine(self, combiner, mono_predictions):
stereo_predictions[c.name] = np.array([trafo(r) for r in c])
stereo_predictions[c.name].description = c.description

# to ensure events are stored in the correct order,
# we resort to trigger table order
trigger = read_table(self.h5file, "/dl1/event/subarray/trigger")[
["obs_id", "event_id"]
]
trigger["__sort_index__"] = np.arange(len(trigger))
stereo_predictions = _join_subarray_events(trigger, stereo_predictions)
stereo_predictions.sort("__sort_index__")

write_table(
stereo_predictions,
self.output_path,
Expand Down
14 changes: 14 additions & 0 deletions ctapipe/tools/tests/test_apply_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def test_apply_energy_regressor(
assert f"{prefix}_tel_energy" in events.colnames
assert f"{prefix}_tel_is_valid" in events.colnames

from ctapipe.io.tests.test_table_loader import check_equal_array_event_order

trigger = read_table(output_path, "/dl1/event/subarray/trigger")
energy = read_table(output_path, "/dl2/event/subarray/energy/ExtraTreesRegressor")
check_equal_array_event_order(trigger, energy)


def test_apply_particle_classifier(
particle_classifier_path,
Expand Down Expand Up @@ -141,3 +147,11 @@ def test_apply_both(
events = loader.read_telescope_events()
assert "ExtraTreesClassifier_prediction" in events.colnames
assert "ExtraTreesRegressor_energy" in events.colnames

from ctapipe.io.tests.test_table_loader import check_equal_array_event_order

trigger = read_table(output_path, "/dl1/event/subarray/trigger")
particle_clf = read_table(
output_path, "/dl2/event/subarray/classification/ExtraTreesClassifier"
)
check_equal_array_event_order(trigger, particle_clf)

0 comments on commit 56fce14

Please sign in to comment.