diff --git a/src/runtime/contrib/miopen/conv_forward.cc b/src/runtime/contrib/miopen/conv_forward.cc index 139801feef15..e1d6eb61979f 100644 --- a/src/runtime/contrib/miopen/conv_forward.cc +++ b/src/runtime/contrib/miopen/conv_forward.cc @@ -189,8 +189,29 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward") entry_ptr->conv_entry.data_type, y->shape[0], y->shape[1], y->shape[2], y->shape[3])); + // Set workspace + size_t workspace_size = 0; + MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize( + entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, &workspace_size)); + entry_ptr->conv_entry.UpdateWorkspace(workspace_size); + const float alpha = 1.f; const float beta = 0.f; + + const int request_algo_count = 4; + const bool exhaustive_search = true; + void* workspace = entry_ptr->conv_entry.workspace; + if (workspace_size == 0) workspace = nullptr; + int returned_algo_count = 0; + miopenConvAlgoPerf_t perfs[4]; + + MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, x->data, + entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.output_desc, y->data, request_algo_count, &returned_algo_count, + perfs, workspace, workspace_size, exhaustive_search)); + MIOPEN_CALL(miopenConvolutionForward( entry_ptr->handle, &alpha, entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.conv_desc,