Skip to content

Commit

Permalink
fix mp in DataLoader (wenet-e2e#2506)
Browse files Browse the repository at this point in the history
  • Loading branch information
MengqingCao committed Apr 28, 2024
1 parent f42ddb2 commit e1157e7
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp

from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -345,20 +346,26 @@ def init_dataset_and_dataloader(args, configs, tokenizer, seed=777):

# NOTE(xcsong): Why we prefer persistent_workers=True ?
# https://discuss.pytorch.org/t/what-are-the-dis-advantages-of-persistent-workers/102110
train_data_loader = DataLoader(train_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
persistent_workers=True,
generator=generator,
prefetch_factor=args.prefetch)
cv_data_loader = DataLoader(cv_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
persistent_workers=True,
generator=generator,
prefetch_factor=args.prefetch)
train_data_loader = DataLoader(
train_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
persistent_workers=True,
generator=generator,
prefetch_factor=args.prefetch,
multiprocessing_context=mp.get_context("spawn"),
)
cv_data_loader = DataLoader(
cv_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
persistent_workers=True,
generator=generator,
prefetch_factor=args.prefetch,
multiprocessing_context=mp.get_context("spawn"),
)
return train_dataset, cv_dataset, train_data_loader, cv_data_loader


Expand Down

0 comments on commit e1157e7

Please sign in to comment.