Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Implement 3xTF32 trick #3234

Merged
merged 3 commits into from
Mar 28, 2024
Merged

[BACKEND] Implement 3xTF32 trick #3234

merged 3 commits into from
Mar 28, 2024

Conversation

lezcano
Copy link
Contributor

@lezcano lezcano commented Feb 28, 2024

This PR implements the 3xTF32 trick to make use of the TCs on F32 tensors without sacrificing accuracy. This is particularly relevant for PyTorch, as TF32 is off by default.

Benchmarks on A100 from python/tutorials/03-matrix-multiplication.py run on float32 data using use_tf32=False:

         M       N       K     cuBLAS     Triton    This PR
0    256.0   256.0   256.0   1.927529   1.092267   1.489455
1    384.0   384.0   384.0   5.026909   3.567484   3.686400
2    512.0   512.0   512.0   8.192000   6.553600   6.898527
3    640.0   640.0   640.0  12.190476  10.448980  10.666666
4    768.0   768.0   768.0  13.405091  10.287628  14.503869
5    896.0   896.0   896.0  14.049280  13.380267  20.070399
6   1024.0  1024.0  1024.0  15.887515  12.264046  19.239927
7   1152.0  1152.0  1152.0  16.681475  15.633424  24.883201
8   1280.0  1280.0  1280.0  16.516129  15.340824  28.248276
9   1408.0  1408.0  1408.0  17.090206  14.774461  24.016635
10  1536.0  1536.0  1536.0  17.014154  15.624477  26.021647
11  1664.0  1664.0  1664.0  17.043394  15.073554  25.858942
12  1792.0  1792.0  1792.0  17.107190  16.171833  29.577431
13  1920.0  1920.0  1920.0  17.883570  15.762828  26.331430
14  2048.0  2048.0  2048.0  17.623127  17.032706  27.413751
15  2176.0  2176.0  2176.0  17.887688  16.686275  29.945905
16  2304.0  2304.0  2304.0  19.019006  17.933838  33.787654
17  2432.0  2432.0  2432.0  17.940270  17.288901  31.181425
18  2560.0  2560.0  2560.0  18.164080  17.075561  31.844508
19  2688.0  2688.0  2688.0  17.594183  16.703239  30.370742
20  2816.0  2816.0  2816.0  18.766871  18.089676  33.242537
21  2944.0  2944.0  2944.0  18.735350  17.855977  33.695763
22  3072.0  3072.0  3072.0  18.420008  17.766898  32.768000
23  3200.0  3200.0  3200.0  18.470418  17.704011  33.255391
24  3328.0  3328.0  3328.0  18.253370  17.710036  32.753092
25  3456.0  3456.0  3456.0  18.546485  17.793328  33.634362
26  3584.0  3584.0  3584.0  18.368824  17.833278  33.142423
27  3712.0  3712.0  3712.0  18.665424  17.938112  34.036574
28  3840.0  3840.0  3840.0  18.638578  18.076496  33.794348
29  3968.0  3968.0  3968.0  18.965486  18.190808  34.324595
30  4096.0  4096.0  4096.0  19.035276  18.365864  34.450135

It's an overall win, getting roughly a 85% speed-up on large sizes.

Note that the rounding is differs a little bit to the one implemented in CUTLASS. We could implement that rounding if we wanted though.

This is still a bit far from the 2x speed-ups announced by CUTLASS. To get close to those numbers, we should probably need to remove the stores to shared before ldmatrix.

@lezcano lezcano requested a review from ptillet as a code owner February 28, 2024 17:52
Copy link
Collaborator

@jlebar jlebar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome.

python/triton/language/semantic.py Outdated Show resolved Hide resolved
@@ -62,6 +62,7 @@ class CUDAOptions:
ptx_version: int = None
enable_fp_fusion: bool = True
allow_fp8e4nv: bool = False
allow_tf32: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about the name of this option. Without the context in this patch, I'd think this means, do we allow implicitly upgrading fp32 dots to tf32, which is not what this controls.

