Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[1.x][FEATURE] CUDA graphs support #19142

Merged
merged 15 commits into from
Sep 19, 2020
Merged

Conversation

ptrendx
Copy link
Member

@ptrendx ptrendx commented Sep 14, 2020

Description

CUDA graphs is a feature of CUDA 10, which enables lowering CPU overhead by bundling multiple kernel launches together.

The main limitation of CUDA graphs is that they do require the graph to be static - no parameters to the kernels can change, otherwise the graph needs to be recaptured. That is why this feature is currently only enabled for the symbolic models and Gluon models hybridized with hybridize(static_alloc=True, static_shape=True).

The feature is not enabled by default and requires environment variable MXNET_ENABLE_CUDA_GRAPHS to be set. In order to not capture the operations, the execution of which may change during the course of the training job, stateful operators and operators relying on resources other than workspace are not included in the graph.

Since the feature lowers the CPU overhead of the execution, the impact is most visible in the inference or scale-out workloads with small batch size. Let us consider following script comparing fp16, batch size 1 inference of RN50v2 model from GluonCV with and without graphs:

import mxnet as mx
import gluoncv as gcv
from gluoncv.model_zoo import get_model
import time
import os

net = get_model('ResNet50_V2')
net2 = get_model('ResNet50_V2')

net.initialize(ctx=mx.gpu())
net2.initialize(ctx=mx.gpu())

net.cast('float16')
net2.cast('float16')

net.hybridize(static_alloc=True, static_shape=True)
net2.hybridize(static_alloc=True, static_shape=True)

img = mx.random.uniform(shape=(1, 3, 224, 224), ctx=mx.gpu(), dtype='float16')

os.environ["MXNET_ENABLE_CUDA_GRAPHS"] = "0"

for _ in range(10):
    o = net(img)

mx.nd.waitall()

s = time.time()
for _ in range(1000):
    o = net(img)

mx.nd.waitall()
e = time.time()
print("No graphs: ", e - s)
mx.nd.waitall()

os.environ["MXNET_ENABLE_CUDA_GRAPHS"] = "1"

for _ in range(10):
    o = net2(img)

mx.nd.waitall()

s = time.time()
for _ in range(1000):
    o = net2(img)

mx.nd.waitall()
e = time.time()
print("With graphs: ", e - s)

The result obtained on V100 16GB:

[20:33:46] ../src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
No graphs:  2.8153152465820312
With graphs:  2.3230245113372803

so over 17% increase in performance. The same script but with 128 batch size gives much smaller improvement:

[20:38:34] ../src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
No graphs:  47.31952977180481
With graphs:  47.152849197387695

so 0.3%.

Checklist

Essentials

  • PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], [FEATURE], [DOC], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

@ptrendx ptrendx added the pr-work-in-progress PR is still work in progress label Sep 14, 2020
@mxnet-bot
Copy link

Hey @ptrendx , Thanks for submitting the PR
All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands:

  • To trigger all jobs: @mxnet-bot run ci [all]
  • To trigger specific jobs: @mxnet-bot run ci [job1, job2]

CI supported jobs: [clang, centos-cpu, windows-gpu, windows-cpu, unix-cpu, website, miscellaneous, sanity, edge, unix-gpu, centos-gpu]


Note:
Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin.
All CI tests must pass before the PR can be merged.

@lanking520 lanking520 added pr-awaiting-testing PR is reviewed and waiting CI build and test pr-awaiting-review PR is waiting for code review pr-work-in-progress PR is still work in progress and removed pr-work-in-progress PR is still work in progress pr-awaiting-testing PR is reviewed and waiting CI build and test pr-awaiting-review PR is waiting for code review labels Sep 14, 2020
@lanking520 lanking520 added pr-awaiting-testing PR is reviewed and waiting CI build and test and removed pr-work-in-progress PR is still work in progress labels Sep 14, 2020
@ptrendx
Copy link
Member Author

ptrendx commented Sep 14, 2020

Forgot to add in the description that most of the work here was done by @DickJC123.

