-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
# What Add an option in the segmenter to segment resharding Expr in separate singleton segment. To trigger it, set the segmenter's options as follows: ``` SegmentCandidateFinderOptions options{ .run_translate_welford = false, .run_combine_reductions = false, .run_herrmann_merge = true, .run_final_merge = true, .only_segment_resharding_exprs = true}; ``` and use the segmenter as follows with any (possibly dummy) inputs: ``` KernelArgumentHolder dummy_inputs; auto segmented_fusion = SegmentCandidateFinder::segment(std::move(fusion), dummy_inputs, options); ``` If `only_segment_resharding_exprs` is set to `false` (which is the case by default), the behavior of the segmenter is unchanged. We also provide a quite wide testing suite to validate our implementation. # Why Resharding Exprs need to be handled differently than other Exprs because we want them to result in posting a network collective from the host. Therefore those expressions cannot (for now) be fused to any kernel. For this reason, we need those Expr to be segmented before and after. # How _**Remark:** For now, the segmenter is only used [at one place before scheduling and compiling the fusion](https://github.com/NVIDIA/Fuser/blob/1603f39bab8c1bbe12e38f2b5de53dec3b7cc373/csrc/kernel_cache.cpp#L990)._ Recall that the segmenter first creates as many segments as there are Expr and then tries to merge the neighbour segments incrementally in an eager manner. The method ``` bool SegmentCandidateFinder::codeGenSupportedMerge( SegmentedGroup* group1, SegmentedGroup* group2) ``` returns whether two groups can be merged (i.e. fused into one kernel). With the current patch, if `SegmentCandidateFinderOptions::only_segment_resharding_exprs` is set to `true`, then the usual behavior of `codeGenSupportedMerge` is bypassed and the function returns whether one Expr among the groups is resharding. Because this segmentation shouldn't depend on the inputs data, we use default (aka empty) `KernelArgumentHolder`, from which it is invalid to instantiate a `SchedulerRuntimeInfo runtime_info_`. For this reason, we had to make the latter attribute optional. # Future/other directions Another way to achieve the same result is to manually add segment bounds surrounding the resharding Exprs as was suggested by @wujingyue here #1571 The current implementation looks a bit "hacky" and should be be integrated more properly once multidevice schedulers are implemented and/or the segmenter is refactored. Later, we might wanna be able to fuse communications and computes and also communications between them. This would require a more advanced segmenter and scheduler, but hopefully this patch could serve as a good basis # Example: consider the fusion: ``` auto fusion = std::make_unique<Fusion>(); FusionGuard fg(fusion.get()); TensorView* tv0 = makeContigTensor({4}); fusion->addInput(tv0); TensorView* tv1 = sum(tv0,{3}); TensorView* tv2 = set(tv1); TensorView* tv3 = sum(tv2, {2}); fusion->addOutput(tv3); ``` Manually scheduled as follows: ``` DeviceMesh mesh ({0,1,2,3}) for (auto tv : {tv0, tv1, tv2, tv3}) { tv->setDeviceMesh(mesh); } tv0->axis(0)->parallelize(ParallelType::DIDx); tv1->axis(0)->parallelize(ParallelType::DIDx); ``` This scheduling implies that - `tv0` and `tv1` are fully sharded on the devices {0,1,2,3} - `tv2` and `tv3` are fully replicated on those same devices - consequently, the "set" operation on the line `tv2 = set(tv1)` actually embedds an "AllGather" network collective. This Expr is resharding while all the other exprs are not. We thus excpect this expression to constitute an unmergeable segment. The segmenter in this situation with the option`SegmentCandidateFinderOptions::only_segment_resharding_exprs` set to `true` will result in three segments: - Compute segment 1: with the expr `tv1 = sum(tv0,{3})` - Communication segment 1: with the expr `tv2 = set(tv1)` - Compute segment 2: with the expr `tv3 = sum(tv2, {2})`
- Loading branch information
1 parent
01ff3a2
commit 3e1e11e
Showing
6 changed files
with
75 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters