Skip to content

Commit

Permalink
926 distributed training tests (#1295)
Browse files Browse the repository at this point in the history
* fixes #926

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Nov 27, 2020
1 parent 95ec73d commit dcc0a38
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 51 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ jobs:
os: [windows-latest, macOS-latest, ubuntu-latest]
timeout-minutes: 60
steps:
- if: runner.os == 'windows'
name: Config pagefile (Windows only)
uses: al-cheb/[email protected]
with:
minimum-size: 8
maximum-size: 16
disk-root: "D:"
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v2
Expand All @@ -73,7 +80,7 @@ jobs:
- if: runner.os == 'windows'
name: Install torch cpu from pytorch.org (Windows only)
run: |
python -m pip install torch==1.7.0 torchvision==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
# min. requirements for windows instances
python -c "f=open('requirements-dev.txt', 'r'); txt=f.readlines(); f.close(); print(txt); f=open('requirements-dev.txt', 'w'); f.writelines(txt[1:12]); f.close()"
- name: Install the dependencies
Expand Down
16 changes: 8 additions & 8 deletions monai/handlers/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,19 @@ def __init__(
self._num_examples = 0
self.compute_sample = compute_sample
self.metric_name = metric_name
self._total_tp = 0
self._total_fp = 0
self._total_tn = 0
self._total_fn = 0
self._total_tp = 0.0
self._total_fp = 0.0
self._total_tn = 0.0
self._total_fn = 0.0

@reinit__is_reduced
def reset(self) -> None:
self._sum = 0.0
self._num_examples = 0
self._total_tp = 0
self._total_fp = 0
self._total_tn = 0
self._total_fn = 0
self._total_tp = 0.0
self._total_fp = 0.0
self._total_tn = 0.0
self._total_fn = 0.0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
Expand Down
35 changes: 18 additions & 17 deletions tests/test_handler_confusion_matrix_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,27 @@
# limitations under the License.


import unittest

import numpy as np
import torch
import torch.distributed as dist

from monai.handlers import ConfusionMatrix
from tests.utils import DistCall, DistTestCase


class DistributedConfusionMatrix(DistTestCase):
@DistCall(nnodes=1, nproc_per_node=2)
def test_compute_sample(self):
self._compute(True)

def main():
for compute_sample in [True, False]:
dist.init_process_group(backend="nccl", init_method="env://")
@DistCall(nnodes=1, nproc_per_node=2)
def test_compute(self):
self._compute(False)

torch.cuda.set_device(dist.get_rank())
def _compute(self, compute_sample=True):
device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu"
metric = ConfusionMatrix(include_background=True, metric_name="tpr", compute_sample=compute_sample)

if dist.get_rank() == 0:
Expand All @@ -30,25 +39,25 @@ def main():
[[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]],
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]],
],
device=torch.device("cuda:0"),
device=device,
)
y = torch.tensor(
[
[[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]],
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]],
],
device=torch.device("cuda:0"),
device=device,
)
metric.update([y_pred, y])

if dist.get_rank() == 1:
y_pred = torch.tensor(
[[[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [0.0, 0.0]]]],
device=torch.device("cuda:1"),
device=device,
)
y = torch.tensor(
[[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]],
device=torch.device("cuda:1"),
device=device,
)
metric.update([y_pred, y])

Expand All @@ -59,14 +68,6 @@ def main():
else:
np.testing.assert_allclose(avg_metric, 0.8333, rtol=1e-04, atol=1e-04)

dist.destroy_process_group()


# suppose to execute on 2 rank processes
# python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE
# --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE
# --master_addr="192.168.1.1" --master_port=1234
# test_handler_confusion_matrix_dist.py

if __name__ == "__main__":
main()
unittest.main()
44 changes: 19 additions & 25 deletions tests/test_handler_rocauc_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,34 @@
# limitations under the License.


import unittest

import numpy as np
import torch
import torch.distributed as dist

