Skip to content

Commit

Permalink
Merge pull request #2243 from cta-observatory/crossval_overwrite
Browse files Browse the repository at this point in the history
Fix wrong overwrite config in train_* tools
  • Loading branch information
kosack authored Feb 3, 2023
2 parents c8940b5 + 262e68e commit 3ed2ded
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 30 deletions.
2 changes: 2 additions & 0 deletions ctapipe/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ def energy_regressor_path(model_tmp_path):
f"--output={out_file}",
f"--config={config}",
"--log-level=INFO",
"--overwrite",
],
)
assert ret == 0
Expand Down Expand Up @@ -586,6 +587,7 @@ def particle_classifier_path(model_tmp_path, gamma_train_clf, proton_train_clf):
f"--output={out_file}",
f"--config={config}",
"--log-level=INFO",
"--overwrite",
],
)
assert ret == 0
Expand Down
10 changes: 0 additions & 10 deletions ctapipe/tools/train_disp_reconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,6 @@ class TrainDispReconstructor(Tool):
),
).tag(config=True)

flags = {
"overwrite": (
{
"TrainDispReconstructor": {"overwrite": True},
"CrossValidator": {"overwrite": True},
},
"Overwrite existing output",
),
}

aliases = {
("i", "input"): "TableLoader.input_url",
("o", "output"): "TrainDispReconstructor.output_path",
Expand Down
10 changes: 0 additions & 10 deletions ctapipe/tools/train_energy_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,6 @@ class TrainEnergyRegressor(Tool):
default_value=0, help="Random seed for sampling and cross validation"
).tag(config=True)

flags = {
"overwrite": (
{
"TrainEnergyReconstructor": {"overwrite": True},
"CrossValidator": {"overwrite": True},
},
"Overwrite existing output",
),
}

aliases = {
("i", "input"): "TableLoader.input_url",
("o", "output"): "TrainEnergyRegressor.output_path",
Expand Down
10 changes: 0 additions & 10 deletions ctapipe/tools/train_particle_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,6 @@ class TrainParticleClassifier(Tool):
help="Random number seed for sampling and the cross validation splitting",
).tag(config=True)

flags = {
"overwrite": (
{
"TrainParticleClassifier": {"overwrite": True},
"CrossValidator": {"overwrite": True},
},
"Overwrite existing output",
),
}

aliases = {
"signal": "TrainParticleClassifier.input_url_signal",
"background": "TrainParticleClassifier.input_url_background",
Expand Down

0 comments on commit 3ed2ded

Please sign in to comment.