Skip to content

Commit

Permalink
Fix for concat map dataset (NVIDIA#5133)
Browse files Browse the repository at this point in the history
* change for concat map dataset

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Exhaust longest dataset

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: 1-800-BAD-CODE <>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Oleksii Kuchaiev <[email protected]>
Co-authored-by: Sandeep Subramanian <[email protected]>
  • Loading branch information
4 people authored and Jimmy Zhang committed Dec 14, 2022
1 parent 8f1d9de commit d315225
Showing 1 changed file with 63 additions and 66 deletions.
129 changes: 63 additions & 66 deletions nemo/collections/common/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List
from typing import Any, List, Optional, Tuple

import numpy as np
import torch.utils.data as pt_data
Expand Down Expand Up @@ -189,86 +189,83 @@ class ConcatMapDataset(Dataset):
sampling_temperature (int): Temperature value for sampling. Only used when sampling_technique = 'temperature'.
Defaults to 5.
sampling_probabilities (list): Probability values for sampling. Only used when sampling_technique = 'random'.
seed: Optional value to seed the numpy RNG.
"""

def __init__(
self,
datasets: List[Any],
sampling_technique: str = 'temperature',
sampling_temperature: int = 5,
sampling_probabilities: List[float] = None,
consumed_samples: int = 0,
sampling_probabilities: Optional[List[float]] = None,
seed: Optional[int] = None,
):
super().__init__()
self.datasets = datasets
self.sampling_kwargs = {}
self.size = 0
self.lengths = [len(x) for x in self.datasets]
self.sampling_technique = sampling_technique
self.sampling_temperature = sampling_temperature
self.sampling_probabilities = sampling_probabilities
self.consumed_samples = consumed_samples
self.np_rng = np.random.RandomState(consumed_samples)

for dataset in datasets:
self.size += len(dataset)

# Pointer into the next index to fetch from each dataset
self.dataset_index = np.zeros(len(self.datasets), dtype=np.uint8)
self.permuted_dataset_indices = []
for dataset in self.datasets:
permuted_indices = np.arange(len(dataset))
self.np_rng.shuffle(permuted_indices)
self.permuted_dataset_indices.append(permuted_indices)

if self.sampling_technique == 'temperature':
lengths = []
for dataset in datasets:
lengths.append(len(dataset))

p = np.array(lengths) / np.sum(lengths)
p = np.power(p, 1 / self.sampling_temperature)
self.np_rng = np.random.RandomState(seed)

# Build a list of size `len(self)`. Each tuple contains (dataset_id, dataset_index)
self.indices: List[Tuple[int, int]] = []
# Current position as we consume indices from each data set
dataset_positions = [0] * len(self.datasets)
# Random permutation of each dataset. Will be regenerated when exhausted.
shuffled_indices = [self.np_rng.permutation(len(x)) for x in self.datasets]
# Build the list of randomly-chosen datasets spanning the entire length, adhering to sampling technique
if self.sampling_technique == "round-robin":
# To exhaust longest dataset, need to draw `num_datasets * max_dataset_len` samples
total_length = max(self.lengths) * len(self.lengths)
# For round robin, iterate through each dataset
dataset_ids = np.arange(total_length) % len(self.datasets)
for dataset_id in dataset_ids:
position = dataset_positions[dataset_id]
index = shuffled_indices[dataset_id][position]
self.indices.append((dataset_id, index))
dataset_positions[dataset_id] += 1
if dataset_positions[dataset_id] == len(shuffled_indices[dataset_id]):
dataset_positions[dataset_id] = 0
shuffled_indices[dataset_id] = self.np_rng.permutation(len(self.datasets[dataset_id]))
else:
# Resolve probabilities of drawing from each data set
if self.sampling_technique == "random":
if sampling_probabilities is None or len(sampling_probabilities) != len(self.datasets):
raise ValueError(
f"Need {len(self.datasets)} probabilities; got "
f"{len(sampling_probabilities) if sampling_probabilities is not None else 'None'}"
)
p = np.array(self.sampling_probabilities)
elif self.sampling_technique == "temperature":
p = np.array([len(x) for x in self.datasets])
p = np.power(p, 1 / self.sampling_temperature)
else:
raise ValueError(f"Couldn't interpret sampling technique: {sampling_technique}")
# Normalize probabilities
p = p / np.sum(p)
self.p = p

elif self.sampling_technique == 'random':
if not self.sampling_probabilities:
raise ValueError(
"Random generator expects a 'sampling_probabilities' - a list of probability values corresponding to each dataset."
)

if len(self.sampling_probabilities) != len(self.datasets):
raise ValueError(
f"Length of probabilities list must be equal to the number of datasets. Found {len(sampling_probabilities)} probs and {len(self.datasets)} datasets."
)

p = np.array(self.sampling_probabilities)
self.p = p / np.sum(p) # Ensure probabilities sum to 1
# Will randomly choose from datasets
choices = np.arange(len(self.datasets))
# Keep going until largest dataset is exhausted.
exhausted_datasets = set()
while len(exhausted_datasets) < len(self.datasets):
# Randomly choose a dataset for each position in accordance with p
dataset_id = self.np_rng.choice(a=choices, p=p)
dataset = self.datasets[dataset_id]
# Pick next index from dataset
position = dataset_positions[dataset_id]
index = shuffled_indices[dataset_id][position]
self.indices.append((dataset_id, index))
# Maybe reset this dataset's permutation
dataset_positions[dataset_id] += 1
if dataset_positions[dataset_id] >= len(dataset):
shuffled_indices[dataset_id] = self.np_rng.permutation(len(dataset))
dataset_positions[dataset_id] = 0
exhausted_datasets.add(dataset_id)

def __len__(self):
return self.size

def _get_dataset_index(self, idx):
if self.sampling_technique == 'temperature' or self.sampling_technique == 'random':
return self.np_rng.choice(np.arange(len(self.datasets)), p=self.p)
elif self.sampling_technique == 'round-robin':
return idx % len(self.datasets)
return len(self.indices)

def __getitem__(self, idx):
# Get the dataset we want to sample from
dataset_index = self._get_dataset_index(idx)

# Get the index of the sample we want to fetch from the dataset
sample_idx = self.dataset_index[dataset_index]

# If the sample idx > dataset size, reset to 0.
if sample_idx > len(self.datasets[dataset_index]):
sample_idx = 0
self.dataset_index[dataset_index] = 0

# Sample index -> shuffled sample index
shuffled_sample_idx = self.permuted_dataset_indices[dataset_index][sample_idx]

sample = self.datasets[dataset_index][shuffled_sample_idx]
self.dataset_index[dataset_index] += 1

return sample
dataset_id, dataset_index = self.indices[idx]
return self.datasets[dataset_id][dataset_index]

0 comments on commit d315225

Please sign in to comment.