Skip to content

Commit

Permalink
part of #926 (#1368)
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <[email protected]>

Co-authored-by: Nic Ma <[email protected]>
Co-authored-by: Isaac Yang <[email protected]>
  • Loading branch information
3 people authored Jan 4, 2021
1 parent 0c2411e commit 1924b43
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions tests/test_distributed_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch
import torch.distributed as dist

from monai.data import DistributedSampler
from tests.utils import DistCall, DistTestCase


def test(expected, **kwargs):
dist.init_process_group(backend="nccl", init_method="env://")

torch.cuda.set_device(dist.get_rank())
data = [1, 2, 3, 4, 5]
sampler = DistributedSampler(dataset=data, **kwargs)
samples = np.array([data[i] for i in list(sampler)])
if dist.get_rank() == 0:
np.testing.assert_allclose(samples, np.array(expected[0]))

if dist.get_rank() == 1:
np.testing.assert_allclose(samples, np.array(expected[1]))

dist.destroy_process_group()
class DistributedSamplerTest(DistTestCase):
@DistCall(nnodes=1, nproc_per_node=2)
def test_even(self):
data = [1, 2, 3, 4, 5]
sampler = DistributedSampler(dataset=data, shuffle=False)
samples = np.array([data[i] for i in list(sampler)])
if dist.get_rank() == 0:
np.testing.assert_allclose(samples, np.array([1, 3, 5]))

if dist.get_rank() == 1:
np.testing.assert_allclose(samples, np.array([2, 4, 1]))

def main():
test(shuffle=False, expected=[[1, 3, 5], [2, 4, 1]])
test(shuffle=False, even_divisible=False, expected=[[1, 3, 5], [2, 4]])
@DistCall(nnodes=1, nproc_per_node=2)
def test_uneven(self):
data = [1, 2, 3, 4, 5]
sampler = DistributedSampler(dataset=data, shuffle=False, even_divisible=False)
samples = np.array([data[i] for i in list(sampler)])
if dist.get_rank() == 0:
np.testing.assert_allclose(samples, np.array([1, 3, 5]))

if dist.get_rank() == 1:
np.testing.assert_allclose(samples, np.array([2, 4]))

# suppose to execute on 2 rank processes
# python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE
# --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE
# --master_addr="localhost" --master_port=1234
# test_distributed_sampler.py

if __name__ == "__main__":
main()
unittest.main()

0 comments on commit 1924b43

Please sign in to comment.