-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] add nlc2nchw2nlc and nchw2nlc2nchw (#1249)
* [Feature] add nlc2nchw2nlc and nchw2nlc2nchw * add example * add test, add **kwargs to make it more universal
- Loading branch information
1 parent
4250a5a
commit 17b500f
Showing
3 changed files
with
170 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import torch | ||
|
||
from mmseg.models.utils import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, | ||
nlc_to_nchw) | ||
|
||
|
||
def test_nchw2nlc2nchw(): | ||
# Test nchw2nlc2nchw function | ||
shape_nchw = (4, 2, 5, 5) | ||
shape_nlc = (4, 25, 2) | ||
|
||
def test_func(x): | ||
assert x.shape == torch.Size(shape_nlc) | ||
return x | ||
|
||
x = torch.rand(*shape_nchw) | ||
output = nchw2nlc2nchw(test_func, x) | ||
assert output.shape == torch.Size(shape_nchw) | ||
|
||
def test_func2(x, arg): | ||
assert x.shape == torch.Size(shape_nlc) | ||
assert arg == 100 | ||
return x | ||
|
||
x = torch.rand(*shape_nchw) | ||
output = nchw2nlc2nchw(test_func2, x, arg=100) | ||
assert output.shape == torch.Size(shape_nchw) | ||
|
||
def test_func3(x): | ||
assert x.is_contiguous() | ||
assert x.shape == torch.Size(shape_nlc) | ||
return x | ||
|
||
x = torch.rand(*shape_nchw) | ||
output = nchw2nlc2nchw(test_func3, x, contiguous=True) | ||
assert output.shape == torch.Size(shape_nchw) | ||
assert output.is_contiguous() | ||
|
||
|
||
def test_nlc2nchw2nlc(): | ||
# Test nlc2nchw2nlc function | ||
shape_nchw = (4, 2, 5, 5) | ||
shape_nlc = (4, 25, 2) | ||
|
||
def test_func(x): | ||
assert x.shape == torch.Size(shape_nchw) | ||
return x | ||
|
||
x = torch.rand(*shape_nlc) | ||
output = nlc2nchw2nlc(test_func, x, shape_nchw[2:]) | ||
assert output.shape == torch.Size(shape_nlc) | ||
|
||
def test_func2(x, arg): | ||
assert x.shape == torch.Size(shape_nchw) | ||
assert arg == 100 | ||
return x | ||
|
||
x = torch.rand(*shape_nlc) | ||
output = nlc2nchw2nlc(test_func2, x, shape_nchw[2:], arg=100) | ||
assert output.shape == torch.Size(shape_nlc) | ||
|
||
def test_func3(x): | ||
assert x.is_contiguous() | ||
assert x.shape == torch.Size(shape_nchw) | ||
return x | ||
|
||
x = torch.rand(*shape_nlc) | ||
output = nlc2nchw2nlc(test_func3, x, shape_nchw[2:], contiguous=True) | ||
assert output.shape == torch.Size(shape_nlc) | ||
assert output.is_contiguous() | ||
|
||
|
||
def test_nchw_to_nlc(): | ||
# Test nchw_to_nlc function | ||
shape_nchw = (4, 2, 5, 5) | ||
shape_nlc = (4, 25, 2) | ||
x = torch.rand(*shape_nchw) | ||
y = nchw_to_nlc(x) | ||
assert y.shape == torch.Size(shape_nlc) | ||
|
||
|
||
def test_nlc_to_nchw(): | ||
# Test nlc_to_nchw function | ||
shape_nchw = (4, 2, 5, 5) | ||
shape_nlc = (4, 25, 2) | ||
x = torch.rand(*shape_nlc) | ||
y = nlc_to_nchw(x, (5, 5)) | ||
assert y.shape == torch.Size(shape_nchw) |