From 83d7ba287a9cb2607728857b86551bd1ef10afe9 Mon Sep 17 00:00:00 2001 From: yzhangcs Date: Mon, 22 Jul 2024 15:40:33 +0800 Subject: [PATCH] Make `BufferShuffledExamplesIterable` resumable --- src/datasets/iterable_dataset.py | 117 ++++++++++++++++++++++++++----- tests/test_iterable_dataset.py | 12 ++-- 2 files changed, 107 insertions(+), 22 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index c23f45570b4..63784438c6e 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1379,19 +1379,20 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generat # TODO(QL): implement iter_arrow def _init_state_dict(self) -> dict: - self._state_dict = self.ex_iterable._init_state_dict() - self._original_state_dict = self.state_dict() + self._state_dict = { + "ex_iterable": self.ex_iterable._init_state_dict(), + "num_taken": 0, + "global_example_idx": 0, + "buffer_state_dict": { + "num_taken": 0, + "global_example_idx": 0, + "index_offset": 0, + "bit_generator_state": None, + "first_state": self.ex_iterable.state_dict() + } + } return self._state_dict - def load_state_dict(self, state_dict: dict) -> dict: - if self._state_dict: - if state_dict != self._original_state_dict: - logger.warning( - "Loading a state dict of a shuffle buffer of a dataset without the buffer content." - "The shuffle buffer will be refilled before starting to yield new examples." - ) - return super().load_state_dict(state_dict) - @staticmethod def _iter_random_indices(rng: np.random.Generator, buffer_size: int, random_batch_size=1000) -> Iterator[int]: while True: @@ -1400,19 +1401,101 @@ def _iter_random_indices(rng: np.random.Generator, buffer_size: int, random_batc def __iter__(self): buffer_size = self.buffer_size rng = deepcopy(self.generator) - indices_iterator = self._iter_random_indices(rng, buffer_size) + + global_example_idx = 0 + index_offset = 0 + bit_generator_state = rng.bit_generator.state + final_skipped = 0 + buffer_state_dict = { + "num_taken": 0, + "global_example_idx": global_example_idx, + "index_offset": index_offset, + "bit_generator_state": bit_generator_state, + "first_state": None + } + random_batch_size = 1000 + + if self._state_dict: + global_example_idx = self._state_dict["global_example_idx"] + buffer_state_dict = self._state_dict["buffer_state_dict"] + index_offset = buffer_state_dict["index_offset"] + if buffer_state_dict["bit_generator_state"] is not None: + bit_generator_state = buffer_state_dict["bit_generator_state"] + rng.bit_generator.state = bit_generator_state + indices_iterator = self._iter_random_indices(rng, buffer_size, random_batch_size) # this is the shuffle buffer that we keep in memory - mem_buffer = [] + mem_buffer = {} + + # resume the buffer if necessary + if global_example_idx > 0: + self.ex_iterable.load_state_dict(buffer_state_dict["first_state"]) + global_example_idx_start = buffer_state_dict["global_example_idx"] + final_skipped = self._state_dict["num_taken"] - buffer_state_dict["num_taken"] + num_consumed = global_example_idx - global_example_idx_start + # skip consumed random indices + for _ in range(max(index_offset - 1, 0) % random_batch_size): + i = next(indices_iterator) + for x in islice(self.ex_iterable, num_consumed): + if global_example_idx_start < buffer_size: + i = global_example_idx_start + else: + i = next(indices_iterator) + index_offset += 1 + # pop the existed first to make examples inserted by order + if i in mem_buffer: + mem_buffer.pop(i) + mem_buffer[i] = { + "x": x, + "global_example_idx": global_example_idx_start, + "index_offset": index_offset, + "bit_generator_state": bit_generator_state, + "state_dict": buffer_state_dict["first_state"] + } + global_example_idx_start += 1 + + state_dict = self.ex_iterable.state_dict() if self._state_dict else None for x in self.ex_iterable: if len(mem_buffer) == buffer_size: # if the buffer is full, pick and example from it + if index_offset % random_batch_size == 0: + bit_generator_state = rng.bit_generator.state i = next(indices_iterator) - yield mem_buffer[i] - mem_buffer[i] = x # replace the picked example by a new one + index_offset += 1 + example = mem_buffer.pop(i) + if self._state_dict: + first = mem_buffer[next(iter(mem_buffer))] + self._state_dict["num_taken"] += 1 + self._state_dict["global_example_idx"] = global_example_idx + 1 + # record everything needed (by making a copy) to resume the buffer from the first example + self._state_dict["buffer_state_dict"] = { + "num_taken": self._state_dict["num_taken"], + "global_example_idx": first["global_example_idx"], + "index_offset": first["index_offset"], + "bit_generator_state": first["bit_generator_state"], + "first_state": first["state_dict"] + } + yield example["x"] else: # otherwise, keep filling the buffer - mem_buffer.append(x) + i = len(mem_buffer) + mem_buffer[i] = { + "x": x, + "global_example_idx": global_example_idx, + "index_offset": index_offset, + "bit_generator_state": bit_generator_state, + "state_dict": state_dict + } + if self._state_dict: + state_dict = self.ex_iterable.state_dict() + global_example_idx += 1 + mem_buffer = [mem_buffer[i]["x"] for i in sorted(mem_buffer.keys())] # when we run out of examples, we shuffle the remaining examples in the buffer and yield them rng.shuffle(mem_buffer) - yield from mem_buffer + for i, x in enumerate(mem_buffer): + if i < final_skipped: + continue + if self._state_dict: + self._state_dict["num_taken"] += 1 + self._state_dict["global_example_idx"] = global_example_idx + yield x def shuffle_data_sources(self, generator: np.random.Generator) -> "BufferShuffledExamplesIterable": """Shuffle the wrapped examples iterable as well as the shuffling buffer.""" diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 6d21eda3863..96e24675e28 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -59,7 +59,7 @@ ) -DEFAULT_N_EXAMPLES = 20 +DEFAULT_N_EXAMPLES = 40 DEFAULT_BATCH_SIZE = 4 DEFAULT_FILEPATH = "file.txt" @@ -307,14 +307,17 @@ def gen(tables): @pytest.mark.parametrize("seed", [42, 1337, 101010, 123456]) def test_buffer_shuffled_examples_iterable(seed): - n, buffer_size = 100, 30 + n, buffer_size, random_batch_size = 100, 30, 1000 generator = np.random.default_rng(seed) base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) ex_iterable = BufferShuffledExamplesIterable(base_ex_iterable, buffer_size=buffer_size, generator=generator) rng = deepcopy(generator) expected_indices_used_for_shuffling = list( - islice(BufferShuffledExamplesIterable._iter_random_indices(rng, buffer_size=buffer_size), n - buffer_size) + islice( + BufferShuffledExamplesIterable._iter_random_indices(rng, buffer_size=buffer_size, random_batch_size=random_batch_size), + n - buffer_size + ) ) # indices to pick in the shuffle buffer should all be in the right range assert all(0 <= index_to_pick < buffer_size for index_to_pick in expected_indices_used_for_shuffling) @@ -1188,8 +1191,7 @@ def test_horizontally_concatenated_examples_iterable(): ) def test_no_iter_arrow(ex_iterable: _BaseExamplesIterable): assert ex_iterable.iter_arrow is None - if not isinstance(ex_iterable, BufferShuffledExamplesIterable): - assert_load_state_dict_resumes_iteration(ex_iterable) + assert_load_state_dict_resumes_iteration(ex_iterable) @pytest.mark.parametrize(