Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-105] Fix CuDNN performance after code refactor #10116

Merged
merged 25 commits into from
Mar 22, 2018

Conversation

zheng-da
Copy link
Contributor

@zheng-da zheng-da commented Mar 14, 2018

Description

This PR tries to fix the performance degradation report in #9874

We observed about 4% performance decrease. There are multiple factors that cause the performance decrease.

  • The first one is that the refactored code passes more arrays to the backward of BatchNorm.
  • The second one is that the refactored code needs to reinitialize the CuDNN states in every forward and backward.
  • The third one is that the refactor code leads to more memory allocation (e.g., creating std::vector) in forward and backward. (However, the test shows that this doesn't cause much performance decrease.)

This PR tries to reduce these overhead.

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

@zheng-da zheng-da requested a review from cjolivier01 as a code owner March 14, 2018 20:33
for (uint32_t i = 0; i < out_data.size(); ++i) {
out_data[i] = nnvm::NodeEntry{n, i, 0};
}
std::vector<nnvm::NodeEntry> heads;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use reserve()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code runs to build the computation graph. It only runs once. Do we still need to call reserve()?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, please

// add all the auxiliary data
//for (uint32_t i = 0; i < prop.aux_states.size(); ++i) {
// inputs.emplace_back(ptr->inputs[i + prop.arguments.size()]);
//}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

std::vector<TBlob> in_data(3);
in_data[batchnorm::kData] = inputs[3];
in_data[batchnorm::kGamma] = inputs[4];
std::vector<TBlob> aux_states(2);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens to aux states (running mean and variance)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CuDNN version doesn't need aux_states. CUDA version does. So aux_states is set properly to run CUDA code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this fix only improves CUDNN operator? Wouldn't we expect the other non-CUDNN operators to also be slower by the same amount?

@cjolivier01
Copy link
Member

This needs a JIRA ticket

inputs.begin() + out_data_start);
std::vector<NDArray> out_data(inputs.begin() + out_data_start, inputs.end());
std::vector<NDArray> in_grad(outputs.begin(), outputs.begin() + 3);
static thread_local std::vector<NDArray> out_grad(1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does thread_local help here?

Copy link
Member

@cjolivier01 cjolivier01 Mar 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't these hold a reference to the NDArray's Chunk data indefinitely?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm trying to avoid memory allocation for std::vector.

But you are right. It potentially causes mem leak.

@zheng-da zheng-da changed the title Fix CuDNN performance after code refactor [MXNET-105] Fix CuDNN performance after code refactor Mar 14, 2018
@zheng-da zheng-da changed the title [MXNET-105] Fix CuDNN performance after code refactor [MXNET-105][WIP] Fix CuDNN performance after code refactor Mar 14, 2018
inputs.begin() + out_data_start);
std::vector<TBlob> out_data(inputs.begin() + out_data_start, inputs.end());
std::vector<TBlob> in_grad(outputs.begin(), outputs.begin() + 3);
static thread_local std::vector<TBlob> out_grad(1);
Copy link
Contributor

@piiswrong piiswrong Mar 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably too many thread_local.
Why are we copying the vectors in the first place?
why not change the interface of operator Forward/Backward?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question. I'll do that instead.

@zheng-da
Copy link
Contributor Author

zheng-da commented Mar 17, 2018

I measured the performance with (opts) and without (no opts) this PR and compare with the commit before #8302 (original) in the master branch. I ran each version 20 times and calculate the average and standard deviation. The performance is measured as image/second. This fix has got close to the original version. It's unclear where is the remaining perf loss.

opts no opts original
avg 5200.78 5033.98 5251.90
std 43.41 64.05 62.53

The command to run the test:

for i in {1..20}; do
python example/image-classification/train_imagenet.py --benchmark 1 --gpu 0,1,2,3,4,5,6,7 --batch-size 1024 --num-epochs 1 --disp -batches 100 --network resnet-v1 --num-layers 50 --data-nthreads 40 --min-random-scale 0.533 --max-random-shear-ratio 0 --max-random-rotate-angle 0 --max-random-h 0 --max-random-l 0 --max-random-s --dtype float16 --kv-store device
done

