Skip to content

Commit

Permalink
Merge pull request #2213 from cta-observatory/overwrite
Browse files Browse the repository at this point in the history
Consistent overwrite
  • Loading branch information
maxnoe authored Jan 24, 2023
2 parents 10b6668 + b01a420 commit 14ea3ba
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 174 deletions.
14 changes: 14 additions & 0 deletions ctapipe/core/tests/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,20 @@ def test_tool_logging_quiet(capsys):
assert len(log) == 0


def test_tool_overwrite_output(capsys, tmp_path):
path = tmp_path / "overwrite_dummy"
tool = Tool()
# path does not exist
tool.check_output(path)
# path exists and no overwrite
path.touch()
with pytest.raises(ToolConfigurationError):
tool.check_output(path)
# path exists and overwrite is True
tool.overwrite = True
tool.check_output(path)


def test_invalid_traits(tmp_path, caplog):
caplog.set_level(logging.INFO, logger="ctapipe")

Expand Down
33 changes: 32 additions & 1 deletion ctapipe/core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def main():
).tag(config=True)

quiet = Bool(default_value=False).tag(config=True)
overwrite = Bool(default_value=False).tag(config=True)

_log_formatter_cls = ColoredFormatter

Expand Down Expand Up @@ -203,8 +204,12 @@ def __init__(self, **kwargs):
{"Tool": {"log_level": "DEBUG"}},
"Set log level to DEBUG",
),
"overwrite": (
{"Tool": {"overwrite": True}},
"Overwrite existing output files without asking",
),
}
self.flags.update(flags)
self.flags = {**flags, **self.flags}

self.is_setup = False
self.version = version
Expand Down Expand Up @@ -324,6 +329,32 @@ def add_component(self, component_instance):
self._registered_components.append(component_instance)
return component_instance

def check_output(self, *output_paths):
"""
Test if output files exist and if they do, throw an error
unless ``self.overwrite`` is set to True.
This should be checked during tool setup to avoid having a tool only
realize the output can not be written after some long-running calculations
(e.g. training of ML-models).
Because we currently do not collect all created output files in the tool
(they can be attached to some component), the output files need
to be given and can not easily be derived from ``self``.
Parameters
----------
output_paths: Path
One or more output path to check.
"""
for output in output_paths:
if output is not None and output.exists():
if self.overwrite:
self.log.warning("Overwriting %s", output)
else:
raise ToolConfigurationError(
f"Output path {output} exists, but overwrite=False"
)

@abstractmethod
def setup(self):
"""Set up the tool.
Expand Down
21 changes: 2 additions & 19 deletions ctapipe/tools/apply_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tqdm.auto import tqdm

from ctapipe.core.tool import Tool
from ctapipe.core.traits import Bool, Integer, Path, flag
from ctapipe.core.traits import Integer, Path
from ctapipe.io import TableLoader, write_table
from ctapipe.io.astropy_helpers import read_table
from ctapipe.io.tableio import TelListToMaskTransform
Expand Down Expand Up @@ -48,8 +48,6 @@ class ApplyModels(Tool):
--output gamma_applied.dl2.h5
"""

overwrite = Bool(default_value=False).tag(config=True)

input_url = Path(
default_value=None,
allow_none=False,
Expand Down Expand Up @@ -104,19 +102,6 @@ class ApplyModels(Tool):
"chunk-size": "ApplyModels.chunk_size",
}

flags = {
**flag(
"overwrite",
"ApplyModels.overwrite",
"Overwrite tables in output file if it exists",
"Don't overwrite tables in output file if it exists",
),
"f": (
{"ApplyModels": {"overwrite": True}},
"Overwrite output file if it exists",
),
}

