Skip to content

Commit

Permalink
Also use check_output for the disp training
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasNickel committed Jan 20, 2023
1 parent 5dd8141 commit f0f0d89
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 20 deletions.
28 changes: 10 additions & 18 deletions ctapipe/tools/train_disp_reconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np

from ctapipe.core import Tool
from ctapipe.core.traits import Bool, Int, IntTelescopeParameter, Path, TraitError, flag
from ctapipe.core.traits import Bool, Int, IntTelescopeParameter, Path
from ctapipe.io import TableLoader
from ctapipe.reco import CrossValidator, DispReconstructor
from ctapipe.reco.preprocessing import check_valid_rows, horizontal_to_telescope
Expand Down Expand Up @@ -46,8 +46,6 @@ class TrainDispReconstructor(Tool):
),
).tag(config=True)

overwrite = Bool(help="overwrite existing output files").tag(config=True)

random_seed = Int(
default_value=0, help="Random seed for sampling and cross validation"
).tag(config=True)
Expand All @@ -63,11 +61,9 @@ class TrainDispReconstructor(Tool):
).tag(config=True)

flags = {
**flag(
"overwrite",
"TrainDispReconstructor.overwrite",
"Overwrite output existing output files",
"Don't overwrite existing output files",
"overwrite": (
{"ApplyModels": {"overwrite": True}, "CrossValidator": {"overwrite": True}},
"Overwrite existing output",
),
}

Expand Down Expand Up @@ -99,16 +95,12 @@ def setup(self):
self.cross_validate = CrossValidator(parent=self, model_component=self.models)
self.rng = np.random.default_rng(self.random_seed)

if self.output_path.suffix != ".pkl":
self.log.warning(
"Expected .pkl extension for output_path, got %s",
self.output_path.suffix,
)

if self.output_path.exists() and not self.overwrite:
raise TraitError(
f"output_path '{self.output_path}' exists and overwrite=False"
)
output_files = [
self.output_path,
]
if self.cross_validate.output_path:
output_files.append(self.cross_validate.output_path)
self.check_output(output_files)

def start(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion ctapipe/tools/train_energy_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class TrainEnergyRegressor(Tool):
flags = {
"overwrite": (
{"ApplyModels": {"overwrite": True}, "CrossValidator": {"overwrite": True}},
"Overwrite output file if it exists",
"Overwrite existing output",
),
}

Expand Down
2 changes: 1 addition & 1 deletion ctapipe/tools/train_particle_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class TrainParticleClassifier(Tool):
flags = {
"overwrite": (
{"ApplyModels": {"overwrite": True}, "CrossValidator": {"overwrite": True}},
"Overwrite output file if it exists",
"Overwrite existing output",
),
}

Expand Down

0 comments on commit f0f0d89

Please sign in to comment.