From df6f812d20c17b9d2d289635bac72a8321a7e8f5 Mon Sep 17 00:00:00 2001 From: Maximilian Linhoff Date: Wed, 4 Sep 2024 16:25:31 +0200 Subject: [PATCH 1/3] Add test for duplicated input files in merge tool --- src/ctapipe/tools/tests/test_merge.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/ctapipe/tools/tests/test_merge.py b/src/ctapipe/tools/tests/test_merge.py index 09b3bca217e..e5eb70faec0 100644 --- a/src/ctapipe/tools/tests/test_merge.py +++ b/src/ctapipe/tools/tests/test_merge.py @@ -5,11 +5,12 @@ from pathlib import Path import numpy as np +import pytest import tables from astropy.table import vstack from astropy.utils.diff import report_diff_values -from ctapipe.core import run_tool +from ctapipe.core import ToolConfigurationError, run_tool from ctapipe.io import TableLoader from ctapipe.io.astropy_helpers import read_table from ctapipe.io.tests.test_astropy_helpers import assert_table_equal @@ -188,3 +189,22 @@ def test_muon(tmp_path, dl1_muon_output_file): assert len(table) == 2 * n_input assert_table_equal(table[:n_input], input_table) assert_table_equal(table[n_input:], input_table) + + +def test_duplicated(tmp_path, dl1_file, dl1_proton_file): + from ctapipe.tools.merge import MergeTool + + output = tmp_path / "invalid.dl1.h5" + with pytest.raises(ToolConfigurationError, match="Same file given multiple times"): + run_tool( + MergeTool(), + argv=[ + str(dl1_file), + str(dl1_proton_file), + str(dl1_file), + f"--output={output}", + "--overwrite", + ], + cwd=tmp_path, + raises=True, + ) From 028553d42c45bad7ce612bf2486982ec68a371ae Mon Sep 17 00:00:00 2001 From: Maximilian Linhoff Date: Wed, 4 Sep 2024 16:34:39 +0200 Subject: [PATCH 2/3] Check for uniqueness of input files in merge tool --- docs/changes/2611.feature.rst | 2 ++ src/ctapipe/tools/merge.py | 8 ++++++++ 2 files changed, 10 insertions(+) create mode 100644 docs/changes/2611.feature.rst diff --git a/docs/changes/2611.feature.rst b/docs/changes/2611.feature.rst new file mode 100644 index 00000000000..41040223d88 --- /dev/null +++ b/docs/changes/2611.feature.rst @@ -0,0 +1,2 @@ +The ``ctapipe-merge`` tool no checks for duplicated input files and +raises an error in that case. diff --git a/src/ctapipe/tools/merge.py b/src/ctapipe/tools/merge.py index 3665d504cd4..11285ad9caf 100644 --- a/src/ctapipe/tools/merge.py +++ b/src/ctapipe/tools/merge.py @@ -3,6 +3,7 @@ """ import sys from argparse import ArgumentParser +from collections import Counter from pathlib import Path from tqdm.auto import tqdm @@ -161,6 +162,13 @@ def setup(self): ) sys.exit(1) + counts = Counter(self.input_files) + duplicated = [p for p, c in counts.items() if c > 1] + if len(duplicated) > 0: + raise ToolConfigurationError( + f"Same file given multiple times. Duplicated files are: {duplicated}" + ) + self.merger = self.enter_context(HDF5Merger(parent=self)) if self.merger.output_path in self.input_files: raise ToolConfigurationError( From 00dddbdc6ef29bb6f312f2f686cb55a1c00d554a Mon Sep 17 00:00:00 2001 From: Maximilian Linhoff Date: Wed, 4 Sep 2024 17:14:05 +0200 Subject: [PATCH 3/3] Check for duplicated obs_ids in HDF5Merge, fixes #2610 --- docs/changes/2611.feature.rst | 6 +++++- src/ctapipe/conftest.py | 8 +++++++ src/ctapipe/io/hdf5merger.py | 30 +++++++++++++++++++++++++++ src/ctapipe/io/tests/test_merge.py | 27 +++++++++++++++++++++++- src/ctapipe/tools/tests/test_merge.py | 6 ++---- 5 files changed, 71 insertions(+), 6 deletions(-) diff --git a/docs/changes/2611.feature.rst b/docs/changes/2611.feature.rst index 41040223d88..37bd102ab23 100644 --- a/docs/changes/2611.feature.rst +++ b/docs/changes/2611.feature.rst @@ -1,2 +1,6 @@ -The ``ctapipe-merge`` tool no checks for duplicated input files and +The ``ctapipe-merge`` tool now checks for duplicated input files and raises an error in that case. + +The ``HDF5Merger`` class, and thus also the ``ctapipe-merge`` tool, +now checks for duplicated obs_ids during merging, to prevent +invalid output files. diff --git a/src/ctapipe/conftest.py b/src/ctapipe/conftest.py index 4d743e5d217..c6599dfe86d 100644 --- a/src/ctapipe/conftest.py +++ b/src/ctapipe/conftest.py @@ -592,6 +592,14 @@ def proton_train_clf(model_tmp_path, energy_regressor_path): ], raises=True, ) + + # modify obs_ids by adding a constant, this enables merging gamma and proton files + # which is used in the merge tool tests. + with tables.open_file(outpath, mode="r+") as f: + for table in f.walk_nodes("/", "Table"): + if "obs_id" in table.colnames: + obs_id = table.col("obs_id") + table.modify_column(colname="obs_id", column=obs_id + 1_000_000_000) return outpath diff --git a/src/ctapipe/io/hdf5merger.py b/src/ctapipe/io/hdf5merger.py index 95d0a8b6909..b6d59ae7364 100644 --- a/src/ctapipe/io/hdf5merger.py +++ b/src/ctapipe/io/hdf5merger.py @@ -188,6 +188,8 @@ def __init__(self, output_path=None, **kwargs): self.data_model_version = None self.subarray = None self.meta = None + self._merged_obs_ids = set() + # output file existed, so read subarray and data model version to make sure # any file given matches what we already have if appending: @@ -202,6 +204,9 @@ def __init__(self, output_path=None, **kwargs): ) self.required_nodes = _get_required_nodes(self.h5file) + # this will update _merged_obs_ids from existing input file + self._check_obs_ids(self.h5file) + def __call__(self, other: str | Path | tables.File): """ Append file ``other`` to the output file @@ -267,7 +272,32 @@ def _check_can_merge(self, other): f"Required node {node_path} not found in {other.filename}" ) + def _check_obs_ids(self, other): + keys = [ + "/configuration/observation/observation_block", + "/dl1/event/subarray/trigger", + ] + + for key in keys: + if key in other.root: + obs_ids = other.root[key].col("obs_id") + break + else: + raise CannotMerge( + f"Input file {other.filename} is missing keys required to" + f" check for duplicated obs_ids. Tried: {keys}" + ) + + duplicated = self._merged_obs_ids.intersection(obs_ids) + if len(duplicated) > 0: + msg = f"Input file {other.filename} contains obs_ids already included in output file: {duplicated}" + raise CannotMerge(msg) + + self._merged_obs_ids.update(obs_ids) + def _append(self, other): + self._check_obs_ids(other) + # Configuration self._append_subarray(other) diff --git a/src/ctapipe/io/tests/test_merge.py b/src/ctapipe/io/tests/test_merge.py index ec37b6b5dda..6376da6566a 100644 --- a/src/ctapipe/io/tests/test_merge.py +++ b/src/ctapipe/io/tests/test_merge.py @@ -68,7 +68,7 @@ def test_simple(tmp_path, gamma_train_clf, proton_train_clf): merger(proton_train_clf) subarray = SubarrayDescription.from_hdf(gamma_train_clf) - assert subarray == SubarrayDescription.from_hdf(output), "Subarays do not match" + assert subarray == SubarrayDescription.from_hdf(output), "Subarrays do not match" tel_groups = [ "/dl1/event/telescope/parameters", @@ -164,3 +164,28 @@ def test_muon(tmp_path, dl1_muon_output_file): n_input = len(input_table) assert len(table) == n_input assert_table_equal(table, input_table) + + +def test_duplicated_obs_ids(tmp_path, dl2_shower_geometry_file): + from ctapipe.io.hdf5merger import CannotMerge, HDF5Merger + + output = tmp_path / "invalid.dl1.h5" + + # check for fresh file + with HDF5Merger(output) as merger: + merger(dl2_shower_geometry_file) + + with pytest.raises( + CannotMerge, match="Input file .* contains obs_ids already included" + ): + merger(dl2_shower_geometry_file) + + # check for appending + with HDF5Merger(output, overwrite=True) as merger: + merger(dl2_shower_geometry_file) + + with HDF5Merger(output, append=True) as merger: + with pytest.raises( + CannotMerge, match="Input file .* contains obs_ids already included" + ): + merger(dl2_shower_geometry_file) diff --git a/src/ctapipe/tools/tests/test_merge.py b/src/ctapipe/tools/tests/test_merge.py index e5eb70faec0..0879d27f1df 100644 --- a/src/ctapipe/tools/tests/test_merge.py +++ b/src/ctapipe/tools/tests/test_merge.py @@ -177,7 +177,6 @@ def test_muon(tmp_path, dl1_muon_output_file): argv=[ f"--output={output}", str(dl1_muon_output_file), - str(dl1_muon_output_file), ], raises=True, ) @@ -186,9 +185,8 @@ def test_muon(tmp_path, dl1_muon_output_file): input_table = read_table(dl1_muon_output_file, "/dl1/event/telescope/muon/tel_001") n_input = len(input_table) - assert len(table) == 2 * n_input - assert_table_equal(table[:n_input], input_table) - assert_table_equal(table[n_input:], input_table) + assert len(table) == n_input + assert_table_equal(table, input_table) def test_duplicated(tmp_path, dl1_file, dl1_proton_file):