Perhaps supports_tf32, with a comment saying that this indicates whether the hw supports tf32 but doesn't give us permission to use tf32 to silently, except where the result is "as if" we'd computed in fp32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I wasn't sure about this name. I followed the convention of allow_fp8e4nv, but I agree that supports_tf32 is a much better name. Let me change it.

Copy link
Collaborator

@jlebar jlebar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@jlebar jlebar enabled auto-merge (squash) February 28, 2024 18:32
@lezcano lezcano disabled auto-merge February 28, 2024 18:39
Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is going to be a precision problem. Maybe we would need a separate control for it but we want the precision to match fp32

python/triton/language/semantic.py Outdated Show resolved Hide resolved
@sbodenstein
Copy link

Some comments:

  1. We should keep the ability to do IEEE standard FP32 computations (there are scientific ML applications where this may be important).
  2. From NVIDIA's talk about the precision, it seems this is only more precise for larger matrices (from their graph: the transition point is 200x200) using a particular definition of precision. So its not uniformly more precise, and decisions on implementation might need to be made based on matrix size.
  3. There exist many variants: CUTLASS has a 4xTF32 matmul that doesn't drop the smallest term in the expansion (and hence more precise than 3xTF32). In addition, the same trick can be done with BF16 (https://arxiv.org/abs/1904.06376) for 3xBF16 or 6xBF16. Seems like having a general API to support this type of pattern is the way to go, rather than baking a particular choice as FP32 dot.

@lezcano
Copy link
Contributor Author

lezcano commented Feb 29, 2024

A few updates:

  • After looking a bit more closely at the code with @alexsamardzic and @peterbell10, we found that we were able to squeeze some extra performance going up to 33TFPS. To do this, we need to implement this patch after the c += tl.dot(a,b) -> tl.dot(a,b,c) transformation has been done. This performance is pretty much 1/3 of that of tf32, which makes sense. Pipelining should take us close to the 40TFPS announced by CUTLASS
  • It is not kosher to implement optimisations in the frontend. Even less so those that are backend-specific.
  • For these two reasons, it makes much more sense to write this optimisation as a TTGIR optimisation pass. I hope to be able to put a PR tomorrow.

As for how to expose this. I agree that it'd be good to give the user the possibility to use the IEEE version. I can either add a new f32_backend: str flag (happy to bikeshed the name) exposing this. Now, I really think that even if the user can use the IEEE version, the default behaviour should be to use the 3xtf32 version. We already default to use_tf32=true for speed disregarding accuracy. I think we should also default to speed (but even accuracy!) in this case.

@ThomasRaoux
Copy link
Collaborator

after the c += tl.dot(a,b) -> tl.dot(a,b,c) transformation has been done.

tl.dot supports 3 arguments so you should be able to write it directly in the language. Would that allow you to do it as a library?

It is not kosher to implement optimisations in the frontend. Even less so those that are backend-specific.

It depends if we consider it to be a library level or a backend optimization.

For these two reasons, it makes much more sense to write this optimisation as a TTGIR optimisation pass. I hope to be able to put a PR tomorrow.

The downside of this strategy is that we need to add another attribute to language. As @sbodenstein pointed out there may be a bunch of other algorithm and we don't want to keep adding attributes to the language.

As for how to expose this. I agree that it'd be good to give the user the possibility to use the IEEE version. I can either add a new f32_backend: str flag (happy to bikeshed the name) exposing this. Now, I really think that even if the user can use the IEEE version, the default behaviour should be to use the 3xtf32 version. We already default to use_tf32=true for speed disregarding accuracy. I think we should also default to speed (but even accuracy!) in this case.

The current default behavior is tf32, are you suggesting changing it to 3xtf32? Or are you saying it should be default if use_tf32=False? In this case default behavior doesn't mean much and the best would be having an enum of precisions.

I would still think a library solution if possible would be nice.

@Jokeren
Copy link
Contributor

Jokeren commented Feb 29, 2024

For these two reasons, it makes much more sense to write this optimisation as a TTGIR optimisation pass. I hope to be able to put a PR tomorrow.

I have a different thought regarding this. I'm curious about what optimizations you'd like to implement?

