Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Parallel Data Loading Shufflable Iterable Datasets/DataStreams #100

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

alex-jw-brooks
Copy link
Collaborator

@alex-jw-brooks alex-jw-brooks commented Jul 25, 2023

This PR adds support for multiple workers when processing an iterable dataset in such a way that:

  • Ensures that data is evenly split across the workers in a true partition, i.e., no sampling
  • We can still shuffle after every iteration, even if we have multiple workers

The caveat to this is that for the shuffling to work correctly, we need to use persistent_workers=True when creating our data loader.

This is accomplished by defining a shuffle_seed, which is essentially a random seed that gets incremented every time we cycle through our data. This is used as the random seed when creating the shuffled stream generator; the workers must be persistent, otherwise the shuffle_seed will get reset with every iteration, but this approach lets us shuffle consistently across workers without them communicating.

Then, to divide the data, we create an iterator yielding every nth item of the preprocessed stream (which would be shuffled by now) given n worker, with an offset based on the worker ID.

Also adds docstrings to the stream wrapper & caches the stream length, since len() is an expensive operation for the data stream.

Closes: #74

Signed-off-by: Alex-Brooks <[email protected]>
@alex-jw-brooks
Copy link
Collaborator Author

alex-jw-brooks commented Jul 25, 2023

Minimal example showing shuffling (run from a venv on Linux):

import torch
from caikit.core.data_model import DataStream
from caikit_nlp.toolkit.data_stream_wrapper import SimpleIterableStreamWrapper

SAMPLE_DATA = [{"label": "a"}, {"label": "b"}, {"label": "c"}, {"label": "d"}]
SAMPLE_STREAM = DataStream.from_iterable(SAMPLE_DATA)
wrapper = SimpleIterableStreamWrapper(stream=SAMPLE_STREAM, shuffle=True)

torch_loader = torch.utils.data.DataLoader(
    wrapper,
    num_workers=2,
    persistent_workers=True, # Needed, otherwise every iteration shuffles the same way!
)

for epoch in range(3):
    for idx, x in enumerate(torch_loader):
        print(x)
    print("Finished iteration: {}".format(epoch))

Sample output:

{'label': ['c']}
{'label': ['b']}
{'label': ['d']}
{'label': ['a']}
Finished iteration: 0
{'label': ['c']}
{'label': ['d']}
{'label': ['b']}
{'label': ['a']}
Finished iteration: 1
{'label': ['b']}
{'label': ['a']}
{'label': ['c']}
{'label': ['d']}
Finished iteration: 2

In some preliminary benchmarking I did, this is unfortunately slower than running with no worker processes, at least for the way we handle tokenizer mapping onto train streams in prompt tuning (on the order of 2-3x slower). While a bit of a bummer, this is a generic utility for datastreams, and may be beneficial for re-entrant streams that have heft iteration costs, like loading from files etc

@alex-jw-brooks
Copy link
Collaborator Author

alex-jw-brooks commented Jul 26, 2023

There are some other potential optimizations that can be made around this, but they do break the genericism a bit; it might be better to consider getting this in first, and having the optimizations as a follow up.

The two main ones I can think of are:

  • Tokenizer function mapping. Mapping over the data stream this way effectively builds an on the fly tokenizer that retokenizes every time because of the way reentry on the iterator works. I.e., same situation as the sample code below
import caikit
s = caikit.core.data_model.DataStream.from_iterable([1])

def map_func(example):
    print(f"Called the map func on example: {example}") # printed 10 times since every iteration calls this again upon reentry
    return example + 1

mapped_s = s.map(map_func)
for _ in range(10):
    for x in mapped_s:
        pass
  • skipping tokenization of unyielded samples while tokenizing; that ^ is much more of an issue with the approach we have here, which effectively divides the iterator across n processes, because we're wasting time tokenizing things that aren't even being yielded because they're supposed to be yielded by other processes. Since we know what things to yield, it's probably a good idea to actually hold the func to be mapped in the stream wrapper, and only apply it when we're iterating on things being yielded, assuming we want to do on the fly tokenization. I.e., if we have 4 processes, we only apply the mapped func on each 4th sample. This one could also be added to this PR since it's more isolated from the models using it - I don't have a strong preference either way

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Change SimpleIterableStreamWrapper to work with multiple workers allowing shuffling
1 participant