diff --git a/test/wenet/dataset/test_datapipes.py b/test/wenet/dataset/test_datapipes.py index f269788c9e..a3bfaff948 100644 --- a/test/wenet/dataset/test_datapipes.py +++ b/test/wenet/dataset/test_datapipes.py @@ -2,6 +2,7 @@ import torch from torch.utils.data import datapipes from torch.utils.data.datapipes.iter import IterableWrapper +import torch.multiprocessing as mp from functools import partial from wenet.dataset.datapipes import (RepeatDatapipe, SortDataPipe, @@ -108,9 +109,11 @@ def test_dynamic_batch_datapipe(data_list): window_class=DynamicBatchWindow(max_frames_in_batch), wrapper_class=padding) - dataloader = torch.utils.data.DataLoader(dataset, - batch_size=None, - num_workers=2) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=None, + num_workers=2, + multiprocessing_context=mp.get_context("spawn")) for d in dataloader: assert d['feats'].size(1) <= max_frames_in_batch diff --git a/test/wenet/dataset/test_dataset.py b/test/wenet/dataset/test_dataset.py index 86bf22b9b8..5d36264277 100644 --- a/test/wenet/dataset/test_dataset.py +++ b/test/wenet/dataset/test_dataset.py @@ -1,5 +1,6 @@ import pytest import torch +import torch.multiprocessing as mp from wenet.dataset.dataset import Dataset from wenet.text.char_tokenizer import CharTokenizer @@ -54,9 +55,11 @@ def test_dataset(params): data_list, tokenizer=tokenizer, conf=dataset_conf) - dataloader = torch.utils.data.DataLoader(dataset, - batch_size=None, - num_workers=4, - persistent_workers=True) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=None, + num_workers=4, + persistent_workers=True, + multiprocessing_context=mp.get_context("spawn")) for d in dataloader: pass diff --git a/tools/compute_cmvn_stats.py b/tools/compute_cmvn_stats.py index c68929436c..5ae3a8bea5 100755 --- a/tools/compute_cmvn_stats.py +++ b/tools/compute_cmvn_stats.py @@ -11,6 +11,7 @@ import torchaudio import torchaudio.compliance.kaldi as kaldi from torch.utils.data import Dataset, DataLoader +import torch.multiprocessing as mp class CollateFunc(object): @@ -107,12 +108,14 @@ def __getitem__(self, idx): collate_func = CollateFunc(feat_dim, resample_rate) dataset = AudioDataset(args.in_scp) batch_size = 20 + mp_context = mp.get_context("spawn") if args.num_workers > 0 else None data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, sampler=None, num_workers=args.num_workers, - collate_fn=collate_func) + collate_fn=collate_func, + multiprocessing_context=mp_context) with torch.no_grad(): all_number = 0 diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index 3779b74eca..8cddd7185b 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -22,6 +22,7 @@ import torch import yaml from torch.utils.data import DataLoader +import torch.multiprocessing as mp from wenet.dataset.dataset import Dataset from wenet.utils.config import override_config @@ -222,9 +223,11 @@ def main(): test_conf, partition=False) + mp_context = mp.get_context("spawn") if args.num_workers > 0 else None test_data_loader = DataLoader(test_dataset, batch_size=None, - num_workers=args.num_workers) + num_workers=args.num_workers, + multiprocessing_context=mp_context) # Init asr model from configs args.jit = False diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 5131a13ac6..4d3a809612 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -37,6 +37,21 @@ logging.getLogger('langid').setLevel(logging.INFO) +import os +try: + cpu_info = os.popen("lscpu | grep 'Vendor ID'").read() + # 0x48 --> HiSilicon + if (cpu_info.rstrip().split(" ")[-1] == "0x48"): + # NOTE (MengqingCao): set number of threads in the subprocesses to 1 + # Why? There may be some operators ultilizing multi-threads in processor, + # causing possibly deadlock in Kunpeng. + # Similar issue in PyTorch: https://github.com/pytorch/pytorch/issues/45198 + torch.set_num_threads(1) +except Exception as ex: + logging.warning('Failed to set number of thread in Kunpeng, \ + this may cause segmentfault while dataloading, \ + ignore this warning if you are not using Kunpeng') + class UrlOpenError(Exception): diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index cdf6da2b3a..7bf822a8ef 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -26,6 +26,7 @@ import torch.optim as optim import torch.distributed as dist +import torch.multiprocessing as mp from tensorboardX import SummaryWriter from torch.utils.data import DataLoader @@ -345,20 +346,23 @@ def init_dataset_and_dataloader(args, configs, tokenizer, seed=777): # NOTE(xcsong): Why we prefer persistent_workers=True ? # https://discuss.pytorch.org/t/what-are-the-dis-advantages-of-persistent-workers/102110 + mp_context = mp.get_context("spawn") if args.num_workers > 0 else None train_data_loader = DataLoader(train_dataset, batch_size=None, pin_memory=args.pin_memory, num_workers=args.num_workers, persistent_workers=True, generator=generator, - prefetch_factor=args.prefetch) + prefetch_factor=args.prefetch, + multiprocessing_context=mp_context) cv_data_loader = DataLoader(cv_dataset, batch_size=None, pin_memory=args.pin_memory, num_workers=args.num_workers, persistent_workers=True, generator=generator, - prefetch_factor=args.prefetch) + prefetch_factor=args.prefetch, + multiprocessing_context=mp_context) return train_dataset, cv_dataset, train_data_loader, cv_data_loader