@zheng-da zheng-da changed the title [MXNET-105][WIP] Fix CuDNN performance after code refactor [MXNET-105] Fix CuDNN performance after code refactor Mar 21, 2018
@piiswrong
Copy link
Contributor

@cjolivier01

@piiswrong piiswrong merged commit 46e47cb into apache:master Mar 22, 2018
@zheng-da zheng-da deleted the fix_cudnn_perf branch March 24, 2018 05:51
})
}
#else
aux_states[batchnorm::kMovingMean] = inputs[6];
aux_states[batchnorm::kMovingVar] = inputs[7];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zheng-da aux_states is not defined if USE_CUDNN is not enabled. @marcoabreu seems there is no pure cuda ci environment which is not built with cudnn.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see. i'll update it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree. @marcoabreu could you add a CI only with CUDA?

Copy link
Contributor

@marcoabreu marcoabreu Mar 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, no problem at all! Compilation only or do we need tests as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to run the code at least once. We probably don't need to try both Python2 and Python3, something like that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done: #10281

ashokei pushed a commit to ashokei/incubator-mxnet that referenced this pull request Mar 27, 2018
* Reduce #inputs/outputs of batchnorm backward.

* Pass more arrays to BN.

* Make std::vector thread local.

* Set inputs of BN backward for other cases.

* Fix for other cases.

* remove commented code.

* fix a potential mem leak.

* Fix a compile error in mkldnn.

* Fix an error.

* reserve space for std::vector.

* Fix alignment.

* Fix cpp unit test.

* Fix BN CPP unit tests.

* Fix a compile error.

* Fix compilation error.

* Move Op signature.

* Cache CuDNN conv op.

* Fix compile error.

* Fix compile error.

* Remove thread_local.

* Reduce mem alloc when caching cudnn conv.

* Fix a lint error.

* Cache CuDNN deconv.

* Fix lint error.
jinhuang415 pushed a commit to jinhuang415/incubator-mxnet that referenced this pull request Mar 30, 2018
* Reduce #inputs/outputs of batchnorm backward.

* Pass more arrays to BN.

* Make std::vector thread local.

* Set inputs of BN backward for other cases.

* Fix for other cases.

* remove commented code.

* fix a potential mem leak.

* Fix a compile error in mkldnn.

* Fix an error.

* reserve space for std::vector.

* Fix alignment.

* Fix cpp unit test.

* Fix BN CPP unit tests.

* Fix a compile error.

* Fix compilation error.

* Move Op signature.

* Cache CuDNN conv op.

* Fix compile error.

* Fix compile error.

* Remove thread_local.

* Reduce mem alloc when caching cudnn conv.

* Fix a lint error.

* Cache CuDNN deconv.

* Fix lint error.
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* Reduce #inputs/outputs of batchnorm backward.

* Pass more arrays to BN.

* Make std::vector thread local.

* Set inputs of BN backward for other cases.

* Fix for other cases.

* remove commented code.

* fix a potential mem leak.

* Fix a compile error in mkldnn.

* Fix an error.

* reserve space for std::vector.

* Fix alignment.

* Fix cpp unit test.

* Fix BN CPP unit tests.

* Fix a compile error.

* Fix compilation error.

* Move Op signature.

* Cache CuDNN conv op.

* Fix compile error.

* Fix compile error.

* Remove thread_local.

* Reduce mem alloc when caching cudnn conv.

* Fix a lint error.

* Cache CuDNN deconv.

* Fix lint error.
zheng-da added a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* Reduce #inputs/outputs of batchnorm backward.

* Pass more arrays to BN.

* Make std::vector thread local.

* Set inputs of BN backward for other cases.

* Fix for other cases.

* remove commented code.

* fix a potential mem leak.

* Fix a compile error in mkldnn.

* Fix an error.

* reserve space for std::vector.

* Fix alignment.

* Fix cpp unit test.

* Fix BN CPP unit tests.

* Fix a compile error.

* Fix compilation error.

* Move Op signature.

* Cache CuDNN conv op.

* Fix compile error.

* Fix compile error.

* Remove thread_local.

* Reduce mem alloc when caching cudnn conv.

* Fix a lint error.

* Cache CuDNN deconv.

* Fix lint error.
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants