From 5b2909774645a47ef068cf88d15ace6ea9e9b7ef Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 14 Feb 2024 04:25:50 +0100 Subject: [PATCH 1/2] update docs --- docs/source-fabric/advanced/distributed_communication.rst | 5 +++++ src/lightning/fabric/fabric.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source-fabric/advanced/distributed_communication.rst b/docs/source-fabric/advanced/distributed_communication.rst index 83aac5cb3f285..920b557d0d6e5 100644 --- a/docs/source-fabric/advanced/distributed_communication.rst +++ b/docs/source-fabric/advanced/distributed_communication.rst @@ -236,6 +236,11 @@ Full example: result = fabric.all_gather(data) print("Result of all-gather:", result) # tensor([ 0, 10, 20, 30]) +.. warning:: + + For the special case where ``world_size`` is 1, no additional dimension is added to the tensor(s). This inconsistency + is kept for backward compatibility and you may need to handle this special case in your code to make it agnostic. + ---- diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index e53b424c42348..3bc56c3d05326 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -583,7 +583,8 @@ def all_gather( Return: A tensor of shape (world_size, batch, ...), or if the input was a collection - the output will also be a collection with tensors of this shape. + the output will also be a collection with tensors of this shape. For the special case where + world_size is 1, no additional dimension is added to the tensor(s). """ self._validate_launched() From ed5a197c52ac36abe80a2dbd1ece83f85b580369 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 14 Feb 2024 04:31:07 +0100 Subject: [PATCH 2/2] lightning module too --- src/lightning/pytorch/core/module.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 0bdcd7b28b204..d8bc79c9d781d 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -668,7 +668,8 @@ def all_gather( Return: A tensor of shape (world_size, batch, ...), or if the input was a collection - the output will also be a collection with tensors of this shape. + the output will also be a collection with tensors of this shape. For the special case where + world_size is 1, no additional dimension is added to the tensor(s). """ group = group if group is not None else torch.distributed.group.WORLD