Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A test that surrounds a sharding expression with segment_sets. #1571

Merged
merged 1 commit into from
Jan 4, 2024

Conversation

wujingyue
Copy link
Collaborator

As a follow-up to our previous meeting about segmentation for multi-device, I wrote this PR showing how to surround a sharding expression with segment_sets. This way, a sharding expression becomes its own segment without changing the current segmentation algorithm. I expect a pre-segmenter pass to add such segment_sets. See MarkAliasPreparePass for an example. I believe this is good enough for segmenting out CPU-initiated communication.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice example. Thanks for adding this.

@samnordmann
Copy link
Collaborator

samnordmann commented Jan 4, 2024

Thank you very much! That's very interesting, I wasn't aware of the Segment_Set op type.

I think this way of doing has the advantage of being very is simple and of leaving the segmenter untouched.
However, I think in the future it won't be sufficient and we will need to modify the segmenter anyway (when we want to merge communications together, or comms + compute). So, later we will probably need to build upon the other method I proposed.

Another minor drawback of this method is that it will add a lot of trivial Exprs (i.e. Segmenter_Set) and Tvs which need to be discarded so it may make the analysis more cumbersome.
What do you think ?

In a near future, I will open a PR with the segmentation for multidevice, we will have to decide what method to use.
Which one do you think is preferable?

@wujingyue
Copy link
Collaborator Author

However, I think in the future it won't be sufficient and we will need to modify the segmenter anyway (when we want to merge communications together, or comms + compute)

The current implementation of segments heavily assumes that one segment maps to one CUDA kernel, and we don't nest one kernel into another. So I prefer not making segments hierarchical at this moment. This saves all of us development and maintenance.

That being said, I agree with you that we'll probably need to represent an expression group better without assuming how it'll be translated and executed. Are you back from your vacation? Can we talk about more about the future use cases next week? I'm very curious about use cases of merging communications themselves and merging communications and compute. We should definitely get nvFuser to support them.

Another minor drawback of this method is that it will add a lot of trivial Exprs (i.e. Segmenter_Set) and Tvs which need to be discarded so it may make the analysis more cumbersome.

Yes. That's usually solved by making analysis handle Segmenter.Set and/or by running a clean-up pass immediately after segmentation (because Segmenter.Set is useless afterwards).

@wujingyue wujingyue merged commit 91f10c8 into main Jan 4, 2024
4 checks passed
@wujingyue wujingyue deleted the wjy/shard branch January 4, 2024 17:58
@wujingyue
Copy link
Collaborator Author

Here's the summary of today's meeting.

We (@cowanmeg , @naoyam , @samnordmann , and @wujingyue) talked about the current state of @samnordmann 's draft nvFuser fork for multi-device and why certain things are implemented in certain ways. This is mainly to help me understand how segmentation needs to be improved for the multi-device use case.

Facts about the code:

  1. MultiDeviceExecutor (instead of FusionExecutorCache or FusionExecutor) is the entry point of compilation and execution, for multi-device.
  2. MultiDeviceExecutor does a partial segmentation, which segments only sharding ops.
  3. Sharding-op-only segments are compiled and executed by postCommunication and other segments by postKernel. The current implementation of postKernel has a bug where it uses FusionExecutor (not FusionExecutorCache) to compile and execution a fusion. This will fail for some large fusions that have to be segmented.

My takeaways:

  1. Device meshes are at this moment annotated by the user of nvFuser. We will want annotation to be done automatically. However, this will need to be done before segmentation and therefore not by (the current) scheduling, because device meshes affect segmentation decisions.
  2. I don't see a case for nested SegmentedGroups. A flat DAG of SegmentedGroups suffices for both CPU-initiated and GPU-initiated multi-device communication.
  3. We may need a multi-level segmentation. The current implementation applies segmentation in two levels (as mentioned above about the code). I think we could avoid even that for simplicity. One idea floating around in my head is to create a CommunicationScheduler to accept sharding-op-only Fusions and to let lowering lower such a fusion to a vector of communications. This way, we can reuse Executor for multi-device execution. Also, this plays better with GPU-initiated communication, which has to be lowered to code inside a CUDA kernel.

@naoyam
Copy link
Collaborator

