Skip to content

Commit

Permalink
Pass enabled down to _BackwardSyncControl (#19577)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 8, 2024
1 parent 3740546 commit 06eb3cc
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 19 deletions.
3 changes: 2 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fabric now raises an error if you forget to call `fabric.backward()` when it is needed by the strategy or precision selection ([#19447](https://github.com/Lightning-AI/lightning/pull/19447), [#19493](https://github.com/Lightning-AI/lightning/pull/19493))


-
- `_BackwardSyncControl` can now control what to do when gradient accumulation is disabled ([#19577](https://github.com/Lightning-AI/lightning/pull/19577))


### Deprecated

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte
"You need to set up the model first before you can call `fabric.no_backward_sync()`:"
" `model = fabric.setup(model, ...)`"
)
if not enabled or isinstance(self._strategy, (SingleDeviceStrategy, XLAStrategy)):
if isinstance(self._strategy, (SingleDeviceStrategy, XLAStrategy)):
return nullcontext()
if self._strategy._backward_sync_control is None:
rank_zero_warn(
Expand All @@ -683,7 +683,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte
return nullcontext()

forward_module, _ = _unwrap_compiled(module._forward_module)
return self._strategy._backward_sync_control.no_backward_sync(forward_module)
return self._strategy._backward_sync_control.no_backward_sync(forward_module, enabled)

def sharded_model(self) -> ContextManager:
r"""Instantiate a model under this context manager to prepare it for model-parallel sharding.
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,12 @@ def _determine_ddp_device_ids(self) -> Optional[List[int]]:

class _DDPBackwardSyncControl(_BackwardSyncControl):
@override
def no_backward_sync(self, module: Module) -> ContextManager:
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
"""Blocks gradient synchronization inside the :class:`~torch.nn.parallel.distributed.DistributedDataParallel`
wrapper."""
if not enabled:
return nullcontext()

if not isinstance(module, DistributedDataParallel):
raise TypeError(
"Blocking backward sync is only possible if the module passed to"
Expand Down
6 changes: 4 additions & 2 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from contextlib import ExitStack
from contextlib import ExitStack, nullcontext
from datetime import timedelta
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -768,9 +768,11 @@ def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwa

class _FSDPBackwardSyncControl(_BackwardSyncControl):
@override
def no_backward_sync(self, module: Module) -> ContextManager:
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
"""Blocks gradient synchronization inside the :class:`~torch.distributed.fsdp.FullyShardedDataParallel`
wrapper."""
if not enabled:
return nullcontext()
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel

if not isinstance(module, FullyShardedDataParallel):
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ class _BackwardSyncControl(ABC):
"""

@abstractmethod
def no_backward_sync(self, module: Module) -> ContextManager:
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
"""Blocks the synchronization of gradients during the backward pass.
This is a context manager. It is only effective if it wraps a call to `.backward()`.
Expand Down
4 changes: 3 additions & 1 deletion src/lightning/fabric/strategies/xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,9 +679,11 @@ def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: Dict

class _XLAFSDPBackwardSyncControl(_BackwardSyncControl):
@override
def no_backward_sync(self, module: Module) -> ContextManager:
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
"""Blocks gradient synchronization inside the :class:`~torch_xla.distributed.fsdp.XlaFullyShardedDataParallel`
wrapper."""
if not enabled:
return nullcontext()
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP

if not isinstance(module, XLAFSDP):
Expand Down
8 changes: 5 additions & 3 deletions tests/tests_fabric/strategies/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@ def test_ddp_no_backward_sync():

with pytest.raises(
TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`"
), strategy._backward_sync_control.no_backward_sync(Mock()):
), strategy._backward_sync_control.no_backward_sync(Mock(), True):
pass

module = MagicMock(spec=DistributedDataParallel)
with strategy._backward_sync_control.no_backward_sync(module):
with strategy._backward_sync_control.no_backward_sync(module, False):
pass
module.no_sync.assert_not_called()
with strategy._backward_sync_control.no_backward_sync(module, True):
pass

module.no_sync.assert_called_once()


Expand Down
8 changes: 5 additions & 3 deletions tests/tests_fabric/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,15 @@ def test_fsdp_no_backward_sync():

with pytest.raises(
TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`"
), strategy._backward_sync_control.no_backward_sync(Mock()):
), strategy._backward_sync_control.no_backward_sync(Mock(), True):
pass

module = MagicMock(spec=FullyShardedDataParallel)
with strategy._backward_sync_control.no_backward_sync(module):
with strategy._backward_sync_control.no_backward_sync(module, False):
pass
module.no_sync.assert_not_called()
with strategy._backward_sync_control.no_backward_sync(module, True):
pass

module.no_sync.assert_called_once()


Expand Down
8 changes: 6 additions & 2 deletions tests/tests_fabric/strategies/test_xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,17 @@ def test_xla_fsdp_no_backward_sync():

with pytest.raises(
TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`"
), strategy._backward_sync_control.no_backward_sync(object()):
), strategy._backward_sync_control.no_backward_sync(object(), True):
pass

module = MagicMock(spec=XlaFullyShardedDataParallel)
with strategy._backward_sync_control.no_backward_sync(module):

with strategy._backward_sync_control.no_backward_sync(module, False):
pass
module.no_sync.assert_not_called()

with strategy._backward_sync_control.no_backward_sync(module, True):
pass
module.no_sync.assert_called_once()


Expand Down
6 changes: 3 additions & 3 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,11 +767,11 @@ def test_no_backward_sync():
# disabling the context manager makes it a no-op
with fabric.no_backward_sync(model, enabled=False):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
# when enabled, the wrapped module gets passed down
fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module, False)
fabric._strategy._backward_sync_control.reset_mock()
with fabric.no_backward_sync(model):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module)
fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module, True)


def test_launch_without_function():
Expand Down

0 comments on commit 06eb3cc

Please sign in to comment.