Maybe we can do this:

  1. The backend can integrate optimizations for cases in which multiple dots exist in the loop (if I understand it correct) similar to your case.
  2. The 3XTF32 or 4xTF32 solution can be a library API.

@jlebar
Copy link
Collaborator

jlebar commented Feb 29, 2024

tl.dot supports 3 arguments so you should be able to write it directly in the language. Would that allow you to do it as a library?

I think the concern is that if the user writes c += core.dot(a, b), we need to first convert this into tl.dot(a, b, c) before we run the 3xtf32 optimization.

You could write the user code differently, but I think he's trying to make "reasonable user code" do the right thing. Which, given our goal of portability, makes sense to me.

The downside of this strategy is that we need to add another attribute to language.

We already have a tf32 attribute. It actually sounds like he wants to generalize this into f32_backend: str, rather than necessarily adding a new attribute. The generalized attribute would allow us to say that a dot runs in IEEE-fp32, tf32, 3xtf32, 4xtf32, or whatever other formats a backend supports.

OTOH if we don't do this, then it seems to me that every hardware vendor with its different dot precisions is going to want attributes on dot similar to the use_tf32 that we have today.

IOW it feels like this is a solution to the problem of attribute-creep?

@peterbell10
Copy link
Contributor

tl.dot supports 3 arguments so you should be able to write it directly in the language. Would that allow you to do it as a library?

I think the concern is that if the user writes c += core.dot(a, b), we need to first convert this into tl.dot(a, b, c) before we run the 3xtf32 optimization.

Exactly. If you write c += dot_3xtf32(a, b) where dot_3xtf32 is a library function then the first dot doesn't get optimized and you end up with an extra arith.addf and layout conversions which cause ~30% hit to perf.

If you expect the user to always write dot_3xtf32(a, b, c) then a library solution is fine.

@ptillet
Copy link
Collaborator

ptillet commented Feb 29, 2024

I personally think that replacing allow_tf32=True with an fp32_backend flag is a good change. It is more future-proof, as if TF32 stops existing (or the specs change) the allow_tf32 attribute will become awkward to maintain. Not to mention that TF32 is an nvidia-specific format, so even today it is weird to specify allow_tf32=True on non-nvidia hardware.

However, I do think the default for fp32_backend should be target.default_fp32_backend, which should be tf32 on nvidia. And of course each target should also have a supported_fp32_backend attribute so that users get an error if they try to use tf32 on AMD GPUs.

Leaving the perf considerations aside, doing it in user code with the current limitations of triton metaprogramming seems troublesome. That would force users to carry over constexpr flags that depend on the hardware they're targeting

@lezcano
Copy link
Contributor Author

lezcano commented Mar 1, 2024

I completely agree with Philippe's comment. I'll go ahead and implement that.

Yes, I meant the default behaviour for use_tf32=False. I do think that the default behaviour should be to do the matmul in TF32, in line with the rest of the language (cf. tl.sin). If we do the BC breaking change use_tf32 -> f32_backend, then this point is moot. Now, this point is still relevant if we want to keep both in the API for BC compat reasons. For now, I'll implement it in a BC compat way mapping use_tf32=False -> f32_backend="3xtf32" and throwing a deprecating warning, but I'm happy to make the BC breaking change if that's alright. Whatever you guys prefer cc @ptillet

@lezcano
Copy link
Contributor Author

lezcano commented Mar 5, 2024

Updated the PR. In particular:

  • This transformation is implemented as a TTGIR pass
  • Changed allow_tf32: bool into f32_backend. Removed allow_tf32 (bc-breaking) as otherwise the semantics with both args present are rather unclear.
  • Documented this flag + added per-device defaults for NVIDIA and AMD
  • Updated the speed-ups in the OP: We have a 85% speed-up.
  • Tested in test_dot

The PR may be easier to review commit-by-commit, given the amount of files it touches.

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks good, added couple minor comments