naoyam commented Jan 11, 2024

If we do a multi-level segmentation, wouldn't it make sense to have SegmentedGroups to be further segmented? Wouldn't it make sense to have some form of nested groups?

@wujingyue
Copy link
Collaborator Author

If we do a multi-level segmentation, wouldn't it make sense to have SegmentedGroups to be further segmented?

Yes, it's done by calling makeFusion and then potentially calling FusionExecutorCache::runFusionWithInputs.

Wouldn't it make sense to have some form of nested groups?

I doubt that. What's the practical benefit? Do we need to traverse from a parent SegmentedGroup to a child or from child to parent?

@naoyam
Copy link
Collaborator

naoyam commented Jan 11, 2024

Wouldn't it make sense to have some form of nested groups?

I doubt that. What's the practical benefit? Do we need to traverse from a parent SegmentedGroup to a child or from child to parent?

I don't have any concrete use case, but I'm just feeling that since there's nested structures it seems to make sense to represent the nesting in some way. That said, I also agree that we shouldn't make things overly complicated than necessary.

@samnordmann
Copy link
Collaborator

Thank you very much @wujingyue ! Very useful discussion and summary

About nested SegmentedGroups, I must say I have the same intuition as Naoya, I think it would be a natural and useful thing to have. If we are thinking of complex pipelinings, we would probably need to represent different segments of the fusion at different scales.

Since we are already using several levels of segmentations, why would it be overcomplicated to use nested SegmentedGroup? For now, we need to transform the SegmentedGroup to a Fusion before segmenting it again, which seems to me unnatural

@wujingyue wujingyue added the testing e.g. improving test infra and test coverage label Jan 12, 2024
@naoyam
Copy link
Collaborator

naoyam commented Jan 16, 2024

As I think more about this, I'm getting feeling that separating inter-device and intra-device segmentations as two separate processes may not be what we eventually would like to do. I think that, for a given fusion, how we should distribute tensors is likely to depend on how we could schedule computations within a device. That is, the intra-device segmentation/scheduling would affect the intra-kernel segmentation. While there's certainly conceptual hierarchy between the two segmentations, in reality the hierarchy may be blurred as we may need to think about the inter and intra device segmentations as a whole, so the enforcing the hierarchy between the two may not make much sense.

samnordmann added a commit that referenced this pull request Feb 1, 2024
# What
Add an option in the segmenter to segment resharding Expr in separate
singleton segment.
To trigger it, set the segmenter's options as follows:
```
    SegmentCandidateFinderOptions options{
        .run_translate_welford = false,
        .run_combine_reductions = false,
        .run_herrmann_merge = true,
        .run_final_merge = true,
        .only_segment_resharding_exprs = true};
```
and use the segmenter as follows with any (possibly dummy) inputs:
```
KernelArgumentHolder dummy_inputs;
auto segmented_fusion = SegmentCandidateFinder::segment(std::move(fusion), dummy_inputs, options);
```
If `only_segment_resharding_exprs` is set to `false` (which is the case
by default), the behavior of the segmenter is unchanged.


We also provide a quite wide testing suite to validate our
implementation.

# Why 
Resharding Exprs need to be handled differently than other Exprs because
we want them to result in posting a network collective from the host.
Therefore those expressions cannot (for now) be fused to any kernel. For
this reason, we need those Expr to be segmented before and after.

