Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable SequentialReadingService to support MP + Distributed #985

Closed
wants to merge 6 commits into from

Conversation

ejguan
Copy link
Contributor

@ejguan ejguan commented Feb 3, 2023

Fixes #911

Changes

  • Remove distributed code from PrototypeMPRS
  • Fix a bug of blocking_request_get not sent to worker process
  • Enable SequentialReadingService to combine both Distributed and MP ReadingService
  • Add tests for SequentialReadingService
  • Add tutorial for SequentialReadingService

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 3, 2023
@ejguan ejguan requested review from NivekT and wenleix February 3, 2023 21:26
@facebook-github-bot
Copy link
Contributor

@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -92,7 +92,7 @@ def _create_datapipe_queue_loop(source_datapipe, req_queue, res_queue, blocking_
return pipe_type.DataPipeBehindQueues(
source_datapipe,
protocol_type(req_queue, res_queue),
blocking_request_get=True,
blocking_request_get=blocking_request_get,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bug that would blocking the dispatching process since there will be multiple loops running in the same process. And, we need to make sure each loop won't block each other.
And, it should fix the problem for DPP + MPRS with Fullsync

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And, for bullet proof, I have changed both distribtued/non-distributed tests against the non-balanced data shard to guard the dispatching use cases.

Copy link
Contributor

@wenleix wenleix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Multiprocessing + Distributed
------------------------------

``SequentialReadingService`` can be used to combine both ``ReadingServices`` together to achive multiprocessing and distributed training at the same time.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect in the future, it could also be used in OSS to chain disagg reading service + last-mile "on-trainer" Python transformation? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so when AIStoreRS or RayRS is provided.

torchdata/dataloader2/communication/iter.py Outdated Show resolved Hide resolved
@@ -236,13 +223,13 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:

# Launch dispatching process for the lowest common ancestor of non-replicable DataPipes
graph = traverse_dps(datapipe)
non_replicable_dp = find_lca_round_robin_sharding_dp(graph)
if non_replicable_dp is not None:
dispatching_dp = find_lca_round_robin_sharding_dp(graph)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. it's now more clear which part of dp ~


def initialize(self, datapipe: DataPipe) -> DataPipe:
r"""
``PrototypeMultiProcessingReadingService`` finds information about sharding,
separates graph by multiple pieces and reconnects it using queues.
creates subprocesses.
"""
if dist.is_available() and dist.is_initialized():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

pass


def _launch_distributed_training(world_size, *args, fn):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked ChatGPT what does this function do:

The program is a function _launch_distributed_training that launches a distributed training process. The function takes in parameters world_size, *args, and fn. The environment variable MASTER_ADDR is set to TEST_MASTER_ADDR, and the environment variable MASTER_PORT is set to a value returned from a call to the _get_open_port function. The function creates a multiprocessing context using the spawn method, and creates a queue q using the context.

The function then creates world_size processes using the Process method of the context and starts each process. The target of each process is the function fn, and the arguments for each process are rank, world_size, q, and *args. The function stores the created processes in a list ps.

The function then uses a while loop to get data from the queue q and append it to a list res. The loop breaks when a TerminateSignal is received from the queue. After the loop, the function joins all processes in the ps list. Finally, the function returns the res list.

Seems quite correct? ~

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if we give the description back to ChatGPT and ask it to write a code~~ "write a function that (quote~)"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, writing those comments would take me more time than writing code.

Copy link
Contributor

@NivekT NivekT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

torchdata/dataloader2/communication/iter.py Outdated Show resolved Hide resolved
@@ -50,3 +50,21 @@ Distributed
for d in dl:
model(d)
dl.shutdown()

Multiprocessing + Distributed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking but it will be nice to add this to our Colab example as well!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable to me.

@facebook-github-bot
Copy link
Contributor

@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ejguan merged this pull request in 89be152.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DistributedReadingService supports multi-processing reading
5 participants