Skip to content

Commit

Permalink
bfloat16 support for mgpu (NVIDIA#19)
Browse files Browse the repository at this point in the history
* bfloat16 support for apex DDP

* enable mgpu tests for fp16 and bf16

* update Dockerfile
  • Loading branch information
rohithkrn authored Jun 3, 2020
1 parent aea81c0 commit b0c7d09
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG FROM_IMAGE=lcskrishna/rocm-pytorch:rocm3.3_ubuntu16.04_py3.6_pytorch_updated
ARG FROM_IMAGE=lcskrishna/rocm-pytorch:rocm3.3_ubuntu16.04_py3.6_pytorch_bfloat16_mgpu

FROM ${FROM_IMAGE}
RUN \
Expand Down
13 changes: 7 additions & 6 deletions apex/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def apply_flat_dist_call(bucket, call, extra_args=None):
for buf, synced in zip(bucket, unflatten(coalesced, bucket)):
buf.copy_(synced)

def split_half_float_double(tensors):
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor"]
def split_half_float_double_bfloat16(tensors):
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
Expand Down Expand Up @@ -240,7 +240,8 @@ def __init__(self,

self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
"torch.cuda.FloatTensor" : 1,
"torch.cuda.DoubleTensor" : 2}
"torch.cuda.DoubleTensor" : 2,
"torch.cuda.BFloat16Tensor" : 3}

if multi_tensor_applier.available:
# TODO: I really need to centralize the C++ backed imports
Expand Down Expand Up @@ -498,7 +499,7 @@ def allreduce_fallback(self):
else:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]

split_buckets = split_half_float_double(grads)
split_buckets = split_half_float_double_bfloat16(grads)

# If retain_allreduce_buffers is True and delay_allreduce is False,
# this will only be done during the first backward pass, ignored by the
Expand Down Expand Up @@ -578,8 +579,8 @@ def forward(self, *inputs, **kwargs):
if self.needs_refresh:
self.active_i_buckets = []
self.buckets = []
self.tmp_buckets = [[], [], []] # [running half, float, double buckets]
self.tmp_numels = [0, 0, 0]
self.tmp_buckets = [[], [], [], []] # [running half, float, double, bfloat16 buckets]
self.tmp_numels = [0, 0, 0, 0]
self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {}
Expand Down
3 changes: 2 additions & 1 deletion tests/distributed/amp_master_params/amp_master_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied
# automatically by torch.distributed.launch.
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--opt_level", default="O2", type=str)
args = parser.parse_args()

# FOR DISTRIBUTED: If we are running under torch.distributed.launch,
Expand Down Expand Up @@ -42,7 +43,7 @@
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level)

if args.distributed:
# FOR DISTRIBUTED: After amp.initialize, wrap the model with
Expand Down
5 changes: 4 additions & 1 deletion tests/distributed/amp_master_params/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
model_params_rank1,
master_params_rank0,
master_params_rank1):
# converting model params to float is a hack since allclose doesn't support bfloat16 yet.
model_rank0 = model_rank0.float()
model_rank1 = model_rank1.float()
assert torch.allclose(model_rank0, model_rank1), "Model param mismatch"
assert torch.allclose(master_rank0, master_rank1), "Master param mismatch"
# Some debugging/investigation assistance code:
Expand All @@ -23,6 +26,6 @@
# print(maxval.item(), maxind.item(), offending_val_half.item(), offending_val_float.item(),
# offending_val_float.half().item())
# rtol needs to be > 2^-11 because of denormals...
assert torch.allclose(model_rank0, master_rank0.half(), rtol=.005), "Model-master mismatch"
assert torch.allclose(model_rank0, master_rank0, rtol=.005), "Model-master mismatch"

print("OK: Model and master params match across ranks.")
30 changes: 30 additions & 0 deletions tests/distributed/amp_master_params/run_rocm_distributed.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/bin/bash
set -e

# To run the test on 2 gpus
export WORLD_SIZE=2

# Test with opt_level="O2"
echo "running opt_level O2"
python3.6 -m torch.distributed.launch --nproc_per_node=2 amp_master_params.py --opt_level "O2"
python3.6 compare.py

# delete the model files
echo -e "O2 test completed. Deleting model files\n"
rm rank0model.pth
rm rank1model.pth
rm rank0master.pth
rm rank1master.pth


# Test with opt_level="O5"
echo "running opt_level O5"
python3.6 -m torch.distributed.launch --nproc_per_node=2 amp_master_params.py --opt_level "O5"
python3.6 compare.py

# delete the model files
echo "O5 test completed. Deleting model files"
rm rank0model.pth
rm rank1model.pth
rm rank0master.pth
rm rank1master.pth

0 comments on commit b0c7d09

Please sign in to comment.