classes = [
TableLoader,
EnergyRegressor,
Expand All @@ -129,10 +114,8 @@ def setup(self):
"""
Initialize components from config
"""
self.check_output(self.output_path)
self.log.info("Copying to output destination.")
if self.output_path.exists() and not self.overwrite:
raise IOError(f"Output path {self.output_path} exists, but overwrite=False")

shutil.copy(self.input_url, self.output_path)

self.h5file = self.enter_context(tables.open_file(self.output_path, mode="r+"))
Expand Down
61 changes: 23 additions & 38 deletions ctapipe/tools/dump_triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from astropy import units as u
from astropy.table import Table

from ..core import Provenance, Tool
from ..core.traits import Dict, Path, Unicode
from ..io import EventSource
from ..core import Provenance, ToolConfigurationError
from ..core.traits import Unicode, Dict, Bool, Path, flag
from ..core import Tool

MAX_TELS = 1000

Expand All @@ -23,42 +22,30 @@ class DumpTriggersTool(Tool):
# configuration parameters:
# =============================================

infile = Path(exists=True, directory_ok=False, help="input simtelarray file").tag(
config=True
)
input_path = Path(
exists=True, directory_ok=False, help="input simtelarray file", allow_none=False
).tag(config=True)

outfile = Path(
output_path = Path(
default_value="triggers.fits",
directory_ok=False,
help="output filename (*.fits, *.h5)",
).tag(config=True)

overwrite = Bool(False, help="overwrite existing output file").tag(config=True)

# =============================================
# map low-level options to high-level command-line options
# =============================================

aliases = Dict(
{"infile": "DumpTriggersTool.infile", "outfile": "DumpTriggersTool.outfile"}
{
"input": "DumpTriggersTool.input_path",
"output": "DumpTriggersTool.output_path",
}
)

flags = {
"f": (
{"DumpTriggersTool": {"overwrite": True}},
"Enable overwriting of output file",
),
**flag(
"overwrite"
"DumpTriggersTool.overwrite"
"Enable overwriting of output file.",
"Disable overwriting of output file.",
),
}

examples = (
"ctapipe-dump-triggers --infile gamma.simtel.gz "
"--outfile trig.fits --overwrite"
"ctapipe-dump-triggers --input gamma.simtel.gz "
"--output trig.fits --overwrite"
"\n\n"
"If you want to see more output, use --log_level=DEBUG"
)
Expand Down Expand Up @@ -103,11 +90,9 @@ def add_event_to_table(self, event):
)

def setup(self):
""" setup function, called before `start()` """

if self.infile == "":
raise ToolConfigurationError("No 'infile' parameter was specified. ")
"""setup function, called before `start()`"""

self.check_output(self.output_path)
self.events = Table(
names=["EVENT_ID", "T_REL", "DELTA_T", "N_TRIG", "TRIGGERED_TELS"],
dtype=[np.int64, np.float64, np.float64, np.int32, np.uint8],
Expand All @@ -117,15 +102,15 @@ def setup(self):
self.events["T_REL"].unit = u.s
self.events["T_REL"].description = "Time relative to first event"
self.events["DELTA_T"].unit = u.s
self.events.meta["INPUT"] = str(self.infile)
self.events.meta["INPUT"] = str(self.input_path)

self._current_trigpattern = np.zeros(MAX_TELS)
self._current_starttime = None
self._prev_time = None

def start(self):
""" main event loop """
with EventSource(self.infile) as source:
"""main event loop"""
with EventSource(self.input_path) as source:
for event in source:
self.add_event_to_table(event)

Expand All @@ -136,16 +121,16 @@ def finish(self):
"""
# write out the final table
try:
if ".fits" in self.outfile.suffixes:
self.events.write(self.outfile, overwrite=self.overwrite)
elif self.outfile.suffix in (".hdf5", ".h5", ".hdf"):
if ".fits" in self.output_path.suffixes:
self.events.write(self.output_path, overwrite=self.overwrite)
elif self.output_path.suffix in (".hdf5", ".h5", ".hdf"):
self.events.write(
self.outfile, path="/events", overwrite=self.overwrite
self.output_path, path="/events", overwrite=self.overwrite
)
else:
self.events.write(self.outfile)
self.events.write(self.output_path)

Provenance().add_output_file(self.outfile)
Provenance().add_output_file(self.output_path)
except IOError as err:
self.log.warning("Couldn't write output (%s)", err)

Expand Down
23 changes: 2 additions & 21 deletions ctapipe/tools/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class MergeTool(Tool):
help="Input dl1-files",
).tag(config=True)
output_path = traits.Path(
help="Merged-DL1 output filename", directory_ok=False
help="Merged-DL1 output filename", directory_ok=False, allow_none=False
).tag(config=True)
skip_images = Bool(
help="Skip DL1/Event/Telescope and Simulation/Event/Telescope images in output",
Expand All @@ -146,9 +146,6 @@ class MergeTool(Tool):
skip_broken_files = Bool(
help="Skip broken files instead of raising an error", default_value=False
).tag(config=True)
overwrite = Bool(
help="Overwrite output file if it exists", default_value=False
).tag(config=True)
progress_bar = Bool(
help="Show progress bar during processing", default_value=False
).tag(config=True)
Expand Down Expand Up @@ -177,13 +174,6 @@ class MergeTool(Tool):
}

flags = {
"f": ({"MergeTool": {"overwrite": True}}, "Overwrite output file if it exists"),
**flag(
"overwrite",
"MergeTool.overwrite",
"Overwrite output file if it exists",
"Don't overwrite output file if it exists",
),
"progress": (
{"MergeTool": {"progress_bar": True}},
"Show a progress bar for all given input files",
Expand Down Expand Up @@ -221,16 +211,7 @@ def setup(self):
sys.exit(1)

self.output_path = self.output_path.expanduser()
if self.output_path.exists():
if self.overwrite:
self.log.warning(f"Overwriting {self.output_path}")
self.output_path.unlink()
else:
self.log.critical(
f"Output file {self.output_path} exists, "
"use `--overwrite` to overwrite"
)
sys.exit(1)
self.check_output(self.output_path)

PROV.add_output_file(str(self.output_path))

Expand Down
32 changes: 5 additions & 27 deletions ctapipe/tools/muon_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from ..calib import CameraCalibrator
from ..containers import MuonParametersContainer, TelEventIndexContainer
from ..coordinates import CameraFrame, TelescopeFrame
from ..core import Provenance, Tool, ToolConfigurationError, traits
from ..core.traits import flag
from ..core import Provenance, Tool, traits
from ..image.cleaning import TailcutsImageCleaner
from ..instrument import CameraGeometry
from ..io import EventSource, HDF5TableWriter
Expand All @@ -33,9 +32,9 @@ class MuonAnalysis(Tool):
name = "ctapipe-reconstruct-muons"
description = traits.Unicode(__doc__)

output = traits.Path(directory_ok=False, help="HDF5 output file name").tag(
config=True
)
output = traits.Path(
directory_ok=False, allow_none=False, help="HDF5 output file name"
).tag(config=True)

completeness_threshold = traits.FloatTelescopeParameter(
default_value=30.0, help="Threshold for calculating the ``ring_completeness``"
Expand All @@ -49,10 +48,6 @@ class MuonAnalysis(Tool):
),
).tag(config=True)

overwrite = traits.Bool(
default_value=False, help="If true, overwrite outputfile without asking"
).tag(config=True)

min_pixels = traits.IntTelescopeParameter(
help=(
"Minimum number of pixels after cleaning and ring finding"
Expand Down Expand Up @@ -80,25 +75,8 @@ class MuonAnalysis(Tool):
("t", "allowed-tels"): "EventSource.allowed_tels",
}

flags = {
"f": ({"MuonAnalysis": {"overwrite": True}}, "Overwrite output file"),
**flag(
"overwrite",
"MuonAnalysis.overwrite",
"Overwrite output file",
"Don't overwrite output file",
),
}

def setup(self):
if self.output is None:
raise ToolConfigurationError("You need to provide an --output file")

if self.output.exists() and not self.overwrite:
raise ToolConfigurationError(
"Outputfile {self.output} already exists, use `--overwrite` to overwrite"
)

self.check_output(self.output)
self.source = EventSource(parent=self)
subarray = self.source.subarray

Expand Down
8 changes: 1 addition & 7 deletions ctapipe/tools/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,10 @@ class ProcessorTool(Tool):
}

flags = {
"f": (
"overwrite": (
{"DataWriter": {"overwrite": True}},
"Overwrite output file if it exists",
),
**flag(
"overwrite",
"DataWriter.overwrite",
"Overwrite output file if it exists",
"Don't overwrite output file if it exists",
),
**flag(
"progress",
"ProcessorTool.progress_bar",
Expand Down
6 changes: 3 additions & 3 deletions ctapipe/tools/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ def test_dump_triggers(tmp_path):
from ctapipe.tools.dump_triggers import DumpTriggersTool

sys.argv = ["dump_triggers"]
outfile = tmp_path / "triggers.fits"
tool = DumpTriggersTool(infile=PROD5B_PATH, outfile=str(outfile))
output_path = tmp_path / "triggers.fits"
tool = DumpTriggersTool(input_path=PROD5B_PATH, output_path=str(output_path))

assert run_tool(tool, cwd=tmp_path) == 0

assert outfile.exists()
assert output_path.exists()
assert run_tool(tool, ["--help-all"]) == 0


Expand Down
Loading

0 comments on commit 14ea3ba

Please sign in to comment.