from monai.handlers import ROCAUC
from tests.utils import DistCall, DistTestCase


def main():
dist.init_process_group(backend="nccl", init_method="env://")

torch.cuda.set_device(dist.get_rank())
auc_metric = ROCAUC(to_onehot_y=True, softmax=True)

if dist.get_rank() == 0:
y_pred = torch.tensor([[0.1, 0.9], [0.3, 1.4]], device=torch.device("cuda:0"))
y = torch.tensor([[0], [1]], device=torch.device("cuda:0"))
auc_metric.update([y_pred, y])

if dist.get_rank() == 1:
y_pred = torch.tensor([[0.2, 0.1], [0.1, 0.5]], device=torch.device("cuda:1"))
y = torch.tensor([[0], [1]], device=torch.device("cuda:1"))
auc_metric.update([y_pred, y])

result = auc_metric.compute()
np.testing.assert_allclose(0.75, result)
class DistributedROCAUC(DistTestCase):
@DistCall(nnodes=1, nproc_per_node=2, node_rank=0)
def test_compute(self):
auc_metric = ROCAUC(to_onehot_y=True, softmax=True)
device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu"
if dist.get_rank() == 0:
y_pred = torch.tensor([[0.1, 0.9], [0.3, 1.4]], device=device)
y = torch.tensor([[0], [1]], device=device)
auc_metric.update([y_pred, y])

dist.destroy_process_group()
if dist.get_rank() == 1:
y_pred = torch.tensor([[0.2, 0.1], [0.1, 0.5]], device=device)
y = torch.tensor([[0], [1]], device=device)
auc_metric.update([y_pred, y])

result = auc_metric.compute()
np.testing.assert_allclose(0.75, result)

# suppose to execute on 2 rank processes
# python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE
# --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE
# --master_addr="192.168.1.1" --master_port=1234
# test_handler_rocauc_dist.py

if __name__ == "__main__":
main()
unittest.main()
168 changes: 168 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import functools
import importlib
import os
import sys
import tempfile
import unittest
from io import BytesIO
from subprocess import PIPE, Popen
from typing import Optional
from urllib.error import ContentTooShortError, HTTPError, URLError

import numpy as np
import torch
import torch.distributed as dist

from monai.data import create_test_image_2d, create_test_image_3d
from monai.utils import optional_import, set_determinism
Expand Down Expand Up @@ -87,6 +92,169 @@ def make_nifti_image(array, affine=None):
return image_name


class DistTestCase(unittest.TestCase):
"""testcase without _outcome, so that it's picklable."""

original_mp = None

def setUp(self) -> None:
self.original_mp = torch.multiprocessing.get_start_method(allow_none=True)
try:
torch.multiprocessing.set_start_method("spawn", force=True)
except RuntimeError:
pass

def tearDown(self) -> None:
try:
torch.multiprocessing.set_start_method(str(self.original_mp), force=True)
except RuntimeError:
pass

def __getstate__(self):
self_dict = self.__dict__.copy()
del self_dict["_outcome"]
return self_dict


class DistCall:
"""
Wrap a test case so that it will run in multiple processes on a single machine using `torch.distributed`.
Usage:
decorate a unittest testcase method with a `DistCall` instance::
class MyTests(unittest.TestCase):
@DistCall(nnodes=1, nproc_per_node=3, master_addr="localhost")
def test_compute(self):
...
the `test_compute` method should trigger different worker logic according to `dist.get_rank()`.
Multi-node tests require a fixed master_addr:master_port, with node_rank set manually in multiple scripts
or from environment variable "NODE_RANK".
"""

