diff --git a/python/tvm/contrib/miopen.py b/python/tvm/contrib/miopen.py index 442ce99432dd..616bd5a420ac 100644 --- a/python/tvm/contrib/miopen.py +++ b/python/tvm/contrib/miopen.py @@ -49,7 +49,8 @@ def conv2d_forward(x, pad_w=0, dilation_h=1, dilation_w=1, - conv_mode=0): + conv_mode=0, + data_type=1): """Create an extern op that compute 2D convolution with MIOpen Parameters @@ -73,18 +74,22 @@ def conv2d_forward(x, conv_mode: int 0: miopenConvolution 1: miopenTranspose + data_type: int + 0: miopenHalf (fp16) + 1: miopenFloat (fp32) Returns ------- y: Tensor The result tensor """ - assert conv_mode == 0, "Transpose convolutions not supported yet." + assert (conv_mode == 0 or conv_mode == 1), "0: miopenConvolution / 1: miopenTranspose" oshape = np.zeros((len(x.shape)), dtype=np.int32) xshape = x.shape wshape = w.shape setup_func = _get_global_func("tvm.contrib.miopen.conv2d.setup") algo = setup_func(conv_mode, + data_type, pad_h, pad_w, stride_h, @@ -106,6 +111,7 @@ def conv2d_forward(x, lambda ins, outs: _intrin.call_packed( "tvm.contrib.miopen.conv2d.forward", conv_mode, + data_type, pad_h, pad_w, stride_h, diff --git a/src/contrib/miopen/conv_forward.cc b/src/contrib/miopen/conv_forward.cc index 705a0d47a835..baac86b8603d 100644 --- a/src/contrib/miopen/conv_forward.cc +++ b/src/contrib/miopen/conv_forward.cc @@ -35,21 +35,22 @@ using namespace runtime; TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") .set_body([](TVMArgs args, TVMRetValue *ret) { const int mode = args[0]; - const int pad_h = args[1]; - const int pad_w = args[2]; - const int stride_h = args[3]; - const int stride_w = args[4]; - const int dilation_h = args[5]; - const int dilation_w = args[6]; - const int x_dim0 = args[7]; - const int x_dim1 = args[8]; - const int x_dim2 = args[9]; - const int x_dim3 = args[10]; - const int w_dim0 = args[11]; - const int w_dim1 = args[12]; - const int w_dim2 = args[13]; - const int w_dim3 = args[14]; - void *out_shape = args[15]; + const int dtype = args[1]; + const int pad_h = args[2]; + const int pad_w = args[3]; + const int stride_h = args[4]; + const int stride_w = args[5]; + const int dilation_h = args[6]; + const int dilation_w = args[7]; + const int x_dim0 = args[8]; + const int x_dim1 = args[9]; + const int x_dim2 = args[10]; + const int x_dim3 = args[11]; + const int w_dim0 = args[12]; + const int w_dim1 = args[13]; + const int w_dim2 = args[14]; + const int w_dim3 = args[15]; + void *out_shape = args[16]; MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); // Set Mode @@ -57,7 +58,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") // Set Ctx entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0}; // Set Data Type - entry_ptr->conv_entry.data_type = miopenFloat; // MIOpen only suppports fp32 + entry_ptr->conv_entry.data_type = static_cast( + dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at + // this moment. // Set Desc MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.mode, @@ -170,16 +173,17 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward") .set_body([](TVMArgs args, TVMRetValue *ret) { const int mode = args[0]; - const int pad_h = args[1]; - const int pad_w = args[2]; - const int stride_h = args[3]; - const int stride_w = args[4]; - const int dilation_h = args[5]; - const int dilation_w = args[6]; - const int algo = args[7]; - const DLTensor *x = args[8]; - const DLTensor *w = args[9]; - const DLTensor *y = args[10]; + const int dtype = args[1]; + const int pad_h = args[2]; + const int pad_w = args[3]; + const int stride_h = args[4]; + const int stride_w = args[5]; + const int dilation_h = args[6]; + const int dilation_w = args[7]; + const int algo = args[8]; + const DLTensor *x = args[9]; + const DLTensor *w = args[10]; + const DLTensor *y = args[11]; MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); entry_ptr->conv_entry.fwd_algo = static_cast(algo); @@ -188,7 +192,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward") // Set Ctx entry_ptr->conv_entry.ctx = x->ctx; // Set Data Type - entry_ptr->conv_entry.data_type = miopenFloat; // MIOpen only suppports fp32 + entry_ptr->conv_entry.data_type = static_cast( + dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at + // this moment. // Set Desc MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.mode, diff --git a/tests/python/contrib/test_miopen.py b/tests/python/contrib/test_miopen.py index 4abaa8d06985..5d82f6a14ce5 100644 --- a/tests/python/contrib/test_miopen.py +++ b/tests/python/contrib/test_miopen.py @@ -50,7 +50,8 @@ def test_conv2d(): pad_w, dilation_h, dilation_w, - conv_mode=0) + conv_mode=0, + data_type=1) yshape = [x.value for x in Y.shape] import topi @@ -65,7 +66,7 @@ def verify(): y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx) f(x, w, y) - Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w)) + Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w), (dilation_h, dilation_w)) with tvm.target.rocm(): s_ref = topi.generic.schedule_conv2d_nchw([Y_ref]) f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm") diff --git a/topi/python/topi/rocm/conv2d.py b/topi/python/topi/rocm/conv2d.py index aacdb90286a6..ce9e57e4061d 100644 --- a/topi/python/topi/rocm/conv2d.py +++ b/topi/python/topi/rocm/conv2d.py @@ -78,7 +78,8 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou pad_w, dilation_h, dilation_w, - conv_mode=0) + conv_mode=0, + data_type=1) return conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)