-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
Hey @anko-intel , Thanks for submitting the PR
CI supported jobs: [clang, unix-cpu, unix-gpu, centos-gpu, windows-cpu, windows-gpu, miscellaneous, website, sanity, edge, centos-cpu] Note: |
OneDNN doesn't support float16 format so fallback to standard implementation is needed. It fixes issue 19631.
4d22ab7
to
aed0619
Compare
@rongzha1 - could you review? |
src/operator/tensor/amp_cast.cc
Outdated
mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); | ||
for (size_t i = 0; i < i_ndim; i++) { | ||
i_dims[i] = static_cast<int>(data.shape()[i]); | ||
if (data.dtype() != mshadow::kFloat16) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we add isValidMKLDNNDataType() to check whether it is supported by mkldnn? mshadow has so many data types and some of them are not supported. https://github.com/apache/incubator-mxnet/blob/64f737cdd59fe88d2c5b479f25d011c5156b6a8a/3rdparty/mshadow/mshadow/base.h#L364:3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I considered that. If created isValidMKLDNNDataType() function could be used in many places like MKLDNNStorageType() for FInferStorageType it makes sense. But in this particular situation, amp_cast operator only accept 3 float types (see https://github.com/apache/incubator-mxnet/blob/v1.x/src/operator/tensor/amp_cast.h#L70 ) so I just excluded float16 as not supported in MKLDNN.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. LGTM
src/operator/tensor/amp_cast.cc
Outdated
mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); | ||
for (size_t i = 0; i < i_ndim; i++) { | ||
i_dims[i] = static_cast<int>(data.shape()[i]); | ||
if (data.dtype() != mshadow::kFloat16) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. LGTM
@PatricZhao, @szha could you review and merge if everything is ok? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix! Could you add a test for verification?
@mxnet-bot run ci [centos-cpu, unix-gpu] |
Jenkins CI successfully triggered : [centos-cpu, unix-gpu] |
* Fix AmpCast for float16 OneDNN doesn't support float16 format so fallback to standard implementation is needed. It fixes issue 19631. * Enable amp_cast test for float16 on CPU context
Description
OneDNN doesn't support float16 format, so fallback to standard
implementation is needed.
It fixes issue #19631.
Checklist
Essentials
Comments