diff --git a/src/so_vits_svc_fork/train.py b/src/so_vits_svc_fork/train.py index f72ad837..e2cc0617 100644 --- a/src/so_vits_svc_fork/train.py +++ b/src/so_vits_svc_fork/train.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os import warnings from logging import getLogger from multiprocessing import cpu_count @@ -10,6 +11,7 @@ import torch from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.strategies.ddp import DDPStrategy from lightning.pytorch.tuner import Tuner from torch.cuda.amp import autocast from torch.nn import functional as F @@ -75,7 +77,13 @@ def train( datamodule = VCDataModule(hparams) strategy = ( - "ddp_find_unused_parameters_true" if torch.cuda.device_count() > 1 else "auto" + ( + "ddp_find_unused_parameters_true" + if os.name != "nt" + else DDPStrategy(find_unused_parameters=True, process_group_backend="gloo") + ) + if torch.cuda.device_count() > 1 + else "auto" ) LOG.info(f"Using strategy: {strategy}") trainer = pl.Trainer(