Skip to content

Commit

Permalink
[uio] fix the repeat read dataset problem in the evaluation process (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
czy97 authored Jul 12, 2023
1 parent c20d765 commit 326b871
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
3 changes: 2 additions & 1 deletion wespeaker/bin/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def extract(config='conf/config.yaml', **kwargs):
spk2id_dict={},
whole_utt=(batch_size == 1),
reverb_lmdb_file=None,
noise_lmdb_file=None)
noise_lmdb_file=None,
repeat_dataset=False)
dataloader = DataLoader(dataset,
shuffle=False,
batch_size=batch_size,
Expand Down
30 changes: 19 additions & 11 deletions wespeaker/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ def sample(self, data):

class DataList(IterableDataset):

def __init__(self, lists, shuffle=True, partition=True):
def __init__(self, lists, shuffle=True, partition=True, repeat_dataset=True):
self.lists = lists
self.repeat_dataset = repeat_dataset
self.sampler = DistributedSampler(shuffle, partition)

def set_epoch(self, epoch):
Expand All @@ -113,14 +114,20 @@ def set_epoch(self, epoch):
def __iter__(self):
sampler_info = self.sampler.update()
indexes = self.sampler.sample(self.lists)
indexes_len = len(indexes)
counter = 0
while True:
index = indexes[counter % indexes_len]
counter += 1
data = dict(src=self.lists[index])
data.update(sampler_info)
yield data
if not self.repeat_dataset:
for index in indexes:
data = dict(src=self.lists[index])
data.update(sampler_info)
yield data
else:
indexes_len = len(indexes)
counter = 0
while True:
index = indexes[counter % indexes_len]
counter += 1
data = dict(src=self.lists[index])
data.update(sampler_info)
yield data


def Dataset(data_type,
Expand All @@ -129,7 +136,8 @@ def Dataset(data_type,
spk2id_dict,
whole_utt=False,
reverb_lmdb_file=None,
noise_lmdb_file=None):
noise_lmdb_file=None,
repeat_dataset=True):
""" Construct dataset from arguments
We have two shuffle stage in the Dataset. The first is global
Expand All @@ -149,7 +157,7 @@ def Dataset(data_type,
lists = read_lists(data_list_file)
shuffle = configs.get('shuffle', False)
# Global shuffle
dataset = DataList(lists, shuffle=shuffle)
dataset = DataList(lists, shuffle=shuffle, repeat_dataset=repeat_dataset)
if data_type == 'shard':
dataset = Processor(dataset, processor.url_opener)
dataset = Processor(dataset, processor.tar_file_and_group)
Expand Down

0 comments on commit 326b871

Please sign in to comment.