Skip to content

Commit

Permalink
Added tests for torch_to_matplotlib in visualization.py
Browse files Browse the repository at this point in the history
Signed-off-by: Martin <[email protected]>
  • Loading branch information
bmmtstb committed Jan 19, 2024
1 parent 0a213c0 commit e505b5e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dgs/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def torch_show_image(
plt.show()


def torch_to_matplotlib(img: TVImage) -> np.ndarray: # pragma: no cover
def torch_to_matplotlib(img: Union[TVImage, torch.Tensor]) -> np.ndarray:
"""Convert a given single or batched torch image Tensor to a numpy.ndarray on the cpu.
The dimensions are switched from ``[B x C x H x W]`` -> ``[B x H x W x C]``
Expand Down
23 changes: 23 additions & 0 deletions tests/utils/test__visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import unittest

import torch
from torchvision import tv_tensors

from dgs.utils.visualization import torch_to_matplotlib


class TestVisualization(unittest.TestCase):
def test_torch_to_matplotlib(self):
B, C, H, W = 8, 3, 64, 64
for tensor, out_shape in [
(torch.ones(B, C, H, W), [B, H, W, C]),
(tv_tensors.Image(torch.ones(B, C, H, W)), [B, H, W, C]),
(torch.ones(C, H, W), [H, W, C]),
]:
with self.subTest(msg="tensor: {}, out_shape: {}".format(tensor, out_shape)):
m = torch_to_matplotlib(tensor)
self.assertEqual(list(m.shape), list(out_shape))


if __name__ == "__main__":
unittest.main()

0 comments on commit e505b5e

Please sign in to comment.