Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunked model application #2133

Merged
merged 6 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ctapipe/io/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@


CONVERSIONS = {
Time: lambda t: t.utc.iso,
list: lambda l: ",".join([convert(elem) for elem in l]),
DataLevel: lambda d: d.name,
Time: lambda value: value.utc.iso,
list: lambda value: ",".join([convert(elem) for elem in value]),
DataLevel: lambda value: value.name,
}


Expand Down
222 changes: 149 additions & 73 deletions ctapipe/io/tableloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def __init__(
func,
n_total,
chunk_size,
*args,
**kwargs,
args,
kwargs,
):
self.func = func
self.n_total = n_total
Expand All @@ -74,6 +74,8 @@ def __init__(
self.n_chunks = int(np.ceil(self.n_total / self.chunk_size))
self.args = args
self.kwargs = kwargs
self.start = None
self.end = None

def __len__(self):
return self.n_chunks
Expand All @@ -87,11 +89,11 @@ def __next__(self):
raise StopIteration

chunk = self._current_chunk
start = chunk * self.chunk_size
stop = min(self.n_total, (chunk + 1) * self.chunk_size)
self.start = chunk * self.chunk_size
self.stop = min(self.n_total, (chunk + 1) * self.chunk_size)

self._current_chunk += 1
return self.func(*self.args, start=start, stop=stop, **self.kwargs)
return self.func(*self.args, start=self.start, stop=self.stop, **self.kwargs)


def _empty_telescope_events_table():
Expand Down Expand Up @@ -365,7 +367,7 @@ def read_subarray_events(self, start=None, stop=None, keep_order=True):
self._sort_to_original_order(table)
return table

def read_subarray_events_chunked(self, chunk_size):
def read_subarray_events_chunked(self, chunk_size, *args, **kwargs):
"""
Iterate over chunks of subarray events.

Expand All @@ -378,6 +380,8 @@ def read_subarray_events_chunked(self, chunk_size):
self.read_subarray_events,
n_total=len(self),
chunk_size=chunk_size,
args=args,
kwargs=kwargs,
StFroese marked this conversation as resolved.
Show resolved Hide resolved
)

def _read_telescope_events_for_id(self, tel_id, start=None, stop=None):
Expand All @@ -390,9 +394,9 @@ def _read_telescope_events_for_id(self, tel_id, start=None, stop=None):
tel_id: int
Telescope identification number.
start: int
First row to read
First subarray event index to read
stop: int
Last row to read (non-inclusive)
Last subarray event index to read

Returns
-------
Expand All @@ -402,19 +406,37 @@ def _read_telescope_events_for_id(self, tel_id, start=None, stop=None):
if tel_id is None:
raise ValueError("Please, specify a telescope ID.")

table = read_table(self.h5file, "/dl1/event/telescope/trigger")
table = table[table["tel_id"] == tel_id]
table = table[slice(start, stop)]
# trigger is stored in a single table for all telescopes, we need to
# calculate the range to read from the stereo trigger info
trigger_start = trigger_stop = None
tel_start = tel_stop = None

if start is not None or stop is not None:
tel_start, tel_stop = self._get_tel_start_stop(tel_id, start, stop)

if start is not None:
trigger_start = self._n_total_telescope_events[start]

if stop is not None:
trigger_stop = self._n_total_telescope_events[stop]

table = read_table(
self.h5file,
"/dl1/event/telescope/trigger",
condition=f"tel_id == {tel_id}",
start=trigger_start,
stop=trigger_stop,
)

if self.load_dl1_parameters:
parameters = self._read_telescope_table(
PARAMETERS_GROUP, tel_id, start=start, stop=stop
PARAMETERS_GROUP, tel_id, start=tel_start, stop=tel_stop
)
table = _merge_telescope_tables(table, parameters)

if self.load_dl1_images:
images = self._read_telescope_table(
IMAGES_GROUP, tel_id, start=start, stop=stop
IMAGES_GROUP, tel_id, start=tel_start, stop=tel_stop
)
table = _merge_telescope_tables(table, images)

Expand All @@ -428,7 +450,7 @@ def _read_telescope_events_for_id(self, tel_id, start=None, stop=None):
for algorithm in group._v_children:
path = f"{group_path}/{algorithm}"
dl2 = self._read_telescope_table(
path, tel_id, start=start, stop=stop
path, tel_id, start=tel_start, stop=tel_stop
)
if len(dl2) == 0:
continue
Expand All @@ -437,13 +459,13 @@ def _read_telescope_events_for_id(self, tel_id, start=None, stop=None):

if self.load_true_images:
true_images = self._read_telescope_table(
TRUE_IMAGES_GROUP, tel_id, start=start, stop=stop
TRUE_IMAGES_GROUP, tel_id, start=tel_start, stop=tel_stop
)
table = _merge_telescope_tables(table, true_images)

if self.load_true_parameters:
true_parameters = self._read_telescope_table(
TRUE_PARAMETERS_GROUP, tel_id, start=start, stop=stop
TRUE_PARAMETERS_GROUP, tel_id, start=tel_start, stop=tel_stop
)
table = _join_telescope_events(table, true_parameters)

Expand All @@ -456,35 +478,27 @@ def _read_telescope_events_for_id(self, tel_id, start=None, stop=None):
impacts = self._read_telescope_table(
TRUE_IMPACT_GROUP,
tel_id,
start=start,
stop=stop,
start=tel_start,
stop=tel_stop,
)
table = _join_telescope_events(table, impacts)

return table

def _read_telescope_events_for_ids(self, tel_ids, tel_start=None, tel_stop=None):
tel_start = tel_start if tel_start is not None else [None] * len(tel_ids)
tel_stop = tel_stop if tel_stop is not None else [None] * len(tel_ids)

tables = []
for tel_id, start, stop in zip(tel_ids, tel_start, tel_stop):
# no events for this telescope in chunk
if start is not None and stop is not None and (stop - start) == 0:
continue

tables.append(
self._read_telescope_events_for_id(tel_id, start=start, stop=stop)
)

def _read_telescope_events_for_ids(self, tel_ids, start=None, stop=None):
tables = [
self._read_telescope_events_for_id(tel_id, start=start, stop=stop)
for tel_id in tel_ids
]
return vstack(tables)

def _join_subarray_info(self, table, start=None, stop=None):
subarray_events = self.read_subarray_events(
start=start,
stop=stop,
keep_order=False,
)
def _join_subarray_info(self, table, start=None, stop=None, subarray_events=None):
if subarray_events is None:
subarray_events = self.read_subarray_events(
start=start,
stop=stop,
keep_order=False,
)
table = join_allow_empty(
table,
subarray_events,
Expand All @@ -496,22 +510,19 @@ def _join_subarray_info(self, table, start=None, stop=None):
)
return table

def _get_tel_start_stop(self, tel_ids, start, stop):
def _get_tel_start_stop(self, tel_id, start, stop):
tel_start = None
tel_stop = None
if start is not None or stop is not None:

indices = self.subarray.tel_ids_to_indices(tel_ids)
index = self.subarray.tel_ids_to_indices(tel_id)[0]

# find first/last row for each telescope
if start is not None:
tel_start = self._n_telescope_events[start][indices]
# find first/last row for each telescope
if start is not None:
tel_start = self._n_telescope_events[start, index]

if stop is not None:
if stop >= len(self._n_telescope_events):
tel_stop = None
else:
tel_stop = self._n_telescope_events[stop][indices]
if stop is None or stop >= len(self._n_telescope_events):
tel_stop = None
else:
tel_stop = self._n_telescope_events[stop, index]

return tel_start, tel_stop

Expand Down Expand Up @@ -549,8 +560,7 @@ def read_telescope_events(self, telescopes=None, start=None, stop=None):
else:
tel_ids = self.subarray.get_tel_ids(telescopes)

tel_start, tel_stop = self._get_tel_start_stop(tel_ids, start, stop)
table = self._read_telescope_events_for_ids(tel_ids, tel_start, tel_stop)
table = self._read_telescope_events_for_ids(tel_ids, start, stop)
table = self._join_subarray_info(table, start=start, stop=stop)

# sort back to order in the file
Expand All @@ -561,7 +571,7 @@ def read_telescope_events(self, telescopes=None, start=None, stop=None):

return table

def read_telescope_events_chunked(self, chunk_size, **kwargs):
def read_telescope_events_chunked(self, chunk_size, *args, **kwargs):
"""
Iterate over chunks of telescope events.

Expand All @@ -578,7 +588,8 @@ def read_telescope_events_chunked(self, chunk_size, **kwargs):
self.read_telescope_events,
n_total=len(self),
chunk_size=chunk_size,
**kwargs,
args=args,
kwargs=kwargs,
)