def __init__(
self,
nnodes: int = 1,
nproc_per_node: int = 1,
master_addr: str = "localhost",
master_port: Optional[int] = None,
node_rank: Optional[int] = None,
timeout=60,
init_method=None,
backend: Optional[str] = None,
verbose: bool = False,
):
"""
Args:
nnodes: The number of nodes to use for distributed call.
nproc_per_node: The number of processes to call on each node.
master_addr: Master node (rank 0)'s address, should be either the IP address or the hostname of node 0.
master_port: Master node (rank 0)'s free port.
node_rank: The rank of the node, this could be set via environment variable "NODE_RANK".
timeout: Timeout for operations executed against the process group.
init_method: URL specifying how to initialize the process group. Default is "env://" if unspecified.
backend: The backend to use. Depending on build-time configurations,
valid values include ``mpi``, ``gloo``, and ``nccl``.
verbose: whether to print NCCL debug info.
"""
self.nnodes = int(nnodes)
self.nproc_per_node = int(nproc_per_node)
self.node_rank = int(os.environ.get("NODE_RANK", "0")) if node_rank is None else node_rank
self.master_addr = master_addr
self.master_port = np.random.randint(10000, 20000) if master_port is None else master_port

if backend is None:
self.backend = "nccl" if torch.distributed.is_nccl_available() and torch.cuda.is_available() else "gloo"
else:
self.backend = backend
self.init_method = init_method
if self.init_method is None and sys.platform == "win32":
self.init_method = "file:///d:/a_temp"
self.timeout = datetime.timedelta(0, timeout)
self.verbose = verbose

def run_process(self, func, local_rank, args, kwargs, results):
_env = os.environ.copy() # keep the original system env
try:
os.environ["MASTER_ADDR"] = self.master_addr
os.environ["MASTER_PORT"] = str(self.master_port)
os.environ["LOCAL_RANK"] = str(local_rank)
if self.verbose:
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_DEBUG_SUBSYS"] = "ALL"
os.environ["NCCL_BLOCKING_WAIT"] = str(1)
os.environ["OMP_NUM_THREADS"] = str(1)
os.environ["WORLD_SIZE"] = str(self.nproc_per_node * self.nnodes)
os.environ["RANK"] = str(self.nproc_per_node * self.node_rank + local_rank)

if torch.cuda.is_available():
torch.cuda.set_device(int(local_rank))

dist.init_process_group(
backend=self.backend,
init_method=self.init_method,
timeout=self.timeout,
world_size=int(os.environ["WORLD_SIZE"]),
rank=int(os.environ["RANK"]),
)
func(*args, **kwargs)
results.put(True)
except Exception as e:
results.put(False)
raise e
finally:
os.environ.clear()
os.environ.update(_env)
dist.destroy_process_group()

def __call__(self, obj):
if not torch.distributed.is_available():
return unittest.skipIf(True, "Skipping distributed tests because not torch.distributed.is_available()")(obj)

_cache_original_func(obj)

@functools.wraps(obj)
def _wrapper(*args, **kwargs):
processes = []
results = torch.multiprocessing.Queue()
func = _call_original_func
args = [obj.__name__, obj.__module__] + list(args)
for proc_rank in range(self.nproc_per_node):
p = torch.multiprocessing.Process(
target=self.run_process, args=(func, proc_rank, args, kwargs, results)
)
p.start()
processes.append(p)
for p in processes:
p.join()
assert results.get(), "Distributed call failed."

return _wrapper


_original_funcs = {}


def _cache_original_func(obj) -> None:
"""cache the original function by name, so that the decorator doesn't shadow it."""
global _original_funcs
_original_funcs[obj.__name__] = obj


def _call_original_func(name, module, *args, **kwargs):
if name not in _original_funcs:
_original_module = importlib.import_module(module) # reimport, refresh _original_funcs
if not hasattr(_original_module, name):
# refresh module doesn't work
raise RuntimeError(f"Could not recover the original {name} from {module}: {_original_funcs}.")
f = _original_funcs[name]
return f(*args, **kwargs)


class NumpyImageTestCase2D(unittest.TestCase):
im_shape = (128, 64)
input_channels = 1
Expand Down

0 comments on commit dcc0a38

Please sign in to comment.