Skip to content

Commit

Permalink
Merge pull request #2247 from cta-observatory/fix_sampling
Browse files Browse the repository at this point in the history
Add rng attribute back to TrainEnergyRegressor
  • Loading branch information
maxnoe authored Feb 3, 2023
2 parents 3ed2ded + d1cc77c commit 2a191e1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
20 changes: 20 additions & 0 deletions ctapipe/tools/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,26 @@ def test_too_few_events(tmp_path, dl2_shower_geometry_file):
)


def test_sampling(tmp_path, dl2_shower_geometry_file):
from ctapipe.tools.train_energy_regressor import TrainEnergyRegressor

tool = TrainEnergyRegressor()
config = resource_file("train_energy_regressor.yaml")
out_file = tmp_path / "energy.pkl"

run_tool(
tool,
argv=[
"--input=dataset://gamma_diffuse_dl2_train_small.dl2.h5",
f"--output={out_file}",
f"--config={config}",
"--log-level=INFO",
"--n-events=100",
],
raises=True,
)


def test_cross_validation_results(tmp_path, gamma_train_clf, proton_train_clf):
from ctapipe.tools.train_disp_reconstructor import TrainDispReconstructor
from ctapipe.tools.train_energy_regressor import TrainEnergyRegressor
Expand Down
1 change: 1 addition & 0 deletions ctapipe/tools/train_energy_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def setup(self):
self.cross_validate = CrossValidator(
parent=self, model_component=self.regressor
)
self.rng = np.random.default_rng(self.random_seed)
self.check_output(self.output_path, self.cross_validate.output_path)

def start(self):
Expand Down

0 comments on commit 2a191e1

Please sign in to comment.