@lazyproperty
Expand All @@ -600,6 +611,14 @@ def _n_telescope_events(self):
np.cumsum(tels_with_trigger, out=tels_with_trigger, axis=0)
return tels_with_trigger

@lazyproperty
def _n_total_telescope_events(self):
"""
Number of telescope events in the file for each telescope previous
to the nth subarray event.
"""
return self._n_telescope_events.sum(axis=1)

def read_telescope_events_by_type(
self, telescopes=None, start=None, stop=None
) -> Dict[str, Table]:
Expand All @@ -622,37 +641,30 @@ def read_telescope_events_by_type(
else:
tel_ids = self.subarray.get_tel_ids(telescopes)

tel_start, tel_stop = self._get_tel_start_stop(tel_ids, start, stop)
tel_start = tel_start if tel_start is not None else [None] * len(tel_ids)
tel_stop = tel_stop if tel_stop is not None else [None] * len(tel_ids)
subarray_events = self.read_subarray_events(
start=start, stop=stop, keep_order=False
)
self._add_index_if_needed(subarray_events)

by_type = defaultdict(list)
sort_index = self._get_sort_index(start=start, stop=stop)

for tel_id, start, stop in zip(tel_ids, tel_start, tel_stop):
# no events for this telescope in range start/stop
if start is not None and stop is not None and (stop - start) == 0:
continue

for tel_id in tel_ids:
key = str(self.subarray.tel[tel_id])
by_type[key].append(
self._read_telescope_events_for_id(tel_id, start=start, stop=stop)
)
table = self._read_telescope_events_for_id(tel_id, start=start, stop=stop)
if len(table) > 0:
by_type[key].append(table)

by_type = {k: vstack(ts) for k, ts in by_type.items()}

for key in by_type.keys():
by_type[key] = self._join_subarray_info(
by_type[key], start=start, stop=stop
by_type[key], subarray_events=subarray_events
)
by_type[key] = _join_subarray_events(by_type[key], sort_index)
self._sort_to_original_order(by_type[key], include_tel_id=True)

return by_type

def read_telescope_events_by_type_chunked(self, chunk_size, **kwargs):
def read_telescope_events_by_type_chunked(self, chunk_size, *args, **kwargs):
"""
Iterate over chunks of telescope events.
Iterate over chunks of telescope events as dicts of telescope type to tables.

Parameters
----------
Expand All @@ -667,5 +679,69 @@ def read_telescope_events_by_type_chunked(self, chunk_size, **kwargs):
self.read_telescope_events_by_type,
n_total=len(self),
chunk_size=chunk_size,
**kwargs,
args=args,
kwargs=kwargs,
)

def read_telescope_events_by_id(
self, telescopes=None, start=None, stop=None
) -> Dict[int, Table]:
"""Read telescope-based event information.

Parameters
----------
telescopes: List[Union[int, str, TelescopeDescription]]
Any list containing a combination of telescope IDs or telescope_descriptions.

Returns
-------
tables: dict(astropy.io.Table)
Dictionary of tables organized by telescope ids
Table with primary index columns "obs_id", "event_id" and "tel_id".
"""

if telescopes is None:
tel_ids = tuple(self.subarray.tel.keys())
else:
tel_ids = self.subarray.get_tel_ids(telescopes)

subarray_events = self.read_subarray_events(
start=start, stop=stop, keep_order=False
)
self._add_index_if_needed(subarray_events)

by_id = {}
for tel_id in tel_ids:
# no events for this telescope in range start/stop
table = self._read_telescope_events_for_id(tel_id, start=start, stop=stop)
if len(table) > 0:
by_id[tel_id] = table

for tel_id in by_id.keys():
by_id[tel_id] = self._join_subarray_info(
by_id[tel_id], subarray_events=subarray_events
)
self._sort_to_original_order(by_id[tel_id], include_tel_id=True)

return by_id

def read_telescope_events_by_id_chunked(self, chunk_size, *args, **kwargs):
StFroese marked this conversation as resolved.
Show resolved Hide resolved
"""
Iterate over chunks of telescope events and return a dict of one table per telescope id.

Parameters
----------
chunk_size: int
Number of subarray events to load per chunk.
The telescope tables might be larger or smaller than chunk_size
depending on the selected telescopes.

*args, **kwargs are passed to `read_telescope_events_by_id`
"""
return ChunkIterator(
self.read_telescope_events_by_id,
n_total=len(self),
chunk_size=chunk_size,
args=args,
kwargs=kwargs,
)
Loading