Skip to content

Commit

Permalink
simplify nnunet ci test
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Nov 27, 2024
1 parent 81624aa commit 624f7d7
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tests/tests_nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,26 @@

def run_tests_and_exit_on_failure():

# Download weights
download_pretrained_weights(297) # total 3mm

# Set nnUNet_results env var
weights_dir = Path.home() / ".totalsegmentator" / "nnunet" / "results"
os.environ["nnUNet_results"] = str(weights_dir)
print(f"Using weights directory: {weights_dir}")

# Copy example file
os.makedirs("tests/nnunet_input_files", exist_ok=True)
shutil.copy("tests/reference_files/example_ct_sm.nii.gz", "tests/nnunet_input_files/example_ct_sm_0000.nii.gz")

# Run nnunet
subprocess.call(f"nnUNetv2_predict -i tests/nnunet_input_files -o tests/nnunet_input_files -d 297 -tr nnUNetTrainer_4000epochs_NoMirroring -c 3d_fullres -f 0 -device cpu", shell=True)

r = pytest.main(["-v", "tests/test_end_to_end.py::test_end_to_end::test_nnunet_prediction"])
# Check if output file exists
assert os.path.exists("tests/nnunet_input_files/example_ct_sm.nii.gz"), "A nnunet output file was not generated."

# Clean up
shutil.rmtree("tests/nnunet_input_files")
if r != 0: sys.exit("Test failed: test_nnunet_prediction")


if __name__ == "__main__":
Expand Down

0 comments on commit 624f7d7

Please sign in to comment.