-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Parallel Data Loading Shufflable Iterable Datasets/DataStreams #100
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
73e7af8
to
1392991
Compare
Minimal example showing shuffling (run from a venv on Linux): import torch
from caikit.core.data_model import DataStream
from caikit_nlp.toolkit.data_stream_wrapper import SimpleIterableStreamWrapper
SAMPLE_DATA = [{"label": "a"}, {"label": "b"}, {"label": "c"}, {"label": "d"}]
SAMPLE_STREAM = DataStream.from_iterable(SAMPLE_DATA)
wrapper = SimpleIterableStreamWrapper(stream=SAMPLE_STREAM, shuffle=True)
torch_loader = torch.utils.data.DataLoader(
wrapper,
num_workers=2,
persistent_workers=True, # Needed, otherwise every iteration shuffles the same way!
)
for epoch in range(3):
for idx, x in enumerate(torch_loader):
print(x)
print("Finished iteration: {}".format(epoch)) Sample output:
In some preliminary benchmarking I did, this is unfortunately slower than running with no worker processes, at least for the way we handle tokenizer mapping onto train streams in prompt tuning (on the order of 2-3x slower). While a bit of a bummer, this is a generic utility for datastreams, and may be beneficial for re-entrant streams that have heft iteration costs, like loading from files etc |
There are some other potential optimizations that can be made around this, but they do break the genericism a bit; it might be better to consider getting this in first, and having the optimizations as a follow up. The two main ones I can think of are:
import caikit
s = caikit.core.data_model.DataStream.from_iterable([1])
def map_func(example):
print(f"Called the map func on example: {example}") # printed 10 times since every iteration calls this again upon reentry
return example + 1
mapped_s = s.map(map_func)
for _ in range(10):
for x in mapped_s:
pass
|
This PR adds support for multiple workers when processing an iterable dataset in such a way that:
The caveat to this is that for the shuffling to work correctly, we need to use
persistent_workers=True
when creating our data loader.This is accomplished by defining a
shuffle_seed
, which is essentially a random seed that gets incremented every time we cycle through our data. This is used as the random seed when creating the shuffled stream generator; the workers must be persistent, otherwise theshuffle_seed
will get reset with every iteration, but this approach lets us shuffle consistently across workers without them communicating.Then, to divide the data, we create an iterator yielding every nth item of the preprocessed stream (which would be shuffled by now) given n worker, with an offset based on the worker ID.
Also adds docstrings to the stream wrapper & caches the stream length, since
len()
is an expensive operation for the data stream.Closes: #74