Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ReadingService] Add round robin sharding to support non-replicable DataPipe for Multiprocessing #919

Closed
wants to merge 18 commits into from

Conversation

ejguan
Copy link
Contributor

@ejguan ejguan commented Dec 9, 2022

This PR is created on top of #555. And, this PR extends PrototypeMultiprocessingReadingService to accept non-replicable DataPipe.

And, this PR depends on pytorch/pytorch#90769

Main Changes

  • Add a way to launch a process to fetch data from non-replicable DataPipes and send data to worker processes in a round-robin manner
    • Add ShardingRoundRobinDispatcher (functional name sharding_round_robin_dispatch) to indicate non-replicable DataPipe
    • Add MultipleDataPipesToQueuesLoop to connect non-sharding process to request/response queues
    • Add find_lca_non_replicable_dp as a graph function to determine the lowest common ancestor of all non-replicabble DataPipes. This would guarantee that all non-replicable DataPipes will be running in a single dispatching process
    • In each multiprocessing worker process,
      • If All datapipes are replicable, apply multiprocessing sharding to the graph
      • If not, worker would use find_replicable_branches to apply mp sharding to those replicable branches, because all non-replicable branches have been properly sharded by routing data round-robinly to worker processes.
  • Properly get ResetEpochResponse from protocol via get_response_reset_epoch
  • Add tests for two graph functions
  • Add test to launch non-shardable DataPipe process
  • Add documents

nit Changes

  • Rename Spawn to Create as the process has not been started

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 9, 2022
@@ -225,6 +225,16 @@ def get_response_reset_iterator(self, block=False):
if not isinstance(response, communication.messages.ResetIteratorResponse):
raise Exception("Invalid response received")

def get_response_reset_epoch(self, block=False):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why we didn't have this response before

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking but - hmmmm.... did this cause any bug or unhandled messages? Do you happen to know why?
I know you were looking into unusually messages and responses

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When new iteration starts, reset will be called for _IterateQueueDataPipes . And, an extra get_respondse_next is invoked to drop / response. So, all requests are served.

for dp in self.datapipes:
if dp.protocol.waiting_for_response():
dp.protocol.get_response_next(block=True)

I can give a try to remove this part.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems removing is fine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob question: any docs / entry pointer to understand how these dataloader2/communication design / work? ~

Copy link
Contributor Author

@ejguan ejguan Dec 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately no doc. I can talk about the components based on my understanding:

  • ProtocolClient remains in the main process to pass Request via the request_queue to the corresponding worker process
  • ProtocolServer is created in the worker process that takes request then send Response back to main process via reqponse_queue
  • DataPipeBehindQueues is the worker loop that holds a ProtocolServer to maniputlate DataPipe based on the Reqeust
  • QueueWrapper is the DataPipe that holds a ProtocolClient instance to issue Request and yield data from response_queue to the subsequent DataPipe graph.

We can talk about more detail offline if you want

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Something like this? ~

319563983_491898436260282_330096271851500173_n

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct.

Copy link
Contributor

@wenleix wenleix Dec 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try to draw it with mermaid

graph TD;
    Worker_1_1-->ProtocolServer_1_1-->ProtocolClient_1
    Worker_1_2-->ProtocolServer_1_2-->ProtocolClient_1
    Worker_1_3-->ProtocolServer_1_3-->ProtocolClient_1
    ProtocolClient_1-->GPU1
    ProtocolClient_2-->GPU2
    ProtocolClient_3-->GPU3
Loading

@facebook-github-bot
Copy link
Contributor

@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@NivekT NivekT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to need more time to look through this. What is the use case of non-shardable DataPipe?

Copy link
Contributor

@NivekT NivekT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should provide a definition of non-shardable DataPipe and why that may occur (an abstract example would be helpful as well). In particularly, as you mentioned, note that we do not want to duplicate such DataPipe into multiple workers and sharding filter should not be applied to it. Instead it should be read round-robin (or something else) by downstream DataPipes that are sharded?

@ejguan
Copy link
Contributor Author

ejguan commented Dec 14, 2022

I think we should provide a definition of non-shardable DataPipe and why that may occur (an abstract example would be helpful as well).

Make sense. I will add doc regarding non-shardbable DataPipe/shardable DataPipe to the documents for dataloader2.

In particularly, as you mentioned, note that we do not want to duplicate such DataPipe into multiple workers and sharding filter should not be applied to it. Instead it should be read round-robin (or something else) by downstream DataPipes that are sharded?