lib/Dialect/TritonGPU/Transforms/FP32DotTC.cpp Outdated Show resolved Hide resolved
@@ -565,7 +565,7 @@ def TT_DotOp : TT_Op<"dot", [Pure,
TT_FpIntTensor:$a,
TT_FpIntTensor:$b,
TT_FpIntTensor:$c,
BoolAttr:$allowTF32,
StrAttr:$f32Backend,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: making it an enum would be a bit less error-prone. You can use I32EnumAttr for that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use an enum, do we have to list all possible values in TritonOps.td? That doesn't work well for out-of-tree backends.

(I think it would be good to have the list of valid strings written down somewhere in the nvidia backend, though.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct the enum has to be in TritonOps.td, it still feels better than doing string checks? I believe we could decouple it using interfaces but that's probably an overkill to do it right now. Even if we have a lit of valid strings, it feels easy to make a typo when checkin the value.
I'd be fine with the possible constants be declared in a header as constexpr and used instead of literal strings

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct the enum has to be in TritonOps.td, it still feels better than doing string checks?

I just thought we had it as a goal to make it possible for people to write out-of-tree backends? If you think that's not relevant here, then definitely enums would be better.

I'd be fine with the possible constants be declared in a header as constexpr and used instead of literal strings

The problem is there's no way to enforce this, and people will use strings anyway. (And indeed the IR uses a string, so that's the place where we're most likely to have a typo.) So I think there's an argument for leaning into the strings in the C++ and accepting that we need to write tests. At least, that is my experience with this sort of thing from XLA, where it's used a lot.

Again I do think it needs to be documented in the backend which strings are acceptable. Right now I only see it in the frontend. I also think (orthogonal to all this) that we need a check during lowering that we know how to lower the string we got (i.e. check for invalid strings during lowering -- I don't see this anywhere) in order to catch typos in the IR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with the enum implementation, as we already use enums in quite a few places. Going from enums to any other representation in the future should be trivial tho.

include/triton/Dialect/Triton/IR/TritonOps.td Outdated Show resolved Hide resolved
@@ -565,7 +565,7 @@ def TT_DotOp : TT_Op<"dot", [Pure,
TT_FpIntTensor:$a,
TT_FpIntTensor:$b,
TT_FpIntTensor:$c,
BoolAttr:$allowTF32,
StrAttr:$f32Backend,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps "f32Backend" isn't the best name. We can use this field for the f16 (or whatever) backend too, right? No need to have multiple fields depending on the dtype.

Copy link
Contributor Author

@lezcano lezcano Mar 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed #3234 (comment). I'm happy to bikeshed if a better name / convention is proposed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$precision seems good to me.

@@ -565,7 +565,7 @@ def TT_DotOp : TT_Op<"dot", [Pure,
TT_FpIntTensor:$a,
TT_FpIntTensor:$b,
TT_FpIntTensor:$c,
BoolAttr:$allowTF32,
StrAttr:$f32Backend,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use an enum, do we have to list all possible values in TritonOps.td? That doesn't work well for out-of-tree backends.

(I think it would be good to have the list of valid strings written down somewhere in the nvidia backend, though.)

lib/Dialect/TritonGPU/Transforms/FP32DotTC.cpp Outdated Show resolved Hide resolved
lib/Dialect/TritonGPU/Transforms/FP32DotTC.cpp Outdated Show resolved Hide resolved
lib/Dialect/TritonGPU/Transforms/FP32DotTC.cpp Outdated Show resolved Hide resolved
lib/Dialect/TritonGPU/Transforms/FP32DotTC.cpp Outdated Show resolved Hide resolved
lib/Dialect/TritonGPU/Transforms/FP32DotTC.cpp Outdated Show resolved Hide resolved
lib/Dialect/TritonGPU/Transforms/FP32DotTC.cpp Outdated Show resolved Hide resolved
lib/Dialect/TritonGPU/Transforms/FP32DotTC.cpp Outdated Show resolved Hide resolved
@lezcano
Copy link
Contributor Author

lezcano commented Mar 6, 2024

Addressed the reviews. Should we bikeshed the name tho? @jlebar proposes "precision" for it to be more generic.

@ThomasRaoux
Copy link
Collaborator

ThomasRaoux commented Mar 6, 2024

Addressed the reviews. Should we bikeshed the name tho? @jlebar proposes "precision" for it to be more generic.

The name what proposed by @ptillet who is in vacation until Monday. Not sure we can discuss it without him.

Other than the name where I don't have a strong opinion I think this looks good.

Note that I did add an env var override for TF32 behavior (based on user request): #3290 which might be in your way. If you want to rebase with it that's great but feel free to break it and I'll add it back after your change

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks good to me modulo the naming decision

@jlebar
Copy link
Collaborator

jlebar commented Mar 6, 2024

If you're not comfortable making an executive decision, then I propose we use the current name and I volunteer as tribute to run sed if I can convince Phil that precision is a better name.

tf32: use TC with tf32 ops.
tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp
ieee: don't use TC, implement dot in software.
If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably these descriptions should be on the enum now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment still applies?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured it'd be better to leave them in the place where one would first look at when finding what this flag really does, but sure, I can move them

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example the enum is also used on dot_async.

(If you wanted a "see X" comment I'd be fine with that, although it's probably pretty obvious where to look.)

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let's wait @jlebar to approve as well

Copy link
Collaborator

@jlebar jlebar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks generally good to me, but I think we need to rename f32_backend to input_precision everywhere. That and a few other smallish changes.

OAI is going on vacation next week and there's some ongoing discussion about whether we want to merge this on the Friday before everyone is out or if that's a Bad Idea. In any case if we decide this should wait there shouldn't be as many merge conflicts as usual.

tf32: use TC with tf32 ops.
tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp
ieee: don't use TC, implement dot in software.
If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment still applies?

include/triton/Dialect/Triton/IR/TritonOps.td Outdated Show resolved Hide resolved
python/test/unit/language/test_core.py Outdated Show resolved Hide resolved
python/src/ir.cc Outdated Show resolved Hide resolved
@@ -1473,16 +1473,16 @@ def dot(input, other, acc=None, allow_tf32=None, max_num_imprecise_acc=None, out
:type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
:param other: The second tensor to be multiplied.
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
:param input_precision: How to exercise the Tenors cores for f32 x f32. If the device does not have Tensor Cores
or the inputs are not of dtype f32, this option is ignored.
:type other: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Avaliable options for amd: :code:`"ieee"`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the one place where we describe in the user documentation what these fields mean. Seems like we should write a lot more and cite sources for further reading.

python/triton/language/semantic.py Outdated Show resolved Hide resolved
@@ -29,7 +29,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
%a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT>
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT>
%c = tt.dot %a, %b, %prev_c {transA = false, transB = false, allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {transA = false, transB = false, f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like Thomas does not have time to help out with this right now. We're OK merging as-is and figuring out how to fix this as a follow-up.

@jlebar
Copy link
Collaborator

jlebar commented Mar 15, 2024

The result of the discussion about waiting a week is: If we keep the allow_tf32 (i.e. don't break bw compat on the function in core.py) then we're OK merging this sooner, because it won't require changes to our internal workloads.

@jlebar
Copy link
Collaborator

jlebar commented Mar 15, 2024

(We don't have a full story on bw compatibility, but for us, having a week or two where the old and new APIs are usable makes integration easier. Because the next release will be 3.0, we're ok making bw-incompat changes in general.)

@lezcano
Copy link
Contributor Author

lezcano commented Mar 15, 2024

Alas, I'm leaving now (European time) and I don't have time to do those changes. If anyone wants to champion this through the finish line, feel free to push to my branch to get it merged.

@jlebar
Copy link
Collaborator

jlebar commented Mar 15, 2024

Sorry we've been so slow on this one. I'll probably be working some next week and I have approval from folks to merge this. You have been very patient and I don't want to keep you in rebase hell indefinitely.

@lezcano
Copy link
Contributor Author

lezcano commented Mar 16, 2024

fwiw, I'll be on PTO on Monday / Tuesday, will be back on Wed. If that's too late for you, we can work on mergin this the Monday after. No rush.

@lezcano lezcano force-pushed the dot branch 5 times, most recently from e614d5b to e3dde35 Compare March 21, 2024 16:59
@lezcano
Copy link
Contributor Author

lezcano commented Mar 21, 2024

@jlebar #3234 and the previous comment still hold. Will add those tomorrow.
Other than that, I rebased, fixed the lit issues and renamed everything to "Input Precision". Would you mind running the CI? I'll fix the comments and anything that comes up in CI tomorrow first thing.

@jlebar
Copy link
Collaborator

jlebar commented Mar 21, 2024

Just kicked CI for you.

Thanks again for pushing through with this one.

@jlebar jlebar merged commit 47a35b6 into triton-lang:main Mar 28, 2024
5 checks passed
@lezcano lezcano deleted the dot branch March 28, 2024 19:19
ptillet pushed a commit that referenced this pull request Apr 1, 2024
This PR implements the [3xTF32
trick](NVIDIA/cutlass#385) to make use of
the TCs on F32 tensors without sacrificing accuracy. This is
particularly relevant for PyTorch, as TF32 is off by default.

Benchmarks on A100 from `python/tutorials/03-matrix-multiplication.py`
run on `float32` data using `use_tf32=False`:
```
         M       N       K     cuBLAS     Triton    This PR
0    256.0   256.0   256.0   1.927529   1.092267   1.489455
1    384.0   384.0   384.0   5.026909   3.567484   3.686400
2    512.0   512.0   512.0   8.192000   6.553600   6.898527
3    640.0   640.0   640.0  12.190476  10.448980  10.666666
4    768.0   768.0   768.0  13.405091  10.287628  14.503869
5    896.0   896.0   896.0  14.049280  13.380267  20.070399
6   1024.0  1024.0  1024.0  15.887515  12.264046  19.239927
7   1152.0  1152.0  1152.0  16.681475  15.633424  24.883201
8   1280.0  1280.0  1280.0  16.516129  15.340824  28.248276
9   1408.0  1408.0  1408.0  17.090206  14.774461  24.016635
10  1536.0  1536.0  1536.0  17.014154  15.624477  26.021647
11  1664.0  1664.0  1664.0  17.043394  15.073554  25.858942
12  1792.0  1792.0  1792.0  17.107190  16.171833  29.577431
13  1920.0  1920.0  1920.0  17.883570  15.762828  26.331430
14  2048.0  2048.0  2048.0  17.623127  17.032706  27.413751
15  2176.0  2176.0  2176.0  17.887688  16.686275  29.945905
16  2304.0  2304.0  2304.0  19.019006  17.933838  33.787654
17  2432.0  2432.0  2432.0  17.940270  17.288901  31.181425
18  2560.0  2560.0  2560.0  18.164080  17.075561  31.844508
19  2688.0  2688.0  2688.0  17.594183  16.703239  30.370742
20  2816.0  2816.0  2816.0  18.766871  18.089676  33.242537
21  2944.0  2944.0  2944.0  18.735350  17.855977  33.695763
22  3072.0  3072.0  3072.0  18.420008  17.766898  32.768000
23  3200.0  3200.0  3200.0  18.470418  17.704011  33.255391
24  3328.0  3328.0  3328.0  18.253370  17.710036  32.753092
25  3456.0  3456.0  3456.0  18.546485  17.793328  33.634362
26  3584.0  3584.0  3584.0  18.368824  17.833278  33.142423
27  3712.0  3712.0  3712.0  18.665424  17.938112  34.036574
28  3840.0  3840.0  3840.0  18.638578  18.076496  33.794348
29  3968.0  3968.0  3968.0  18.965486  18.190808  34.324595
30  4096.0  4096.0  4096.0  19.035276  18.365864  34.450135
```

It's an overall win,  getting roughly a 85% speed-up on large sizes.

Note that the rounding is differs a little bit to the one [implemented
in
CUTLASS](https://github.com/NVIDIA/cutlass/blob/a8f2c80db0564c74f4efccac71993b971dfc448b/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h#L99-L100).
We could implement that rounding if we wanted though.

This is still a bit far from the 2x speed-ups announced by CUTLASS. To
get close to those numbers, we should probably need to remove the stores
to shared before `ldmatrix`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants