Skip to content

Commit

Permalink
Cherrypick #8521 into r2.6 (#8563)
Browse files Browse the repository at this point in the history
Co-authored-by: Chengji Yao <[email protected]>
  • Loading branch information
tengyifei and Chengji Yao authored Jan 14, 2025
1 parent c23cd63 commit 31b2b3b
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 9 deletions.
7 changes: 3 additions & 4 deletions docs/source/perf/ddp.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ repo](https://github.com/pytorch/xla/). For those who are interested in
the native xla data parallel approach, here is the
[tutorial](../API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing).

Here are some of the known issues that are under investigation: \*
`gradient_as_bucket_view=False` needs to be enforced. \* There are some
issues while being used with `torch.utils.data.DataLoader`.
`test_train_mp_mnist.py` with real data crashes before exiting.
Here are some of the known issues that are under investigation: \* There are some
issues while being used with `torch.utils.data.DataLoader`. `test_train_mp_mnist.py`
with real data crashes before exiting.
9 changes: 6 additions & 3 deletions test/distributed_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def train_step(model, inputs, labels, optimizer, loss_fn):

def ddp_correctness(init_method: str = 'env://',
use_large_net: bool = False,
debug: bool = False):
debug: bool = False,
gradient_as_bucket_view: bool = False):
if init_method == 'env://':
rank = xr.global_ordinal()
world_size = xr.world_size()
Expand All @@ -111,11 +112,13 @@ def ddp_correctness(init_method: str = 'env://',
steps = 5 # To save test time.
cpu_model = LargeNet()

# TODO: There're issues in the captured graph when gradient_as_bucket_view is True
# bucket_cap_mb is set to 1 mb such that we can still have multiple all_reduces while avoiding
# using models that are too larger (25 mb).
# To be noted, DDP currently uses one bucket for the first iteration. See pytorch#73732.
ddp_model = DDP(copy.deepcopy(cpu_model).to(device), bucket_cap_mb=1)
ddp_model = DDP(
copy.deepcopy(cpu_model).to(device),
gradient_as_bucket_view=gradient_as_bucket_view,
bucket_cap_mb=1)
# ddp_model.register_comm_hook(state=None, hook=comp_hook)

cpu_optimizer = optim.SGD(cpu_model.parameters(), lr=1e-1)
Expand Down
13 changes: 11 additions & 2 deletions test/torch_distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
class TestXrtDistributedDataParallel(parameterized.TestCase):

@staticmethod
def _ddp_correctness(rank, use_large_net: bool, debug: bool):
def _ddp_correctness(rank,
use_large_net: bool,
debug: bool,
gradient_as_bucket_view: bool = False):
# We cannot run this guard before XMP,
# see API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing.
device = xm.xla_device()
Expand All @@ -27,11 +30,17 @@ def _ddp_correctness(rank, use_large_net: bool, debug: bool):
file=sys.stderr)
return
util.ddp_correctness(
init_method="xla://", use_large_net=use_large_net, debug=debug)
init_method="xla://",
use_large_net=use_large_net,
debug=debug,
gradient_as_bucket_view=gradient_as_bucket_view)

def test_ddp_correctness(self):
torch_xla.launch(self._ddp_correctness, args=(False, FLAGS.debug))

def test_ddp_correctness_with_gradient_as_bucket_view(self):
torch_xla.launch(self._ddp_correctness, args=(False, FLAGS.debug, True))

def test_ddp_correctness_large_net(self):
torch_xla.launch(self._ddp_correctness, args=(True, FLAGS.debug))

Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ void ReplaceXlaTensor(const at::Tensor& tensor, XLATensorPtr new_xla_tensor) {
impl->set_tensor(std::move(new_xla_tensor));
}

void ReplaceXlaTensor(const std::vector<at::Tensor>& tensors,
const std::vector<XLATensorPtr> new_xla_tensors) {
XLA_CHECK(tensors.size() == new_xla_tensors.size())
<< "The size of tensors and new_xla_tensors are not equal: "
<< tensors.size() << " vs. " << new_xla_tensors.size();
for (size_t i = 0; i < tensors.size(); ++i) {
ReplaceXlaTensor(tensors[i], new_xla_tensors[i]);
}
}

std::vector<XLATensorPtr> GetXlaTensors(const at::ITensorListRef& tensors) {
std::vector<XLATensorPtr> xla_tensors;
xla_tensors.reserve(tensors.size());
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/aten_xla_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ XLATensorPtr GetXlaTensor(const at::Tensor& tensor);
// version.
void ReplaceXlaTensor(const at::Tensor& tensor, XLATensorPtr new_xla_tensor);

void ReplaceXlaTensor(const std::vector<at::Tensor>& tensor,
const std::vector<XLATensorPtr> new_xla_tensor);

// Same as above, applied to a list of tensors.
std::vector<XLATensorPtr> GetXlaTensors(const at::ITensorListRef& tensors);

Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,9 @@ void AllReduceInPlace(const std::string& reduce_type,
GetXlaTensors(tensors, /*want_all=*/true);
tensor_methods::all_reduce(xtensors, GetReduceType(reduce_type), scale,
replica_groups, pin_layout);
std::vector<XLATensorPtr> new_xtensors =
GetXlaTensors(tensors, /*want_all=*/true);
bridge::ReplaceXlaTensor(tensors, new_xtensors);
}

at::Tensor AllReduce(const std::string& reduce_type, const at::Tensor& input,
Expand Down

0 comments on commit 31b2b3b

Please sign in to comment.