Actually, this is not True. sharding_filter is still needed if distributed sharding is needed. We only disabled multiprocessing sharding for the non-shardable branch. And, round-robin reading is also achieved automatically with ProtoTypeMPRS.

Edit: Updated the summary with a few topics need to be covered. Let me know if there is any other concern on the documentation.

@@ -225,6 +225,16 @@ def get_response_reset_iterator(self, block=False):
if not isinstance(response, communication.messages.ResetIteratorResponse):
raise Exception("Invalid response received")

def get_response_reset_epoch(self, block=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking but - hmmmm.... did this cause any bug or unhandled messages? Do you happen to know why?
I know you were looking into unusually messages and responses

@@ -84,9 +153,22 @@ def process_reset_fn(
reset the random state of the ``DataPipe`` graph and the global random states for ``torch``,
``random`` and ``numpy``.
"""
# Reset non-sharding process first
graph = traverse_dps(datapipe)
non_sharding_process_dps = find_dps(graph, communication.iter._IterateQueueDataPipes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Is it always the case where _IterateQueueDataPipes is the only non-sharding process?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes because we will find the lowest common ancestor of all non-shardable DataPipes in the main process and replace it by this _IterateQueueDataPipes in the worker process.
So, it's guaranteed that there is only a single non-sharding process.

@ejguan
Copy link
Contributor Author

ejguan commented Dec 15, 2022

@NivekT
Thanks to your question regarding example of non-shardable DataPipe. There are two kinds of non-shardable DataPipe. However, we have to treat them differently.

  1. non-shardable data source that connects to a remote system that only accept one client per machine.
    For non-shardable data source, we need to launch the non-sharding process to transfer data from non-shardable data source to worker processes
  2. fullsync DataPipe that should not run in worker process.
    This DataPipe is normally appended at the end of pipeine. If we use the lowest common ancestor of all non-shardable DataPipe as the non-shardable partial graph, the whole pipeline would be sent to non-sharding process. We have to make sure fullsync is attached to the output of worker processes in the main process

I propose to change to find lowest common ancestor of non-replicable Data Source and sent them to non-sharding process.
Simple example:

graph TD;
DP1(non-replicable DP1)-->DP2;
DP2-->DP5;
DP3(non-replicable DP3)-->DP4;
DP4-->DP5;
DP5-->DP6;
DP6-->fullsync;
fullsync-->output;
Loading

The lowest common ancestor of all non-shardable Data Source is DP5. If we include fullsync, the whole pipeline becomes non-shardable. Given the above example, the non-sharding process should contain two branches until DP5 and worker process should execute only DP6 and at the end of pipeline fullsync is attach in the main process

@NivekT
Copy link
Contributor

NivekT commented Dec 16, 2022

That makes sense. How do you plan to implement that? Will we still have users calling .fullsync() and make modification to the graph in DL2? Or making full_sync an option within DL2? Probably the former?

@ejguan
Copy link
Contributor Author

ejguan commented Dec 16, 2022

How do you plan to implement that? Will we still have users calling .fullsync() and make modification to the graph in DL2?

We should allow either users calling fullsync or let DistributedRS attach fullsync automatically. Then, ProtoMPRS should do graph modification. WDYT? I am not going to implement it in this PR for fullsync though.

@wenleix
Copy link
Contributor

wenleix commented Dec 19, 2022

Replying to #919 :

Actually, this is not True. sharding_filter is still needed if distributed sharding is needed. We only disabled multiprocessing sharding for the non-shardable branch. And, round-robin reading is also achieved automatically with ProtoTypeMPRS.

An alternative approach for distributed sharding would be distribute the workload based on filename or some compression/encoding unit in file (in Parquet it's called "Page": https://parquet.apache.org/docs/concepts/) and in ORC I think it's called "Stripe". So it avoid reading the original data multiple times?

test/dataloader2/test_dataloader2.py Outdated Show resolved Hide resolved
@@ -225,6 +225,16 @@ def get_response_reset_iterator(self, block=False):
if not isinstance(response, communication.messages.ResetIteratorResponse):
raise Exception("Invalid response received")

def get_response_reset_epoch(self, block=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob question: any docs / entry pointer to understand how these dataloader2/communication design / work? ~

@ejguan
Copy link
Contributor Author

ejguan commented Dec 20, 2022

Non-shardable is an extremely bad name as it's actually being sharded by round-robin dispatching. The actual meaning here is to prevent copy of DataPipe to multiple processes. I might rename it to non-replicable DataPipe/dispatching process.

@NivekT
Copy link
Contributor

NivekT commented Dec 20, 2022

Non-shardable is an extremely bad name as it's actually being sharded by round-robin dispatching. The actual meaning here is to prevent copy of DataPipe to multiple processes. I might rename it to non-replicable DataPipe/dispatching process.

I agree with renaming "non-shardable" to "non-replicated" DataPipe. I suppose sometimes it is replicable but the users won't want to?

@ejguan ejguan changed the title [ReadingService] Add mp support for non-shardable DataPipe [ReadingService] Add round robin sharding to support non-replicable DataPipe for Multiprocessing Dec 20, 2022
@ejguan
Copy link
Contributor Author

ejguan commented Dec 20, 2022

@wenleix @NivekT This PR has been updated. And updated document can be found in https://ejguan.github.io/dataloader2.html#dynamic-sharding

Comment on lines +61 to +63
# Lazily import to prevent circular import
from torchdata.dataloader2 import communication

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my temporary fix for the circular import problem.
cc: @NivekT

Copy link
Contributor Author

@ejguan ejguan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following review steps are added to make it easier to review the main logic in the PR. Let me know if there is anything fuzzy to you.

Comment on lines +16 to +17
@functional_datapipe("sharding_round_robin_dispatch")
class ShardingRoundRobinDispatcherIterDataPipe(IterDataPipe):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Step 1: Add ShardingRoundRobinDispatcher is introduced to indicate where the pipeline should be non-replicable.

I am open to any suggestion on the name/functional name

Comment on lines +41 to +42
def __iter__(self) -> Iterator[T_co]:
yield from self.source_datapipe
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Step 1.1: Keep __iter__ as a noop here rather than raising Error to support single-process use case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by "single-process use case". does it mean "eager mode"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, pure eager currently. In the future, we might provide a by-default SingleProcessReadingService for users.

res_queue: Queue


def find_lca_non_replicable_dp(graph: DataPipeGraph) -> Optional[DataPipe]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Step 2: Add this graph function to find the lowest common ancestor of the non-replicable DataPipes (ShardingRoundRobinDispatcher)

graph = traverse_dps(end_dp)
return single_br_dp, multi_br_dp, ch1, ch2, fork_zip_dp, cir_br_dp, cir_map_dp, end_dp, graph

def test_single_non_replicable_dp(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Step 3.1: Tests for single non-replicable DataPipe

graph, cir_map_dp = make_dp_non_replicable(graph, cir_map_dp)
self.assertEqual(find_lca_non_replicable_dp(graph), cir_map_dp)

def test_multi_non_replicable_dps(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Step 3.2: Tests for multiple non-replicable DataPipes

@@ -91,3 +131,20 @@ def SpawnThreadForDataPipeline(datapipe):

process = threading.Thread(target=DataPipeToQueuesLoop, args=(new_datapipe, req_queue, res_queue), daemon=True)
return process, req_queue, res_queue, new_datapipe


def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Step 7.1: Create num_workers pairs of req_queue and res_queue.
And launch MultipleDataPipesToQueuesLoop to iterate over the non-replicable DataPipe

]


def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, call_on_process_init=None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review 7.2: Launch a non-blocking DataPipeBehindQueues while-loop per child DataPipe from round_robin_demux.
Using zip_longest to mimic round robin calling next over each child DataPipe.

Comment on lines +239 to +243
# Dispatching process for non-replicable DataPipes exists
if self._dispatch_process is not None:
# Use the placehold to pass request/response queue to each worker process
dummy_dp.req_queue = self._dispatch_process[1][worker_id]
dummy_dp.res_queue = self._dispatch_process[2][worker_id]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review 6.2: We only have one _DummyIterDataPipe in the main process but have num_workers pairs of req_queue and res_queue. To connect a pair to the corresponding worker process, inject the attributes from _DummyIterDataPipe before sending it to the worker process.

Comment on lines +68 to +97
# Find if there is non-replicable DataPipe
graph = traverse_dps(datapipe)
non_replicable_dp = find_dps(graph, _DummyIterDataPipe) # type: ignore

# There are two cases for DataPipe graph in terms of mp sharding:
# 1) All DataPipes are replicable, apply mp sharding to the whole graph
if len(non_replicable_dp) == 0:
torch.utils.data.graph_settings.apply_sharding(
datapipe, worker_info.num_workers, worker_info.worker_id, SHARDING_PRIORITIES.MULTIPROCESSING
)
# 2) There is non-replicable DataPipe. Since we have replaced the lowest common
# ancestor by a `_DummyIterDataPipe`, we would only apply mp sharding
# to replicable branches that don't have `_DummyIterDataPipe`.
else:
assert len(non_replicable_dp) == 1
replicable_branches = find_replicable_branches(graph)
for dp in replicable_branches:
torch.utils.data.graph_settings.apply_sharding(
dp, worker_info.num_workers, worker_info.worker_id, SHARDING_PRIORITIES.MULTIPROCESSING
)

req_queue = non_replicable_dp[0].req_queue
res_queue = non_replicable_dp[0].res_queue

queue_wrapper = communication.iter.QueueWrapper(
communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)
)
dispatch_process_dp = communication.iter._IterateQueueDataPipes([queue_wrapper])
graph = replace_dp(graph, non_replicable_dp[0], dispatch_process_dp)
datapipe = list(graph.values())[0][0]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Step 8: In the worker process, find if there is _DummyIterDataPipe.
If not, it means the whole pipeline is replicable and do the sharding by filter
If there is, we would do sharding only over the replicable branches.

QueueWrapper and _IterateQueueDataPipes is used to wrap res_queue and req_queue as a DataPipe that can handle Request and Response based on the protocol.

Comment on lines 121 to 142
def dispatch_process_reset_fn(
datapipe: DataPipe,
worker_info: WorkerInfo,
dist_info: _DistInfo,
) -> DataPipe:
r"""
Based on the distributed shared random seed, this function is used to set the random state
of the non-repliable ``DataPipe`` graph and the global random states for the dispatch process.
This function would guarantee that all distributed non-sharding process share the
same random states to ensure the same shuffle order.
"""
worker_seed_generator = torch.Generator()
worker_seed_generator.manual_seed(dist_info.shared_seed)
torch.utils.data.graph_settings.apply_random_seed(
datapipe,
worker_seed_generator,
)

# Set global random states
_set_global_random_state(worker_seed_generator)

return datapipe
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Step 9: When new epoch starts, we want to control the random seed based on distributed information.
We need to guarantee all distributed dispatching process share the same random seed.

Copy link
Contributor

@wenleix wenleix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Review Step 1: Add ShardingRoundRobinDispatcher is introduced to indicate where the pipeline should be non-replicable."

LGTM. Name (sharding_round_robin_dispatch) is a bit long but let's keep it for now...

Copy link
Contributor

@wenleix wenleix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Review Step 7: create req_queue and res_queueand execute with round_robin_demux"

High-level control flow looks good. But didn't have enough low-level context yet~

torchdata/dataloader2/communication/eventloop.py Outdated Show resolved Hide resolved
torchdata/dataloader2/communication/eventloop.py Outdated Show resolved Hide resolved
res_queues
), "``MultipleDataPipesToQueuesLoop`` requires the same number of datapipes, request queues and response queues"

torch.set_num_threads(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob question: what's this for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, this was introduced to disable OpenMP in dataloader workers. This is because OpenMP would create number of threads that equals to the number of CPU cores by default. And, with multiprocessing enabled, num_workers x num_threads_per_worker threads will be created. This won't provide any further benefit.

Besides, OpenMP features should not be enabled if any OpenMP features are utilized in the main process and before subprocesses are forked.

Any suggestion?

Copy link
Contributor

@wenleix wenleix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Review Step 8: Worker process handling (graph rewrite and receiving demux result from dispatch process)"

LGTM.

queue_wrapper = communication.iter.QueueWrapper(
communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)
)
dispatch_process_dp = communication.iter._IterateQueueDataPipes([queue_wrapper])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC:

IterDataPipeQueueProtocolClient will be wrapped into a QueueWrapper (but still not a IterDataPipe), and further wrapped into a _IterateQueueDataPipes which is a IterDataPIpe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. Both QueueWrapper and _IterateQueueDataPipes are IterDataPipe, this is one of the thing that we can optimize later.

Copy link
Contributor

@wenleix wenleix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Step 9: LGTM % minor question...

torchdata/dataloader2/utils/worker.py Outdated Show resolved Hide resolved
torchdata/dataloader2/utils/worker.py Outdated Show resolved Hide resolved
@facebook-github-bot
Copy link
Contributor

@ejguan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ejguan merged this pull request in e15e145.

yield response.value

def reset(self):
# NonBlocking DataPipes do not reset automatically, have to do it manually
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ejguan Just noticed that the reset method has changed after the move. It used to have this:

# Collect all existing requests results to clear queues
        for dp in self.datapipes:
            if dp.protocol.waiting_for_response():
                dp.protocol.get_response_next(block=True)

Is this no longer necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's necessary because we always want to do reset_epoch as well for NonBlocking, which will discard all existing requests. So, when at the point of reset, we should expect no request within the worker process queues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged topic: new feature topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants