From e505b5e5a786749ac680c61c1d17576e34d47482 Mon Sep 17 00:00:00 2001 From: Martin <1500595+bmmtstb@users.noreply.github.com> Date: Fri, 19 Jan 2024 18:54:50 +0100 Subject: [PATCH] Added tests for torch_to_matplotlib in visualization.py Signed-off-by: Martin <1500595+bmmtstb@users.noreply.github.com> --- dgs/utils/visualization.py | 2 +- tests/utils/test__visualization.py | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 tests/utils/test__visualization.py diff --git a/dgs/utils/visualization.py b/dgs/utils/visualization.py index 85a4f1e..9d9edc3 100644 --- a/dgs/utils/visualization.py +++ b/dgs/utils/visualization.py @@ -46,7 +46,7 @@ def torch_show_image( plt.show() -def torch_to_matplotlib(img: TVImage) -> np.ndarray: # pragma: no cover +def torch_to_matplotlib(img: Union[TVImage, torch.Tensor]) -> np.ndarray: """Convert a given single or batched torch image Tensor to a numpy.ndarray on the cpu. The dimensions are switched from ``[B x C x H x W]`` -> ``[B x H x W x C]`` diff --git a/tests/utils/test__visualization.py b/tests/utils/test__visualization.py new file mode 100644 index 0000000..358cd99 --- /dev/null +++ b/tests/utils/test__visualization.py @@ -0,0 +1,23 @@ +import unittest + +import torch +from torchvision import tv_tensors + +from dgs.utils.visualization import torch_to_matplotlib + + +class TestVisualization(unittest.TestCase): + def test_torch_to_matplotlib(self): + B, C, H, W = 8, 3, 64, 64 + for tensor, out_shape in [ + (torch.ones(B, C, H, W), [B, H, W, C]), + (tv_tensors.Image(torch.ones(B, C, H, W)), [B, H, W, C]), + (torch.ones(C, H, W), [H, W, C]), + ]: + with self.subTest(msg="tensor: {}, out_shape: {}".format(tensor, out_shape)): + m = torch_to_matplotlib(tensor) + self.assertEqual(list(m.shape), list(out_shape)) + + +if __name__ == "__main__": + unittest.main()