forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Refactor] add unittest for
HeatmapHead
(open-mmlab#1503)
* add unittest for HeatmapHead * add unittest * fix comments typo Co-authored-by: Tau <[email protected]>
- Loading branch information
Showing
15 changed files
with
448 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,58 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Any, Optional, Sequence, Union | ||
|
||
import numpy as np | ||
import torch | ||
from mmengine.utils import is_seq_of | ||
from torch import Tensor | ||
|
||
|
||
def _to_numpy(x: Tensor) -> np.ndarray: | ||
"""Convert a torch tensor to numpy.ndarray. | ||
def to_numpy(x: Union[Tensor, Sequence[Tensor]], | ||
return_device: bool = False) -> Union[np.ndarray, tuple]: | ||
"""Convert torch tensor to numpy.ndarray. | ||
Args: | ||
x (Tensor | Sequence[Tensor]): A single tensor or a sequence of | ||
tensors | ||
return_device (bool): Whether return the tensor device. Defaults to | ||
``False`` | ||
Returns: | ||
np.ndarray | tuple: If ``return_device`` is ``True``, return a tuple | ||
of converted numpy array(s) and the device indicator; otherwise only | ||
return the numpy array(s) | ||
""" | ||
|
||
if isinstance(x, Tensor): | ||
arrays = x.detach().cpu().numpy() | ||
device = x.device | ||
elif is_seq_of(x, Tensor): | ||
arrays = [to_numpy(_x)[0] for _x in x] | ||
device = x[0].device | ||
else: | ||
raise ValueError(f'Invalid input type {type(x)}') | ||
|
||
if return_device: | ||
return arrays, device | ||
else: | ||
return arrays | ||
|
||
|
||
def to_tensor(x: Union[np.ndarray, Sequence[np.ndarray]], | ||
device: Optional[Any] = None) -> Union[Tensor, Sequence[Tensor]]: | ||
"""Convert numpy.ndarray to torch tensor. | ||
Args: | ||
x (Tensor): A torch tensor | ||
x (np.ndarray | Sequence[np.ndarray]): A single np.ndarray or a | ||
sequence of tensors | ||
tensor (Any, optional): The device indicator. Defaults to ``None`` | ||
Returns: | ||
np.ndarray: The converted numpy array | ||
tuple: | ||
- Tensor | Sequence[Tensor]: The converted Tensor or Tensor sequence | ||
""" | ||
return x.detach().cpu().numpy() | ||
if isinstance(x, np.ndarray): | ||
return torch.tensor(x, device=device) | ||
elif is_seq_of(x, np.ndarray): | ||
return [to_tensor(_x, device=device) for _x in x] | ||
else: | ||
raise ValueError(f'Invalid input type {type(x)}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .base_head import BaseHead | ||
from .heatmap_heads import HeatmapHead | ||
|
||
__all__ = ['BaseHead', 'HeatmapHead'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .heatmap_head import HeatmapHead | ||
|
||
__all__ = ['HeatmapHead'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.