Skip to content

Commit

Permalink
Revert "fix mp in DataLoader (#2506) (#2507)" (#2521)
Browse files Browse the repository at this point in the history
This reverts commit 5576e6f.
  • Loading branch information
xingchensong authored May 8, 2024
1 parent bd22fae commit 2258c72
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 27 deletions.
9 changes: 3 additions & 6 deletions test/wenet/dataset/test_datapipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch
from torch.utils.data import datapipes
from torch.utils.data.datapipes.iter import IterableWrapper
import torch.multiprocessing as mp
from functools import partial

from wenet.dataset.datapipes import (RepeatDatapipe, SortDataPipe,
Expand Down Expand Up @@ -109,11 +108,9 @@ def test_dynamic_batch_datapipe(data_list):
window_class=DynamicBatchWindow(max_frames_in_batch),
wrapper_class=padding)

dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
num_workers=2,
multiprocessing_context=mp.get_context("spawn"))
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=None,
num_workers=2)
for d in dataloader:
assert d['feats'].size(1) <= max_frames_in_batch

Expand Down
11 changes: 4 additions & 7 deletions test/wenet/dataset/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
import torch
import torch.multiprocessing as mp
from wenet.dataset.dataset import Dataset
from wenet.text.char_tokenizer import CharTokenizer

Expand Down Expand Up @@ -55,11 +54,9 @@ def test_dataset(params):
data_list,
tokenizer=tokenizer,
conf=dataset_conf)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
num_workers=4,
persistent_workers=True,
multiprocessing_context=mp.get_context("spawn"))
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=None,
num_workers=4,
persistent_workers=True)
for d in dataloader:
pass
5 changes: 1 addition & 4 deletions tools/compute_cmvn_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp


class CollateFunc(object):
Expand Down Expand Up @@ -108,14 +107,12 @@ def __getitem__(self, idx):
collate_func = CollateFunc(feat_dim, resample_rate)
dataset = AudioDataset(args.in_scp)
batch_size = 20
mp_context = mp.get_context("spawn") if args.num_workers > 0 else None
data_loader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
sampler=None,
num_workers=args.num_workers,
collate_fn=collate_func,
multiprocessing_context=mp_context)
collate_fn=collate_func)

with torch.no_grad():
all_number = 0
Expand Down
5 changes: 1 addition & 4 deletions wenet/bin/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch
import yaml
from torch.utils.data import DataLoader
import torch.multiprocessing as mp

from wenet.dataset.dataset import Dataset
from wenet.utils.config import override_config
Expand Down Expand Up @@ -223,11 +222,9 @@ def main():
test_conf,
partition=False)

mp_context = mp.get_context("spawn") if args.num_workers > 0 else None
test_data_loader = DataLoader(test_dataset,
batch_size=None,
num_workers=args.num_workers,
multiprocessing_context=mp_context)
num_workers=args.num_workers)

# Init asr model from configs
args.jit = False
Expand Down
8 changes: 2 additions & 6 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp

from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -346,23 +345,20 @@ def init_dataset_and_dataloader(args, configs, tokenizer, seed=777):

# NOTE(xcsong): Why we prefer persistent_workers=True ?
# https://discuss.pytorch.org/t/what-are-the-dis-advantages-of-persistent-workers/102110
mp_context = mp.get_context("spawn") if args.num_workers > 0 else None
train_data_loader = DataLoader(train_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
persistent_workers=True,
generator=generator,
prefetch_factor=args.prefetch,
multiprocessing_context=mp_context)
prefetch_factor=args.prefetch)
cv_data_loader = DataLoader(cv_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
persistent_workers=True,
generator=generator,
prefetch_factor=args.prefetch,
multiprocessing_context=mp_context)
prefetch_factor=args.prefetch)
return train_dataset, cv_dataset, train_data_loader, cv_data_loader


Expand Down

0 comments on commit 2258c72

Please sign in to comment.