.set_attr<FCompute>("FCompute<gpu>", EigvalsOpForward<gpu>);

#if MXNET_USE_CUSOLVER == 1

NNVM_REGISTER_OP(_npi_eigvalsh)
.set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

istead of setting false everywhere, can we just check if hasAttr("FIsCUDAGraphsCompatible") so that by default its false?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the "everywhere" here is actually way less than if we went the other way around (only a bunch of operators with FCompute has synchronization that is not allowed under graphs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the long term goal here is actually to make those excluded operators compatible too.

static auto& fgraphcompatible = Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
const auto& attrs = exec->attrs;
if (attrs.op != nullptr) {
const auto f = fgraphcompatible.get(attrs.op, nullptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand - I do want to have the function call per Op here, not just a simple true/false.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats the point of registering a lambda function that just returns false?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess look here: https://github.com/apache/incubator-mxnet/pull/19142/files#diff-789523bf443903e74acfa010a5d6b572R33-R37 - this is for dropout, which uses random resource for training (and thus is excluded), but is just a passthrough for inference (and so we want to include it there).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya that one is fine, but "_npi_eig" is always false...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then I would need to add that function for almost all operators for it to return true...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We went with the default to be include the operator, but instead have the functionality itself be non-default.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the majority of ops supported to be included in cuda graphs? If a new user comes in to write an op, do they need to be aware of how to handle cuda graph support?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the first question: yes - nearly all of the FCompute functions that do not use random resource are compatible with graphs.
For the second question - for FCompute operators to not be compatible you need to do synchronization inside the operator - either via stream synchronize or allocation. You generally do not want to do either one of those as it really hurts performance (and the operators that in this PR I marked incompatible in this PR do just that). If you just launch a kernel (or multiple), which is the case for vast majority of the operators, then you are good and do not even need to think about graphs - it will just work.

I'm still exploring the ways of automatically testing newly added operators in order for the feature to be able to be on by default, but I do not consider this the scope of this PR, as v1.x branch is not really supposed to get many more operators (I will do that in the PR to master). Generally this would involve testing operators with hybridize(static_alloc=True, static_shape=True) (which generally should be tested much more as right now testing of this functionality is really limited, even though it is widely used).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be ok since its disabled by default, and only enabled explicitly by setting the env var.

@lanking520 lanking520 added pr-awaiting-review PR is waiting for code review pr-awaiting-testing PR is reviewed and waiting CI build and test pr-work-in-progress PR is still work in progress and removed pr-awaiting-testing PR is reviewed and waiting CI build and test pr-awaiting-review PR is waiting for code review labels Sep 14, 2020
@ptrendx ptrendx added pr-awaiting-review PR is waiting for code review and removed pr-work-in-progress PR is still work in progress labels Sep 18, 2020
Copy link
Contributor

@DickJC123 DickJC123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @ptrendx mentioned, I coded much of this facility, as embodied in the cuda_graphs.h file. I have reviewed that file carefully and found it to be true to the version we have been using in NVIDIA's container version of MXNET. Thanks @ptrendx for adding the FIsCUDAGraphsCompatible attribute, as is needed for a robust CUDA Graphs support, and other improvements.

LGTM.

@samskalicky samskalicky merged commit 0fce381 into apache:v1.x Sep 19, 2020
DickJC123 pushed a commit to DickJC123/mxnet that referenced this pull request Jun 1, 2021
* Initial cherry-pick

* Store NodeAttrs in OpExecutor

* Do not allow stateful operations in CUDA graphs and provide mechanism
for marking ops as safe

* Guard against using ops with synchronization

* Cleaning

* Properly guard graphs

* Limit graphs to CUDA 10.2+

* Fix the compilation when graphs are not available

* Guarding the libcuda.so usage behind RTC compilation flag

* Document the env variables

* Add test

* Fix the test

* Use with_environment
chinakook pushed a commit to chinakook/mxnet that referenced this pull request Aug 4, 2021
* Initial cherry-pick

* Store NodeAttrs in OpExecutor

* Do not allow stateful operations in CUDA graphs and provide mechanism
for marking ops as safe

* Guard against using ops with synchronization

* Cleaning

* Properly guard graphs

* Limit graphs to CUDA 10.2+

* Fix the compilation when graphs are not available

* Guarding the libcuda.so usage behind RTC compilation flag

* Document the env variables

* Add test

* Fix the test

* Use with_environment
chinakook pushed a commit to chinakook/mxnet that referenced this pull request Aug 8, 2021
* Initial cherry-pick

* Store NodeAttrs in OpExecutor

* Do not allow stateful operations in CUDA graphs and provide mechanism
for marking ops as safe

* Guard against using ops with synchronization

* Cleaning

* Properly guard graphs

* Limit graphs to CUDA 10.2+

* Fix the compilation when graphs are not available

* Guarding the libcuda.so usage behind RTC compilation flag

* Document the env variables

* Add test

* Fix the test

* Use with_environment
DickJC123 pushed a commit to DickJC123/mxnet that referenced this pull request Feb 15, 2022
* Initial cherry-pick

* Store NodeAttrs in OpExecutor

* Do not allow stateful operations in CUDA graphs and provide mechanism
for marking ops as safe

* Guard against using ops with synchronization

* Cleaning

* Properly guard graphs

* Limit graphs to CUDA 10.2+

* Fix the compilation when graphs are not available

* Guarding the libcuda.so usage behind RTC compilation flag

* Document the env variables

* Add test

* Fix the test

* Use with_environment
DickJC123 added a commit that referenced this pull request Mar 18, 2022
* [1.x][FEATURE] CUDA graphs support (#19142)

* Initial cherry-pick

* Store NodeAttrs in OpExecutor

* Do not allow stateful operations in CUDA graphs and provide mechanism
for marking ops as safe

* Guard against using ops with synchronization

* Cleaning

* Properly guard graphs

* Limit graphs to CUDA 10.2+

* Fix the compilation when graphs are not available

* Guarding the libcuda.so usage behind RTC compilation flag

* Document the env variables

* Add test

* Fix the test

* Use with_environment

* Fix compile and test_cuda_graphs

* Fix lint

* Mark more ops as not CUDA Graphs compatible

* Mark some linalg ops as not CUDA Graphs compatible

* Marked 2 ops CUDA Graphs incompatible due to cpu->gpu copy

* Mark cuDNN Dropout as fully CUDA Graphs compatible.  Reenable tests.

* clang-tidy fixes

* More clang-tidy fixes

* Avoid CUDA_CALL(e): improper macro expansion

* Add compile guard to Dropout's FIsCUDAGraphsCompatible def

* Temporarily add '-s' to pytest serial tests

* Fix DropoutOp.dropout_passthrough_ handling for CUDA Graphs

* Adapt test_gluon_gpu.py::test_cuda_graphs for gluon2.0

* Create CUDA Graph 'dot' files if MXNET_CUDA_GRAPHS_DBG_FILE=<file_prefix>

* Fix clang-tidy

* Fix more clang-tidy

* Skip test_np_standard_binary_funcs test of 0-dim array broadcast

* Improve test_rnn_layers_fp{16,32} invocation

* Run test_rnn_layers_fp32 only when cuDNN is present

* Fix potential out-of-bounds write in count_sketch.cu

* Add temp output to debug centos crash

* Mark InstanceNorm and LeakyRELU as not CUDA Graphs compatible

* Ops calling FStatefulCompute* are not CUDA Graphs compatible by default

* Fix clang-tidy

* Revert "Add temp output to debug centos crash"

This reverts commit e013a85.

* Quiet 'unused variable' compilation warning

* Trigger CI

* Check of FCreateOpState removed given new check for FStatefulCompute*

* Revert "Temporarily add '-s' to pytest serial tests"

This reverts commit 5a2f847.

Co-authored-by: Przemyslaw Tredak <[email protected]>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants