-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Auto TensorCore CodeGen #4105
Comments
Awesome solution! Just curios: for shapes which are worse than cudnn/cublas, what kind of tuning is using? |
We haven’t spent much effort on performance tuning yet. For cases with bad performance we plan to do profiling to figure out the causes firstly. One possible way of optimization is to manually modify the generated code. If the manual optimization really works and it is general enough, we can try to implement it in the schedule. |
Good point! We do have some internal discussions about whether we need to automatically search the schedule space based on performance between TensorCore and non-TensorCore kernel, since TensorCore implementation may not beat the non-TensorCore version for every shapes. This is one of the plan-to-do features and any further comments and inputs are also welcome. One possible solution is to expose TensorCore as another schedule configuration knob to let auto-tuner decide whether we need to turn it on or not. Another potential solution is that in the IR pass we decide on whether a certain shape may perform better with TensorCore with heuristics. There are pros and cons with both solution. For the former one, the tuner space will be enlarged, thus bringing a little bit larger tuning space. For the latter one, since we make decision in the IR pass internally, the tuner space is kept almost the same however introduce dependency upon the accuracy of the heuristics, although for TensorCore due it is hardware nature we think it might be clear to decide whether a shape is performance friendly for TensorCore or not, there is still possibility that we may choose a low-performance kernel. |
Thanks for the RFC, also cross link to #4052. Non standard buffer allocationWe are moving toward using special memory scopes to annotate the special memory(e.g. mma). The use of Here is an alternative solution: introduce a new scope for the special memory needed for lowering, then the special rule can be used to generate the corresponding memory needed. Of course there could be additional hints that are needed to lower the the allocation code, you can likely embed that additional information with a special AttrStmt outside the allocation scope. Place of Pattern MatchingRight now from the reading of RFC, seems the early pattern matching was done before flattening and was dependent on the compute structure. I wonder if we could de-couple this, with some annotations, run some of the rewriting after storage flatten. Of course the low-level code does not enjoy the benefit of the multi-dimension indices, but the access pattern can still be detected by DetectLinearEquation. One possible limitation I see the current approach is that whether we could support operations like conv2d, as we will need to explicitly express compute in this form(which is fine for now). Complement and Combine with Tensor Intrinsics based TensorCore supportIt would be great to hear from more thoughts @Hzfengsy @minminsun about how can we combine the tensor intrinsics based approach with the more automatic pattern detector one. e.g We always tries to have a philosophy to enable the manual scheduling options that can gives us a way to specify search space, then build automation on top. This allows us to takes a spectrum of approach, use more manual one if necessary, and build more diverse automated solution. Our eventual goal would still be unify all tensorization support under tensor intrinsics, and build automation on top. One idea would be we still declare the lowering rules via tensor intrinsics, but reuses the pattern matching techniques in this RFC to rewrite to hints that applies the tensor intrinsics. This way we can organically combine the two ideas together. |
This is really impressive work, congrats! |
Our fp16 TensorCore kernel are tuned on V100 with CUDA toolkit 9.0 with driver 396.44. The int8 TensorCore kernels are tuned on T4 with CUDA toolkit 10.1 with driver 418.39. On different GPUs, the performance of tuned kernels can be different. |
Thanks @tqchen and @Hzfengsy for your valuable feedbacks. We are trying out some of your suggestions. Will have further discussions with you after we have made some evaluations and trials.
I doubt whether "using TensorCores will decrease precision", if the inputs are already in fp16 or int8. We did try to add an "enable_tensor_core" option in tvm.build_config, but it seems like build_config can't be passed to AutoTVM building. Any suggestion on where to add this option is welcome. But I think eventually we will not need this option, after the implementation is proven to be robust enough. For example, in Tensorflow, MatMul/Conv on fp16 data by default uses TensorCore Kernel of cublas/cudnn.
Thanks for correcting my understanding. So it seems like the tensorcore operation is more like c = float(a)*float(b) + c than c = float(a*b) + c |
We had a meeting with @Hzfengsy today. We discussed the difference and similarity of our solutions. They are different in the front-end: our solution tries to make it as transparent as possible to make it easy-using while #4095 provides more controllability to the user (schedule developer). They are actually targeting different users, so we think both solutions can co-exist. But we both agreed that the intrinsics in the back-end should combine. As to the fragment allocation, we are OK to change from new_expr to the way of introducing new scopes, but currently the new scope introduced in #4052 is not enough for the codegen of fragment allocation if it's extended to support different warp tile sizes and data layouts (col_major/row_major). One possible but not so elegant solution we proposed is to extend the scopes to also include tile size and data layout. @Hzfengsy is also trying to figure out a solution here. We will have more discussions on this. |
I have a proposal to minimize the invasion in TVM and also fundamentally support TensorCore in TVM. This is in the middle of both methodology of #4052 and this RFC. |
Sorry for the late reply. We were occupied by refactoring our implementation to combine with #4052. Thanks a lot for your proposal. Generating PTX or even SASS assembly is really an interesting topic and we may have some investigations and discussions on this later. As to the TensorCore CodeGen, I think maybe the data structure is not the only pain point. The root is in the programming model of TensorCore, in which the threads inside a warp are no longer individual threads and some high level information such as matrix_a/b, row/col_major, strides of a buffer, is required in low level operations. So I guess generating PTX directly may not relieve these pains. @Hzfengsy what do you think about this? |
mark |
We propose a solution for TensorCore CodeGen with significant transparency, flexibility and usability. In this solution, the algorithm description and schedule of TensorCore CodeGen is no different than that of a normal CUDA CodeGen. All the information needed by wmma API, such as matrix_a/matrix_b/accumulator, row_major/col_major, warp tile size and so on, is automatically derived from the AST. Of course, not every algorithm and schedule is suitable for TensorCore computation. This solution will do the check and fall back to normal CUDA CodeGen for those that are not qualified for TensorCore CodeGen.
In this solution, 3 IRVisitors and 1 IRMutator are added.
BodyVisitor, which is called by ScheduleAnalyser, visits the body stmt of original ComputeOp to get the access indices of input matrices if it is recognized as matrix multiply. ScheduleAnalyser compares the access indices with the axis/reduce_axis of ComputeOp to figure out whether an input matrix is matrix_a or matrix_b, row_major or col_major.
MMAMatcher does the pattern matching on AST stmt. The pattern it tries to find out is as following:

