diff --git a/tests/conftest.py b/tests/conftest.py index b6e9e68c..38ccaab2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -65,5 +65,5 @@ def small_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path: def labels_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path: """Provides a small, not preprocessed HCS OME-Zarr dataset.""" dataset_path = tmp_path_factory.mktemp("labels.zarr") - _build_hcs(dataset_path, ["DAPI"], (2, 16, 16), np.uint16, 3) + _build_hcs(dataset_path, ["DAPI", "GFP"], (2, 16, 16), np.uint16, 3) return dataset_path diff --git a/tests/translation/test_evaluation.py b/tests/translation/test_evaluation.py new file mode 100644 index 00000000..40cd8429 --- /dev/null +++ b/tests/translation/test_evaluation.py @@ -0,0 +1,33 @@ +import numpy as np +import pandas as pd +import pytest +from lightning.pytorch.loggers import CSVLogger +from numpy.testing import assert_array_equal + +from viscy.data.segmentation import SegmentationDataModule +from viscy.trainer import Trainer +from viscy.translation.evaluation import SegmentationMetrics2D + + +@pytest.mark.parametrize("pred_channel", ["DAPI", "GFP"]) +def test_segmentation_metrics_2d(pred_channel, labels_hcs_dataset, tmp_path) -> None: + dm = SegmentationDataModule( + pred_dataset=labels_hcs_dataset, + target_dataset=labels_hcs_dataset, + target_channel="DAPI", + pred_channel=pred_channel, + pred_z_slice=0, + target_z_slice=0, + batch_size=1, + num_workers=0, + ) + lm = SegmentationMetrics2D() + trainer = Trainer(logger=CSVLogger(tmp_path, name="", version="")) + trainer.test(lm, datamodule=dm) + metrics = pd.read_csv(tmp_path / "metrics.csv") + if pred_channel == "DAPI": + assert_array_equal( + metrics["accuracy"].to_numpy(), np.ones_like(metrics["accuracy"]) + ) + else: + assert 0 < metrics["accuracy"].mean() < 1