diff --git a/Dockerfile b/Dockerfile index 545c0bf4a..8bf9a1705 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -ARG FROM_IMAGE=lcskrishna/rocm-pytorch:rocm3.3_ubuntu16.04_py3.6_pytorch_updated +ARG FROM_IMAGE=lcskrishna/rocm-pytorch:rocm3.3_ubuntu16.04_py3.6_pytorch_bfloat16_mgpu FROM ${FROM_IMAGE} RUN \ diff --git a/apex/parallel/distributed.py b/apex/parallel/distributed.py index 5267c834a..6aa6a6e8a 100644 --- a/apex/parallel/distributed.py +++ b/apex/parallel/distributed.py @@ -48,8 +48,8 @@ def apply_flat_dist_call(bucket, call, extra_args=None): for buf, synced in zip(bucket, unflatten(coalesced, bucket)): buf.copy_(synced) -def split_half_float_double(tensors): - dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor"] +def split_half_float_double_bfloat16(tensors): + dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"] buckets = [] for i, dtype in enumerate(dtypes): bucket = [t for t in tensors if t.type() == dtype] @@ -240,7 +240,8 @@ def __init__(self, self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0, "torch.cuda.FloatTensor" : 1, - "torch.cuda.DoubleTensor" : 2} + "torch.cuda.DoubleTensor" : 2, + "torch.cuda.BFloat16Tensor" : 3} if multi_tensor_applier.available: # TODO: I really need to centralize the C++ backed imports @@ -498,7 +499,7 @@ def allreduce_fallback(self): else: grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] - split_buckets = split_half_float_double(grads) + split_buckets = split_half_float_double_bfloat16(grads) # If retain_allreduce_buffers is True and delay_allreduce is False, # this will only be done during the first backward pass, ignored by the @@ -578,8 +579,8 @@ def forward(self, *inputs, **kwargs): if self.needs_refresh: self.active_i_buckets = [] self.buckets = [] - self.tmp_buckets = [[], [], []] # [running half, float, double buckets] - self.tmp_numels = [0, 0, 0] + self.tmp_buckets = [[], [], [], []] # [running half, float, double, bfloat16 buckets] + self.tmp_numels = [0, 0, 0, 0] self.bucket_sizes = [] self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)} self.param_id_to_bucket = {} diff --git a/tests/distributed/amp_master_params/amp_master_params.py b/tests/distributed/amp_master_params/amp_master_params.py index 4af5092f7..4b3a80498 100644 --- a/tests/distributed/amp_master_params/amp_master_params.py +++ b/tests/distributed/amp_master_params/amp_master_params.py @@ -9,6 +9,7 @@ # FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied # automatically by torch.distributed.launch. parser.add_argument("--local_rank", default=0, type=int) +parser.add_argument("--opt_level", default="O2", type=str) args = parser.parse_args() # FOR DISTRIBUTED: If we are running under torch.distributed.launch, @@ -42,7 +43,7 @@ model = torch.nn.Linear(D_in, D_out).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) -model, optimizer = amp.initialize(model, optimizer, opt_level="O2") +model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) if args.distributed: # FOR DISTRIBUTED: After amp.initialize, wrap the model with diff --git a/tests/distributed/amp_master_params/compare.py b/tests/distributed/amp_master_params/compare.py index e5cbf20c1..b8047752a 100644 --- a/tests/distributed/amp_master_params/compare.py +++ b/tests/distributed/amp_master_params/compare.py @@ -14,6 +14,9 @@ model_params_rank1, master_params_rank0, master_params_rank1): + # converting model params to float is a hack since allclose doesn't support bfloat16 yet. + model_rank0 = model_rank0.float() + model_rank1 = model_rank1.float() assert torch.allclose(model_rank0, model_rank1), "Model param mismatch" assert torch.allclose(master_rank0, master_rank1), "Master param mismatch" # Some debugging/investigation assistance code: @@ -23,6 +26,6 @@ # print(maxval.item(), maxind.item(), offending_val_half.item(), offending_val_float.item(), # offending_val_float.half().item()) # rtol needs to be > 2^-11 because of denormals... - assert torch.allclose(model_rank0, master_rank0.half(), rtol=.005), "Model-master mismatch" + assert torch.allclose(model_rank0, master_rank0, rtol=.005), "Model-master mismatch" print("OK: Model and master params match across ranks.") diff --git a/tests/distributed/amp_master_params/run_rocm_distributed.sh b/tests/distributed/amp_master_params/run_rocm_distributed.sh new file mode 100644 index 000000000..932466916 --- /dev/null +++ b/tests/distributed/amp_master_params/run_rocm_distributed.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -e + +# To run the test on 2 gpus +export WORLD_SIZE=2 + +# Test with opt_level="O2" +echo "running opt_level O2" +python3.6 -m torch.distributed.launch --nproc_per_node=2 amp_master_params.py --opt_level "O2" +python3.6 compare.py + +# delete the model files +echo -e "O2 test completed. Deleting model files\n" +rm rank0model.pth +rm rank1model.pth +rm rank0master.pth +rm rank1master.pth + + +# Test with opt_level="O5" +echo "running opt_level O5" +python3.6 -m torch.distributed.launch --nproc_per_node=2 amp_master_params.py --opt_level "O5" +python3.6 compare.py + +# delete the model files +echo "O5 test completed. Deleting model files" +rm rank0model.pth +rm rank1model.pth +rm rank0master.pth +rm rank1master.pth