From fdaa5d769f55544fa83ddc170e489c57a6be53db Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 18 Mar 2021 19:26:10 +0000 Subject: [PATCH] Fix all_gather for tpu_cores=8 --- CHANGELOG.md | 3 +++ pytorch_lightning/accelerators/tpu.py | 8 ++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d397079bb072..348f720fa1059 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -163,6 +163,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an exception in the layer summary when the model contains torch.jit scripted submodules ([#6511](https://github.com/PyTorchLightning/pytorch-lightning/pull/6511)) +- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` + + ## [1.2.3] - 2021-03-09 ### Fixed diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 57e65a62f6783..5c4fb2815aa6d 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -46,12 +46,12 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra Function to gather a tensor from several distributed processes Args: tensor: tensor of shape (batch, ...) - group: the process group to gather results from. Defaults to all processes (world) - sync_grads: flag that allows users to synchronize gradients for all_gather op + group: not available with TPUs + sync_grads: not available with TPUs Return: A tensor of shape (world_size, batch, ...) """ # todo: Add support for backward with all_gather - if torch.distributed.is_initialized(): - return xm.all_gather(tensor, group=group, sync_grads=sync_grads) + if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed: + return xm.all_gather(tensor).view(-1, *tensor.shape) return tensor