Skip to content

Commit

Permalink
fix workspace limit in cudnn-8
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiu committed Nov 13, 2020
1 parent c52fe48 commit 5c4a024
Showing 1 changed file with 44 additions and 5 deletions.
49 changes: 44 additions & 5 deletions paddle/fluid/operators/conv_cudnn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/framework/conv_search_cache.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
Expand Down Expand Up @@ -101,6 +102,24 @@ inline int MaxBwdFilterAlgos(cudnnHandle_t cudnn_handle) {
return max_algos;
}

template <typename PerfType, typename AlgoType>
void ChooseAlgoByWorkspace(PerfType* perf_results, size_t perf_num,
size_t workspace_byte, AlgoType* algo) {
for (size_t i = 0; i < perf_num; ++i) {
auto result = perf_results[i];
if (result.status == CUDNN_STATUS_SUCCESS &&
result.memory < workspace_byte) {
*algo = result.algo;
VLOG(3) << " algo: " << result.algo << ", time: " << result.time
<< " ms, wksp = " << result.memory
<< ", status = " << result.status;
return;
}
}
VLOG(3) << "Can not find alog that requires memory < "
<< static_cast<double>(workspace_byte) / (1 << 20) << " MB";
}

template <typename PerfType, typename AlgoType>
void ChooseAlgo(const std::vector<PerfType>& perf_results,
size_t workspace_byte, AlgoType* algo) {
Expand Down Expand Up @@ -219,7 +238,10 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {

if (workspace_size > workspace_size_limit) {
#if CUDNN_VERSION >= 8000
workspace_size_limit = workspace_size;
// cudnnGetConvolutionForwardAlgorithm is removed in CUDNN-8
ChooseAlgoByWorkspace<perf_t, algo_t>(perf_results.get(),
kNUM_CUDNN_FWD_ALGS,
workspace_size_limit, &algo);
#else
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
Expand Down Expand Up @@ -316,7 +338,6 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
size_t workspace_size = 0;
bool has_got_workspace_size = true;
algo_t algo;

#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
Expand Down Expand Up @@ -362,9 +383,10 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
if (workspace_size > workspace_size_limit) {
has_got_workspace_size = false;
#if CUDNN_VERSION >= 8000
// There is no cudnnGetConvolutionBackwardDataAlgorithm in CUDNN 8
// version.
workspace_size_limit = workspace_size;
// cudnnGetConvolutionBackwardDataAlgorithm is removed in CUDNN-8
ChooseAlgoByWorkspace<perf_t, algo_t>(perf_results.get(),
kNUM_CUDNN_BWD_DATA_ALGS,
workspace_size_limit, &algo);
#else
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
Expand Down Expand Up @@ -493,6 +515,23 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
workspace_size = GetWorkspaceSize(args, algo);
if (workspace_size > workspace_size_limit) {
workspace_size = workspace_size_limit;
#if CUDNN_VERSION >= 8000
// cudnnGetConvolutionBackwardFilterAlgorithm is removed in CUDNN-8
ChooseAlgoByWorkspace<perf_t, algo_t>(perf_results.get(),
kNUM_CUDNN_BWD_FILTER_ALGS,
workspace_size_limit, &algo);
#else
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
<< workspace_size << ") exceeds the limit("
<< workspace_size_limit << ")";
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
args.handle, args.idesc.desc(), args.odesc.desc(),
args.cdesc.desc(), args.wdesc.desc(),
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
#endif
}
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
Expand Down

1 comment on commit 5c4a024

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on 5c4a024 Nov 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍PR: #28611 Commit ID: 5c4a024 contains failed CI.

Please sign in to comment.