If matched, the a, b, c will be recorded as fragment registers, which are important inputs to the next visitor.
BufferAnalyser, the last visitor, will get all of the rest information needed for TensorCoreIRMutator, like strides of src/dst buffer for wmma load/store matrix operation, warp tile size for fragment allocation as well as checking whether the schedule is qualified for TensorCore, loops that need to be scaled after normal load/store and compute operation replaced by TensorCore operations, etc..
TensorCoreIRMutator mutates the AST stmt for TensorCore CodeGen. The subtree matched by MMAMatcher will be replaced with “mma_sync” extern call. Load/Store of fragments are replaced with “load/store_matrix_sync” extern call, with the thread index getting unified within a warp. Thread index unification, i.e. changing the index of every thread to the same as the first thread of the warp, is done by ThreadIdxMutator on the subtree.
The TensorCore IR Passes are applied before StorageFlatten because they need stride/shape and index of specific dimensions before they got flattened into one. Before StorageFlatten, “Allocation” is represented by Realize IR Node, which has no new_expr member as Allocate IR Node has. So we added it to Realize IR Node to carry the expr for fragment allocation and pass to Allocate IR Node. We noticed the comment of deprecating new_expr when merging with the latest TVM codebase. We would like to ask for a reconsideration of this decision, because it is really useful for some non-standard buffer allocations.
This solution is evaluated on a sample schedule of Matmul, which is based on AutoTVM. It supports fp16 and int8 data type, and three kinds of data layouts: NN, NT, TN.
On some model layers, we have already achieved better performance than CUBLAS/CUDNN:
FP16 on V100, CUDA 9.0, Driver 396.44
Int8 on T4, CUDA10.1, Driver 418.39
There are also many shapes on which CUBLAS/CUDNN is much better. The performance tuning is still on-going.
Thanks!
-- Minmin Sun, Lanbo Li, Chenfan Jia and Jun Yang of Alibaba PAI team
The text was updated successfully, but these errors were encountered: