From 919bbb2fd37dff66533cb89fdc411ada246a427a Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Fri, 9 Aug 2024 19:15:53 +0400 Subject: [PATCH] Fix chesapeake plot test --- tests/datasets/test_chesapeake.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py index 33dbfd27978..c6c25c8f612 100644 --- a/tests/datasets/test_chesapeake.py +++ b/tests/datasets/test_chesapeake.py @@ -225,6 +225,9 @@ def test_plot(self, dataset: ChesapeakeCVPR) -> None: plt.close() dataset.plot(x, show_titles=False) plt.close() - x['prediction'] = x['mask'][:, :, 0].clone().unsqueeze(2) + if x['mask'].ndim == 2: + x['prediction'] = x['mask'].clone() + else: + x['prediction'] = x['mask'][0, :, :].clone() dataset.plot(x) plt.close()