diff --git a/nemo/collections/common/data/dataset.py b/nemo/collections/common/data/dataset.py index da549aa1a2db..283d34e07b1f 100644 --- a/nemo/collections/common/data/dataset.py +++ b/nemo/collections/common/data/dataset.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, List, Optional, Tuple import numpy as np import torch.utils.data as pt_data @@ -189,6 +189,7 @@ class ConcatMapDataset(Dataset): sampling_temperature (int): Temperature value for sampling. Only used when sampling_technique = 'temperature'. Defaults to 5. sampling_probabilities (list): Probability values for sampling. Only used when sampling_technique = 'random'. + seed: Optional value to seed the numpy RNG. """ def __init__( @@ -196,79 +197,75 @@ def __init__( datasets: List[Any], sampling_technique: str = 'temperature', sampling_temperature: int = 5, - sampling_probabilities: List[float] = None, - consumed_samples: int = 0, + sampling_probabilities: Optional[List[float]] = None, + seed: Optional[int] = None, ): super().__init__() self.datasets = datasets - self.sampling_kwargs = {} - self.size = 0 + self.lengths = [len(x) for x in self.datasets] self.sampling_technique = sampling_technique self.sampling_temperature = sampling_temperature self.sampling_probabilities = sampling_probabilities - self.consumed_samples = consumed_samples - self.np_rng = np.random.RandomState(consumed_samples) - - for dataset in datasets: - self.size += len(dataset) - - # Pointer into the next index to fetch from each dataset - self.dataset_index = np.zeros(len(self.datasets), dtype=np.uint8) - self.permuted_dataset_indices = [] - for dataset in self.datasets: - permuted_indices = np.arange(len(dataset)) - self.np_rng.shuffle(permuted_indices) - self.permuted_dataset_indices.append(permuted_indices) - - if self.sampling_technique == 'temperature': - lengths = [] - for dataset in datasets: - lengths.append(len(dataset)) - - p = np.array(lengths) / np.sum(lengths) - p = np.power(p, 1 / self.sampling_temperature) + self.np_rng = np.random.RandomState(seed) + + # Build a list of size `len(self)`. Each tuple contains (dataset_id, dataset_index) + self.indices: List[Tuple[int, int]] = [] + # Current position as we consume indices from each data set + dataset_positions = [0] * len(self.datasets) + # Random permutation of each dataset. Will be regenerated when exhausted. + shuffled_indices = [self.np_rng.permutation(len(x)) for x in self.datasets] + # Build the list of randomly-chosen datasets spanning the entire length, adhering to sampling technique + if self.sampling_technique == "round-robin": + # To exhaust longest dataset, need to draw `num_datasets * max_dataset_len` samples + total_length = max(self.lengths) * len(self.lengths) + # For round robin, iterate through each dataset + dataset_ids = np.arange(total_length) % len(self.datasets) + for dataset_id in dataset_ids: + position = dataset_positions[dataset_id] + index = shuffled_indices[dataset_id][position] + self.indices.append((dataset_id, index)) + dataset_positions[dataset_id] += 1 + if dataset_positions[dataset_id] == len(shuffled_indices[dataset_id]): + dataset_positions[dataset_id] = 0 + shuffled_indices[dataset_id] = self.np_rng.permutation(len(self.datasets[dataset_id])) + else: + # Resolve probabilities of drawing from each data set + if self.sampling_technique == "random": + if sampling_probabilities is None or len(sampling_probabilities) != len(self.datasets): + raise ValueError( + f"Need {len(self.datasets)} probabilities; got " + f"{len(sampling_probabilities) if sampling_probabilities is not None else 'None'}" + ) + p = np.array(self.sampling_probabilities) + elif self.sampling_technique == "temperature": + p = np.array([len(x) for x in self.datasets]) + p = np.power(p, 1 / self.sampling_temperature) + else: + raise ValueError(f"Couldn't interpret sampling technique: {sampling_technique}") + # Normalize probabilities p = p / np.sum(p) - self.p = p - - elif self.sampling_technique == 'random': - if not self.sampling_probabilities: - raise ValueError( - "Random generator expects a 'sampling_probabilities' - a list of probability values corresponding to each dataset." - ) - - if len(self.sampling_probabilities) != len(self.datasets): - raise ValueError( - f"Length of probabilities list must be equal to the number of datasets. Found {len(sampling_probabilities)} probs and {len(self.datasets)} datasets." - ) - - p = np.array(self.sampling_probabilities) - self.p = p / np.sum(p) # Ensure probabilities sum to 1 + # Will randomly choose from datasets + choices = np.arange(len(self.datasets)) + # Keep going until largest dataset is exhausted. + exhausted_datasets = set() + while len(exhausted_datasets) < len(self.datasets): + # Randomly choose a dataset for each position in accordance with p + dataset_id = self.np_rng.choice(a=choices, p=p) + dataset = self.datasets[dataset_id] + # Pick next index from dataset + position = dataset_positions[dataset_id] + index = shuffled_indices[dataset_id][position] + self.indices.append((dataset_id, index)) + # Maybe reset this dataset's permutation + dataset_positions[dataset_id] += 1 + if dataset_positions[dataset_id] >= len(dataset): + shuffled_indices[dataset_id] = self.np_rng.permutation(len(dataset)) + dataset_positions[dataset_id] = 0 + exhausted_datasets.add(dataset_id) def __len__(self): - return self.size - - def _get_dataset_index(self, idx): - if self.sampling_technique == 'temperature' or self.sampling_technique == 'random': - return self.np_rng.choice(np.arange(len(self.datasets)), p=self.p) - elif self.sampling_technique == 'round-robin': - return idx % len(self.datasets) + return len(self.indices) def __getitem__(self, idx): - # Get the dataset we want to sample from - dataset_index = self._get_dataset_index(idx) - - # Get the index of the sample we want to fetch from the dataset - sample_idx = self.dataset_index[dataset_index] - - # If the sample idx > dataset size, reset to 0. - if sample_idx > len(self.datasets[dataset_index]): - sample_idx = 0 - self.dataset_index[dataset_index] = 0 - - # Sample index -> shuffled sample index - shuffled_sample_idx = self.permuted_dataset_indices[dataset_index][sample_idx] - - sample = self.datasets[dataset_index][shuffled_sample_idx] - self.dataset_index[dataset_index] += 1 - - return sample + dataset_id, dataset_index = self.indices[idx] + return self.datasets[dataset_id][dataset_index]