Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make BufferShuffledExamplesIterable resumable #7056

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 100 additions & 17 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,19 +1379,20 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generat
# TODO(QL): implement iter_arrow

def _init_state_dict(self) -> dict:
self._state_dict = self.ex_iterable._init_state_dict()
self._original_state_dict = self.state_dict()
self._state_dict = {
"ex_iterable": self.ex_iterable._init_state_dict(),
"num_taken": 0,
"global_example_idx": 0,
"buffer_state_dict": {
"num_taken": 0,
"global_example_idx": 0,
"index_offset": 0,
"bit_generator_state": None,
"first_state": self.ex_iterable.state_dict()
}
}
return self._state_dict

def load_state_dict(self, state_dict: dict) -> dict:
if self._state_dict:
if state_dict != self._original_state_dict:
logger.warning(
"Loading a state dict of a shuffle buffer of a dataset without the buffer content."
"The shuffle buffer will be refilled before starting to yield new examples."
)
return super().load_state_dict(state_dict)

@staticmethod
def _iter_random_indices(rng: np.random.Generator, buffer_size: int, random_batch_size=1000) -> Iterator[int]:
while True:
Expand All @@ -1400,19 +1401,101 @@ def _iter_random_indices(rng: np.random.Generator, buffer_size: int, random_batc
def __iter__(self):
buffer_size = self.buffer_size
rng = deepcopy(self.generator)
indices_iterator = self._iter_random_indices(rng, buffer_size)

global_example_idx = 0
index_offset = 0
bit_generator_state = rng.bit_generator.state
final_skipped = 0
buffer_state_dict = {
"num_taken": 0,
"global_example_idx": global_example_idx,
"index_offset": index_offset,
"bit_generator_state": bit_generator_state,
"first_state": None
}
random_batch_size = 1000

if self._state_dict:
global_example_idx = self._state_dict["global_example_idx"]
buffer_state_dict = self._state_dict["buffer_state_dict"]
index_offset = buffer_state_dict["index_offset"]
if buffer_state_dict["bit_generator_state"] is not None:
bit_generator_state = buffer_state_dict["bit_generator_state"]
rng.bit_generator.state = bit_generator_state
indices_iterator = self._iter_random_indices(rng, buffer_size, random_batch_size)
# this is the shuffle buffer that we keep in memory
mem_buffer = []
mem_buffer = {}

# resume the buffer if necessary
if global_example_idx > 0:
self.ex_iterable.load_state_dict(buffer_state_dict["first_state"])
global_example_idx_start = buffer_state_dict["global_example_idx"]
final_skipped = self._state_dict["num_taken"] - buffer_state_dict["num_taken"]
num_consumed = global_example_idx - global_example_idx_start
# skip consumed random indices
for _ in range(max(index_offset - 1, 0) % random_batch_size):
i = next(indices_iterator)
for x in islice(self.ex_iterable, num_consumed):
if global_example_idx_start < buffer_size:
i = global_example_idx_start
else:
i = next(indices_iterator)
index_offset += 1
# pop the existed first to make examples inserted by order
if i in mem_buffer:
mem_buffer.pop(i)
mem_buffer[i] = {
"x": x,
"global_example_idx": global_example_idx_start,
"index_offset": index_offset,
"bit_generator_state": bit_generator_state,
"state_dict": buffer_state_dict["first_state"]
}
global_example_idx_start += 1

state_dict = self.ex_iterable.state_dict() if self._state_dict else None
for x in self.ex_iterable:
if len(mem_buffer) == buffer_size: # if the buffer is full, pick and example from it
if index_offset % random_batch_size == 0:
bit_generator_state = rng.bit_generator.state
i = next(indices_iterator)
yield mem_buffer[i]
mem_buffer[i] = x # replace the picked example by a new one
index_offset += 1
example = mem_buffer.pop(i)
if self._state_dict:
first = mem_buffer[next(iter(mem_buffer))]
self._state_dict["num_taken"] += 1
self._state_dict["global_example_idx"] = global_example_idx + 1
# record everything needed (by making a copy) to resume the buffer from the first example
self._state_dict["buffer_state_dict"] = {
"num_taken": self._state_dict["num_taken"],
"global_example_idx": first["global_example_idx"],
"index_offset": first["index_offset"],
"bit_generator_state": first["bit_generator_state"],
"first_state": first["state_dict"]
}
yield example["x"]
else: # otherwise, keep filling the buffer
mem_buffer.append(x)
i = len(mem_buffer)
mem_buffer[i] = {
"x": x,
"global_example_idx": global_example_idx,
"index_offset": index_offset,
"bit_generator_state": bit_generator_state,
"state_dict": state_dict
}
if self._state_dict:
state_dict = self.ex_iterable.state_dict()
global_example_idx += 1
mem_buffer = [mem_buffer[i]["x"] for i in sorted(mem_buffer.keys())]
# when we run out of examples, we shuffle the remaining examples in the buffer and yield them
rng.shuffle(mem_buffer)
yield from mem_buffer
for i, x in enumerate(mem_buffer):
if i < final_skipped:
continue
if self._state_dict:
self._state_dict["num_taken"] += 1
self._state_dict["global_example_idx"] = global_example_idx
yield x

def shuffle_data_sources(self, generator: np.random.Generator) -> "BufferShuffledExamplesIterable":
"""Shuffle the wrapped examples iterable as well as the shuffling buffer."""
Expand Down
12 changes: 7 additions & 5 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
)


DEFAULT_N_EXAMPLES = 20
DEFAULT_N_EXAMPLES = 40
DEFAULT_BATCH_SIZE = 4
DEFAULT_FILEPATH = "file.txt"

Expand Down Expand Up @@ -307,14 +307,17 @@ def gen(tables):

@pytest.mark.parametrize("seed", [42, 1337, 101010, 123456])
def test_buffer_shuffled_examples_iterable(seed):
n, buffer_size = 100, 30
n, buffer_size, random_batch_size = 100, 30, 1000
generator = np.random.default_rng(seed)
base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n})
ex_iterable = BufferShuffledExamplesIterable(base_ex_iterable, buffer_size=buffer_size, generator=generator)

rng = deepcopy(generator)
expected_indices_used_for_shuffling = list(
islice(BufferShuffledExamplesIterable._iter_random_indices(rng, buffer_size=buffer_size), n - buffer_size)
islice(
BufferShuffledExamplesIterable._iter_random_indices(rng, buffer_size=buffer_size, random_batch_size=random_batch_size),
n - buffer_size
)
)
# indices to pick in the shuffle buffer should all be in the right range
assert all(0 <= index_to_pick < buffer_size for index_to_pick in expected_indices_used_for_shuffling)
Expand Down Expand Up @@ -1188,8 +1191,7 @@ def test_horizontally_concatenated_examples_iterable():
)
def test_no_iter_arrow(ex_iterable: _BaseExamplesIterable):
assert ex_iterable.iter_arrow is None
if not isinstance(ex_iterable, BufferShuffledExamplesIterable):
assert_load_state_dict_resumes_iteration(ex_iterable)
assert_load_state_dict_resumes_iteration(ex_iterable)


@pytest.mark.parametrize(
Expand Down