Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Add proper scheduling for dense on CUDA #3923

Merged
merged 6 commits into from
Sep 19, 2019

Conversation

comaniac
Copy link
Contributor

@comaniac comaniac commented Sep 9, 2019

@icemelon9 please review this PR that adds scheduling for dense OP on CUDA since the original scheduling was too basic to achieve reasonable performance.

The added scheduling was modified from topi/recipe/gemm/cuda_gemm_square.py to achieve high performance (6TFlop/s for 2048x2048 dense matrix after AutoTVM tuning). For small batch (<64) dense, we are still based on the original scheduling but just added a parameter (tile_k) for AutoTVM to tune (~370 GFlop/s for batch size 1 dense computation).

One reason to have separate scheduling for different batch size is that I encountered invalid CUDA kernel errors when applying the high performance scheduling to small batch. I think that may be due to invalid splits, but I am not quite sure. You are welcome to comment and suggest improvements.

@comaniac
Copy link
Contributor Author

comaniac commented Sep 9, 2019

@icemelon9 @vinx13 @Huyuwei could you help review this PR?

@vinx13
Copy link
Member

vinx13 commented Sep 9, 2019

@comaniac please look into ci error

@icemelon
Copy link
Member

Could you also add fallback config for autotvm?

@comaniac
Copy link
Contributor Author

@vinx13 Seems failed a case from NNVM. Will fix it soon.
@icemelon9 Will do it tomorrow.

Thanks!

@comaniac
Copy link
Contributor Author

The above commit:

  1. Fixed the failed unit test that was caused by too small shape (=1) in reduce axis.
  2. Added fallback config for both cases.

topi/python/topi/cuda/dense.py Outdated Show resolved Hide resolved
topi/python/topi/cuda/dense.py Outdated Show resolved Hide resolved
topi/python/topi/cuda/dense.py Outdated Show resolved Hide resolved
Copy link
Member

@icemelon icemelon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

topi/python/topi/cuda/dense.py Outdated Show resolved Hide resolved
topi/python/topi/cuda/dense.py Outdated Show resolved Hide resolved
topi/python/topi/cuda/dense.py Show resolved Hide resolved
topi/python/topi/cuda/dense.py Outdated Show resolved Hide resolved
@comaniac comaniac requested a review from Huyuwei September 18, 2019 06:04
@comaniac comaniac closed this Sep 19, 2019
@comaniac comaniac reopened this Sep 19, 2019
@comaniac
Copy link
Contributor Author

CI error doesn't relate to this PR and it passed locally. Re-run without change.

@comaniac comaniac closed this Sep 19, 2019
@comaniac comaniac reopened this Sep 19, 2019
@Huyuwei
Copy link
Contributor

Huyuwei commented Sep 19, 2019

@comaniac This is merged, thanks!

@Huyuwei Huyuwei merged commit bec08fe into apache:master Sep 19, 2019
@comaniac comaniac deleted the dense_cuda_schedule branch September 19, 2019 21:24
wweic pushed a commit to wweic/tvm that referenced this pull request Sep 30, 2019
* add proper scheduling for dense on CUDA

* add fallback config and fix unit test

* fix corner cases

* refactoring

* fix bias and add testcase

* let fusion happen
wweic pushed a commit to wweic/tvm that referenced this pull request Sep 30, 2019
* add proper scheduling for dense on CUDA

* add fallback config and fix unit test

* fix corner cases

* refactoring

* fix bias and add testcase

* let fusion happen
wweic pushed a commit to neo-ai/tvm that referenced this pull request Oct 1, 2019
* add proper scheduling for dense on CUDA

* add fallback config and fix unit test

* fix corner cases

* refactoring

* fix bias and add testcase

* let fusion happen
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants