Skip to content

Commit

Permalink
Fix mypy errors attributed to `pytorch_lightning.utilities.distribute…
Browse files Browse the repository at this point in the history
…d` (#13678)


Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
krishnakalyan3 and carmocca authored Jul 16, 2022
1 parent e23756b commit 2845e75
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ module = [
"pytorch_lightning.tuner.batch_size_scaling",
"pytorch_lightning.utilities.auto_restart",
"pytorch_lightning.utilities.data",
"pytorch_lightning.utilities.distributed",
"pytorch_lightning.utilities.meta",
]
ignore_errors = "True"
5 changes: 3 additions & 2 deletions src/pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Un
if group is None:
group = torch.distributed.group.WORLD

op: Optional[ReduceOp]
if isinstance(reduce_op, str):
if reduce_op.lower() in ("avg", "mean"):
op = ReduceOp.SUM
Expand Down Expand Up @@ -174,7 +175,7 @@ def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Un

class AllGatherGrad(torch.autograd.Function):
@staticmethod
def forward(
def forward( # type: ignore[override]
ctx: Any,
tensor: Tensor,
group: Optional["torch.distributed.ProcessGroup"] = group.WORLD,
Expand Down Expand Up @@ -317,7 +318,7 @@ def register_ddp_comm_hook(
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)

new_rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook) # type: ignore[operator]


def tpu_distributed() -> bool:
Expand Down

0 comments on commit 2845e75

Please sign in to comment.