diff --git a/docs/source-fabric/advanced/distributed_communication.rst b/docs/source-fabric/advanced/distributed_communication.rst index 83aac5cb3f285..920b557d0d6e5 100644 --- a/docs/source-fabric/advanced/distributed_communication.rst +++ b/docs/source-fabric/advanced/distributed_communication.rst @@ -236,6 +236,11 @@ Full example: result = fabric.all_gather(data) print("Result of all-gather:", result) # tensor([ 0, 10, 20, 30]) +.. warning:: + + For the special case where ``world_size`` is 1, no additional dimension is added to the tensor(s). This inconsistency + is kept for backward compatibility and you may need to handle this special case in your code to make it agnostic. + ---- diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index e53b424c42348..3bc56c3d05326 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -583,7 +583,8 @@ def all_gather( Return: A tensor of shape (world_size, batch, ...), or if the input was a collection - the output will also be a collection with tensors of this shape. + the output will also be a collection with tensors of this shape. For the special case where + world_size is 1, no additional dimension is added to the tensor(s). """ self._validate_launched() diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 0bdcd7b28b204..d8bc79c9d781d 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -668,7 +668,8 @@ def all_gather( Return: A tensor of shape (world_size, batch, ...), or if the input was a collection - the output will also be a collection with tensors of this shape. + the output will also be a collection with tensors of this shape. For the special case where + world_size is 1, no additional dimension is added to the tensor(s). """ group = group if group is not None else torch.distributed.group.WORLD