# How
_**Remark:** For now, the segmenter is only used [at one place before
scheduling and compiling the
fusion](https://github.com/NVIDIA/Fuser/blob/1603f39bab8c1bbe12e38f2b5de53dec3b7cc373/csrc/kernel_cache.cpp#L990)._

Recall that the segmenter first creates as many segments as there are
Expr and then tries to merge the neighbour segments incrementally in an
eager manner. The method
```
bool SegmentCandidateFinder::codeGenSupportedMerge(
    SegmentedGroup* group1,
    SegmentedGroup* group2) 
```
returns whether two groups can be merged (i.e. fused into one kernel). 

With the current patch, if
`SegmentCandidateFinderOptions::only_segment_resharding_exprs` is set to
`true`, then the usual behavior of `codeGenSupportedMerge` is bypassed
and the function returns whether one Expr among the groups is
resharding.

Because this segmentation shouldn't depend on the inputs data, we use
default (aka empty) `KernelArgumentHolder`, from which it is invalid to
instantiate a `SchedulerRuntimeInfo runtime_info_`. For this reason, we
had to make the latter attribute optional.

# Future/other directions

Another way to achieve the same result is to manually add segment bounds
surrounding the resharding Exprs as was suggested by @wujingyue here
#1571

The current implementation looks a bit "hacky" and should be be
integrated more properly once multidevice schedulers are implemented
and/or the segmenter is refactored.

Later, we might wanna be able to fuse communications and computes and
also communications between them. This would require a more advanced
segmenter and scheduler, but hopefully this patch could serve as a good
basis

# Example:
consider the fusion:
```
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());

  TensorView* tv0 = makeContigTensor({4});
  fusion->addInput(tv0);
  TensorView* tv1 = sum(tv0,{3});
  TensorView* tv2 = set(tv1);
  TensorView* tv3 = sum(tv2, {2});
  fusion->addOutput(tv3);
```

Manually scheduled as follows:
```
  DeviceMesh mesh ({0,1,2,3})
  for (auto tv : {tv0, tv1, tv2, tv3}) {
    tv->setDeviceMesh(mesh);
  }
  tv0->axis(0)->parallelize(ParallelType::DIDx);
  tv1->axis(0)->parallelize(ParallelType::DIDx);
```
This scheduling implies that
- `tv0` and `tv1` are fully sharded on the devices {0,1,2,3}
- `tv2` and `tv3` are fully replicated on those same devices
- consequently, the "set" operation on the line `tv2 = set(tv1)`
actually embedds an "AllGather" network collective. This Expr is
resharding while all the other exprs are not. We thus excpect this
expression to constitute an unmergeable segment.

The segmenter in this situation with the
option`SegmentCandidateFinderOptions::only_segment_resharding_exprs` set
to `true` will result in three segments:
- Compute segment 1: with the expr `tv1 = sum(tv0,{3})`
- Communication segment 1:  with the expr `tv2 = set(tv1)`
- Compute segment 2: with the expr `tv3 = sum(tv2, {2})`
cowanmeg added a commit to samnordmann/Fuser that referenced this pull request Feb 13, 2024
* print bandwidth when perf_debug_verbose is true (NVIDIA#1689)

print bandwidth when `perf_debug_verbose` is true.

* in vectorization validation, add err msg if tv has no definition (NVIDIA#1690)

check the existence of tv definition in vectorization validation

* Accomodate Reduction IterDomains when concretizing reshape extents (NVIDIA#1692)

We register extents for concretization when we concretize reshape. In
order to do that, we line up `IterDomain`s in the symbolic reshaped TV
and the new, concretized one. In cases where the concretized reshape is
trivial, such as when the output shape is the same as the input, we do
not create a new TV. In those cases, we will have the input to the
original `ViewOp` as the concretized output. That input TV might have
reduction domains, as in the provided test, in which case we need to
filter those out when doing this alignment. This small PR just
implements that filtering.

Fixes NVIDIA#1691.

* `MmaOp::evaluate` method (NVIDIA#1675)

* Fix some typos. (NVIDIA#1700)

* `torch.compile` and `eager` benchmarks for `softmax` (NVIDIA#1670)

Adds `torch.compile` and `eager` baseline benchmarks to be used in
weekly benchmark runs.
Issue NVIDIA#1668.

* Add a test for fusions with no inputs. (NVIDIA#1709)

As a follow up to
NVIDIA#1696 (comment).

* Double the size of the fusion cache to workaround a CI issue. (NVIDIA#1702)

By just removing entries when it fills up.

* Check that the reduced axis is sharded on producer in isLowerableToCommunication (NVIDIA#1695)

Currently, a reduction is lowerable to a communication iff only one axis
is reduced and this axis is sharded across devices on the **producer**
side.
Before this patch, we would mistakenly check that the axis is sharded on
**consumer** side, which led to some runtime assert error.

* Add blank impl of isLowerableToCommunication. (NVIDIA#1698)

isLowerableToCommunication is used in a few places to print error
messages or short-circuit loops. Those places appear to be places that
are intended to largely be used behind the distributed path. It's easier
to just define the API instead of trying to conditionalize all the use
sites and invent non-USE_DISTRIBUTED behavior.

* Multidevice segmenter (NVIDIA#1696)

# What
Add an option in the segmenter to segment resharding Expr in separate
singleton segment.
To trigger it, set the segmenter's options as follows:
```
    SegmentCandidateFinderOptions options{
        .run_translate_welford = false,
        .run_combine_reductions = false,
        .run_herrmann_merge = true,
        .run_final_merge = true,
        .only_segment_resharding_exprs = true};
```
and use the segmenter as follows with any (possibly dummy) inputs:
```
KernelArgumentHolder dummy_inputs;
auto segmented_fusion = SegmentCandidateFinder::segment(std::move(fusion), dummy_inputs, options);
```
If `only_segment_resharding_exprs` is set to `false` (which is the case
by default), the behavior of the segmenter is unchanged.


We also provide a quite wide testing suite to validate our
implementation.

# Why 
Resharding Exprs need to be handled differently than other Exprs because
we want them to result in posting a network collective from the host.
Therefore those expressions cannot (for now) be fused to any kernel. For
this reason, we need those Expr to be segmented before and after.

# How
_**Remark:** For now, the segmenter is only used [at one place before
scheduling and compiling the
fusion](https://github.com/NVIDIA/Fuser/blob/1603f39bab8c1bbe12e38f2b5de53dec3b7cc373/csrc/kernel_cache.cpp#L990)._

Recall that the segmenter first creates as many segments as there are
Expr and then tries to merge the neighbour segments incrementally in an
eager manner. The method
```
bool SegmentCandidateFinder::codeGenSupportedMerge(
    SegmentedGroup* group1,
    SegmentedGroup* group2) 
```
returns whether two groups can be merged (i.e. fused into one kernel). 

With the current patch, if
`SegmentCandidateFinderOptions::only_segment_resharding_exprs` is set to
`true`, then the usual behavior of `codeGenSupportedMerge` is bypassed
and the function returns whether one Expr among the groups is
resharding.

Because this segmentation shouldn't depend on the inputs data, we use
default (aka empty) `KernelArgumentHolder`, from which it is invalid to
instantiate a `SchedulerRuntimeInfo runtime_info_`. For this reason, we
had to make the latter attribute optional.

# Future/other directions

Another way to achieve the same result is to manually add segment bounds
surrounding the resharding Exprs as was suggested by @wujingyue here
NVIDIA#1571

The current implementation looks a bit "hacky" and should be be
integrated more properly once multidevice schedulers are implemented
and/or the segmenter is refactored.

Later, we might wanna be able to fuse communications and computes and
also communications between them. This would require a more advanced
segmenter and scheduler, but hopefully this patch could serve as a good
basis

# Example:
consider the fusion:
```
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());

  TensorView* tv0 = makeContigTensor({4});
  fusion->addInput(tv0);
  TensorView* tv1 = sum(tv0,{3});
  TensorView* tv2 = set(tv1);
  TensorView* tv3 = sum(tv2, {2});
  fusion->addOutput(tv3);
```

Manually scheduled as follows:
```
  DeviceMesh mesh ({0,1,2,3})
  for (auto tv : {tv0, tv1, tv2, tv3}) {
    tv->setDeviceMesh(mesh);
  }
  tv0->axis(0)->parallelize(ParallelType::DIDx);
  tv1->axis(0)->parallelize(ParallelType::DIDx);
```
This scheduling implies that
- `tv0` and `tv1` are fully sharded on the devices {0,1,2,3}
- `tv2` and `tv3` are fully replicated on those same devices
- consequently, the "set" operation on the line `tv2 = set(tv1)`
actually embedds an "AllGather" network collective. This Expr is
resharding while all the other exprs are not. We thus excpect this
expression to constitute an unmergeable segment.

The segmenter in this situation with the
option`SegmentCandidateFinderOptions::only_segment_resharding_exprs` set
to `true` will result in three segments:
- Compute segment 1: with the expr `tv1 = sum(tv0,{3})`
- Communication segment 1:  with the expr `tv2 = set(tv1)`
- Compute segment 2: with the expr `tv3 = sum(tv2, {2})`

* Vectorization Factor patch for computeInfoC2P with Broadcast in mapped IterDomain (NVIDIA#1625)

Fixes NVIDIA#1567

This PR patches vectorization factor in
`ContiguousInnerDimensionsMapper::computeInfoC2P`.

Handling of resolved broadcast dimension should be made on mapped
consumer tensors' from_ids, instead of the root_domain order. Added a
few tests per @zasdfgbnm 's suggestion:

```
Case 0:
T2[1024, 2, 512] = T0[1024, 2, 1] + T1[1024, 2, 512]
allocation = rfactor
--> T0 has no vectorization

Case 1:
T2[1024, 512, 2] = T0[1024, 1, 2] + T1[1024, 512, 2]
allocation = rfactor
--> T0 has vectorization 2

Case 2:
T2[1024, 512, 2] = T0[1024, 1, 2] + T1[1024, 512, 2];
T3[512, 1024, 2] = transpose(T2[1024, 512, 2])
allocation = rfactor
*except T1 has stride_order {1, 2, 0}
--> T0 has vectorization 4

Case 3:
T2[512, 1024, 2] = T0[1, 1024, 2] + T1[512, 1024, 2]
T3[1024, 512, 2] = transpose(T2[512, 1024, 2])
allocation = rfactor
--> T0 has vectorization 2
```

---------

Co-authored-by: Jacob Hinkle <[email protected]>
Co-authored-by: Gao, Xiang <[email protected]>

* transpose scheduler fix: reduction IterDomain on input tensors (NVIDIA#1661)

Fixes NVIDIA#1659 

Reorders reduction IterDomain so it won't interfere with
scheduling tiling from transpose scheduler.

* Convert reduction of expanded dims to squeeze (NVIDIA#1679)

See comment in arith.cpp for details.

One controversial change here is to allow squeezing expanded dimensions,
both in our IR's `SqueezeOp` and in the user-facing functions `squeeze`.
This results in actually removing those dimensions. This behavior
diverges from PyTorch, whose `squeeze` command will ignore requested
squeezes if the size is not 1 regardless of whether that dimension is
expanded. I'm happy to discuss this change and potentially take another
course, but I think we do need to be able to remove expanded axes (see
NVIDIA#1174 (comment) for
another case where I encountered this limitation).

Fixes NVIDIA#1678

* Make sure ValGraphs are created deterministically (NVIDIA#1714)

While I was working on NVIDIA#32, I sometimes saw non-deterministic results.
Hope this is the only source of non-determinism.

* Fix squeeze-related errors (NVIDIA#1717)

This fixes current failures in `pytest_ops.py -k squeeze` and some
integration failues.

This restores our previous semantics for squeeze, which **do not match
PyTorch**. Namely, if squeeze is provided a dimension that cannot be
squeezed, we will always raise an error.

* NVFUSER_DISTRIBUTED instead of USE_DISTRIBUTED (NVIDIA#1711)

* Add the missing `clang-format on` and reformat. (NVIDIA#1722)

* Print a newline before the header. (NVIDIA#1720)

* Associate each fusion cache with its local rank in distributed setting. (NVIDIA#1699)

### Problem:
Currently, automatic serialization saves a single cache regardless of
the number of devices. In a distributed setting, each process restores
its fusion cache from the same common workspace. However, this workspace
only contains the CUDA kernels for a single device. The remaining
processes must recompile the kernels for their devices.

### Solution:
A separate process is created for each device with `ddp` or `fsdp` and
each process contains a separate `FusionCache`. This PR associates each
fusion cache with its local rank in a distributed setting, allowing
automatic serialization to create a separate workspace for each device.
During deserialization, each process loads the workspace associated with
its local rank.

* Vectorized serial grid reduction (NVIDIA#1528)

This change allows us to use vectorized loads/stores in
`serialReductionStep`. The generated kernel now looks like
```c++
  NVFUSER_UPDATE_MAGIC_ZERO;                                        
  grid_sync::blockSerializeWait<false, false, true>(&T5[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
  #pragma unroll                                                                                                                         
  for(nvfuser_index_t i16 = 0; i16 < 4LL; ++i16) {                                                                                           nvfuser_index_t i17;                                                                                                                 
    i17 = 32LL * i16;                                                                                                                        nvfuser_index_t i18;                                                                                                                 
    i18 = 4096LL * i16;                                                                                                                  
    nvfuser_index_t i19;                                                                                                                 
    i19 = i5 + i18;                                                                                                                      
    nvfuser_index_t i20;                                                                                                                 
    i20 = -i18;                                                                                                                          
    #pragma unroll                                                                                                                       
    for(nvfuser_index_t i21 = 0; i21 < 8LL; ++i21) {                                                                                     
      nvfuser_index_t i22;                                                                                                               
      i22 = 512LL * (i21 + nvfuser_zero);                                                                                                
      Array<float, 4LL, 4> T3;                                                                                                           
      T3.set(float(0.000000000e+00f));                                                                                                   
      reduction::serialReductionStep</*vec_size=*/4>(                                                                                    
        &T3[0LL],                                                                                                                        
        &T2[(i17 + (4LL * i21))],                                                                                                        
        0.000000000e+00f,                                                                                                                
        &T6[(i19 + i22)],                                                                                                                
        [](float &a, float b) { a = a + b; },                                                                                            
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
        true,                                                                                                                                    true);                                                                                                                           
      if ((b7 && (i6 < (i20 - i22)))) {                                                                                                  
        loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T1[(i19 + i22)], &T3[0LL]);                                    
      }                                                                                                                                  
    }                                                                                                                                    
  }                                                                                                                                      
  grid_sync::blockSerializeRelease<false, false, true>(&T5[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);            
  NVFUSER_UPDATE_MAGIC_ZERO;       
```

* removing out-dated assert on python API (NVIDIA#1724)

removing out-dated asserts in python API `define_vector`;
adding a tests verifying the behavior

* make ci green again (NVIDIA#1730)

skip failing test.

Please enable it once we patch NVIDIA#1728

* Remove unnecessary `MATCHER_P`. (NVIDIA#1729)

* Fix Issue NVIDIA#1734 (NVIDIA#1735)

Closes Issue NVIDIA#1734

* Rename `AliasType` -> `AllocationType` (NVIDIA#1732)

* Skip executing a kernel if it's empty. (NVIDIA#1723)

I could change `compileFusion` to skip compilation as well. It turned
out to be more complicated than I expected, so I took the easier route
to skip just execution, which is at least an incremental improvement.

* don't cache slice input tv (NVIDIA#1705)

If the input tv is used by slice, don't cache it.
Fix NVIDIA#1697

* Make `MmaOp::evaluate` return output of the same dtype as `MmaOp` (NVIDIA#1733)

* Turing/Ampere Mma tests without `BroadcastOp` (NVIDIA#1672)

This PR renames `matmulAtInput` into `matmulAtInput2D`, explicitly
showing that it generates 2D inputs. This PR also adds a
`matmulAtInput3DTuring`, which is used to generate the 3D fusion inputs
(for example `[M, 1, K]` and `[1, K, N]`) for matmul. The `MmaTest` for
Turing and Ampere is modified to exclude the `BroadcastOp` and use the
3D version for generating fusion inputs. This is only the initial step
for making `scheduleMatmul` schedule a fusion not containing
`BroadcastOp`, I intentionally keep it small. Other changes will be
added in followup PRs.

Fixes NVIDIA#1628

* io_alias_ const update (NVIDIA#1740)

* Add benchmarks for RoPE. (NVIDIA#1739)

This PR adds two implementations of the RoPE module and benchmarks them
for NVIDIA#1597.

`rope_with_cat_fusion` mimics the Hugging Face implementation.
`rope_without_cat_fusion` implements an idea from @nikitaved to avoid
concatenation. Even though it looks difficult for the compiler to do it
all automatically, it's still useful to keep a record of the idea.

As a side change, I made `fd.define_tensor` to accept empty contiguity.

* Make nvfuser matmul benchmarks HSH instead of HSS (NVIDIA#1712)

This matches the `at::matmul` baselines.

This PR also adds a few more problem sizes, and runs each eagermode
baseline with and without FP16 reduction allowed.

* Reduce number of `MmaTest`s (NVIDIA#1738)

This PR is stacked on top of NVIDIA#1672

Turing/Ampere mma is only TN, so it makes no sense to test other layouts
in `MmaTest`s. These tests are intended to test mma instructions,
`ldmatrix` and `ldmatrix.trans` is tested separately in other unit
tests. Similar for `HopperRS` tests.

* Weekly Benchmarks Input Range (NVIDIA#1708)

* Rename axes= to dims= in frontend (NVIDIA#1741)

Currently we accept `axes=` for some ops like `fd.ops.sum` and `dims=`
for others like `fd.ops.squeeze`.

This is a small attempt to make the frontend arguments more consistent.
This change renames the `axis=` kwarg to `dim=` and the same for `axes=`
-> `dims=`.

I think we're free to set our own convention, but for reference:
- PyTorch uses `dim=` in most places and accepts either a single dim or
multiple using that same argument name, where applicable.
- Numpy uses `axis=` and, like PyTorch, accepts a list where applicable.
- `jax.lax` uses `dimensions=`

* Avoid unused smem workspace for serial grid reductions (NVIDIA#1727)

GridReduction can be lowered to either `gridReduce` or
`serialReductionStep`. `gridReduce` requires a smem workspace in order
to use multiple threads to aggregate partial sums. However,
`serialReductionStep` does not coordinate among threads and has no use
for a workspace. This change simply disables allocating that little bit
of extra shared memory if our only grid reductions are serial, which
currently only happens in split-K GEMM.

This reduces the smem allocated in a simple test from 16896 B to 16384 B
(about 97%). More importantly, this makes the computation in
`mma_utils::generateSharedMemoryEpilogueHeuristics()` more accurate.
Tests are updated to check that this computation is accurate.

The change in `kernel.cpp` is responsible for reducing actual smem usage
for split-K. The changes to `mma_utils` and `test_gpu_tensorcore.cpp`
are needed for adding testing that our expected smem usage matches the
actual usage.

* Issue NVIDIA#1748 (NVIDIA#1749)

Closes Issue NVIDIA#1748.
Apart from `c10::cuda::GetDevice`, no other functionality seems
affected.

* Rename `axes` to `dims` in benchmarks fusion definitions (NVIDIA#1751)

Changes the kwarg `axes` to `dims` following the API change in PR NVIDIA#1741.

* Bump matmul benchmark checkMatch() tolerance (NVIDIA#1747)

This is necessary due to recent switch to HSH

Fixes NVIDIA#1746

* linter

* change guard USE_DISTRIBUTED to NVFUSER_DISTRIBUTED in test/test_multidevice_sharding.cpp

* linting

* linter and cleanup

* remove allocator.h/cpp files

* Device index patch (NVIDIA#1752)

Fixes NVIDIA#1748 

guard c10::cuda::GetDevice API change on TORCH_VERSION

with this change, it ensures that we can build against stable release `<
2.2.0`, as well as TOT after
pytorch/pytorch#119142

For 2.3.0 nightly, if someone accidentally checkout a commit before the
patch, the build will still fail.

* fixing multidevice build (NVIDIA#1753)

API change coming from pytorch/pytorch#119421

* patching API GUARD (NVIDIA#1754)

patching API version guard so we'll still be able to build against older
pytorch version.

* Add a visitor for ValGraph (NVIDIA#1713)

Used in the loop promotion analysis. Extracted from NVIDIA#32

* empty commit for triggering CI

---------

Co-authored-by: Liqiang Lu <[email protected]>
Co-authored-by: Jacob Hinkle <[email protected]>
Co-authored-by: Priya Mishra <[email protected]>
Co-authored-by: Jingyue Wu <[email protected]>
Co-authored-by: Tom Fogal <[email protected]>
Co-authored-by: jjsjann123 <[email protected]>
Co-authored-by: Gao, Xiang <[email protected]>
Co-authored-by: Naoya Maruyama <[email protected]>
Co-authored-by: Meghan Cowan <[email protected]>
Co-authored-by: Ryan Spring <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing e.g. improving test infra and test coverage
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants