Skip to content

Commit

Permalink
[Fix] Fix a few bugs when dtype is fp16 (apache#4088)
Browse files Browse the repository at this point in the history
* Fix layer norm for fp16

* [Fix] Fix arange for fp16

* [Fix] Fix mxnet frontend for fp16

* [Fix] Fix arange for fp16

* remove comments

* x

* fix nnvm
  • Loading branch information
icemelon authored and Animesh Jain committed Oct 17, 2019
1 parent 4edd1ee commit 959bcf7
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 15 deletions.
18 changes: 12 additions & 6 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,12 +615,17 @@ def _mx_arange(inputs, attrs):
if attrs.get_int("repeat", 1) != 1:
raise tvm.error.OpAttributeUnimplemented(
'Attribute "repeat" is not supported in operator arange.')
new_attrs = {}
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0))
dtype = attrs.get_str("dtype", "float32")
stop = attrs.get_str("stop", "None")
new_attrs["stop"] = None if stop == "None" else _expr.const(float(stop))
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0))
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
if stop == "None":
stop = None
else:
stop = _expr.const(float(stop), dtype=dtype)
new_attrs = {}
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0), dtype=dtype)
new_attrs["stop"] = stop
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0), dtype=dtype)
new_attrs["dtype"] = dtype
return _op.arange(**new_attrs)


Expand Down Expand Up @@ -863,7 +868,8 @@ def _mx_contrib_div_sqrt_dim(inputs, _):
assert len(inputs) == 1
ndim = len(_infer_type(inputs[0]).checked_type.shape)
dim = _op.take(_op.shape_of(inputs[0]), _expr.const(ndim-1, dtype="int32"))
sqrt_dim = _op.sqrt(dim.astype('float32'))
dtype = _infer_type(inputs[0]).checked_type.dtype
sqrt_dim = _op.sqrt(dim.astype(dtype))
out = inputs[0] / sqrt_dim
return out

Expand Down
11 changes: 7 additions & 4 deletions python/tvm/relay/frontend/nnvm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .. import expr as _expr
from .. import op as _op
from .common import get_relay_op
from .common import infer_type as _infer_type

def _warn_not_used(attr, op='nnvm'):
import warnings
Expand Down Expand Up @@ -123,20 +124,22 @@ def _elemwise_sum(inputs, _, _dtype='float32'):


def _binop_scalar(new_op):
def _impl(inputs, attrs, odtype='float32'):
def _impl(inputs, attrs, odtype=None):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
if odtype is None:
odtype = _infer_type(inputs[0]).checked_type.dtype
scalar = _expr.const(scalar, dtype=odtype)
return new_op(inputs[0], scalar)
return _impl


def _rbinop_scalar(new_op):
def _impl(inputs, attrs, odtype='float32'):
def _impl(inputs, attrs, odtype=None):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
if odtype is None:
odtype = _infer_type(inputs[0]).checked_type.dtype
scalar = _expr.const(scalar, dtype=odtype)
return new_op(scalar, inputs[0])
return _impl
Expand Down
39 changes: 35 additions & 4 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/expr_operator.h>
#include <tvm/ir.h>
#include <tvm/data_layout.h>
#include <tvm/runtime/packed_func.h>
#include <topi/transform.h>
#include <topi/elemwise.h>
#include <topi/broadcast.h>
Expand Down Expand Up @@ -1139,11 +1140,41 @@ and type as the input array.
TVM_REGISTER_NODE_TYPE(ArangeAttrs);

double ToScalar(const runtime::NDArray& array) {
if (array->dtype.code == kDLInt || array->dtype.code == kDLUInt) {
return reinterpret_cast<int32_t*>(array->data)[0];
} else {
return reinterpret_cast<float*>(array->data)[0];
if (array->dtype.code == kDLInt) {
if (array->dtype.bits == 8) {
return reinterpret_cast<int8_t*>(array->data)[0];
} else if (array->dtype.bits == 16) {
return reinterpret_cast<int16_t*>(array->data)[0];
} else if (array->dtype.bits == 32) {
return reinterpret_cast<int32_t*>(array->data)[0];
} else if (array->dtype.bits == 64) {
return reinterpret_cast<int64_t*>(array->data)[0];
}
} else if (array->dtype.code == kDLUInt) {
if (array->dtype.bits == 8) {
return reinterpret_cast<uint8_t*>(array->data)[0];
} else if (array->dtype.bits == 16) {
return reinterpret_cast<uint16_t*>(array->data)[0];
} else if (array->dtype.bits == 32) {
return reinterpret_cast<uint32_t*>(array->data)[0];
} else if (array->dtype.bits == 64) {
return reinterpret_cast<uint64_t*>(array->data)[0];
}
} else if (array->dtype.code == kDLFloat) {
#if (__ARM_FP16_FORMAT_IEEE == 1)
if (array->dtype.bits == 16) {
return reinterpret_cast<__fp16*>(array->data)[0];
}
#endif
if (array->dtype.bits == 32) {
return reinterpret_cast<float*>(array->data)[0];
} else if (array->dtype.bits == 64) {
return reinterpret_cast<double*>(array->data)[0];
}
}
LOG(FATAL) << "Unknown data type: " << tvm::runtime::TVMType2String(array->dtype);
// make compiler happy
return -std::numeric_limits<double>::infinity();
}

bool ArangeRel(const Array<Type>& types,
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/simplify_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
const auto param = attrs.as<LayerNormAttrs>();
CHECK(param);

Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
Expr mean = Mean(data, {param->axis}, true, false);
Expr var = Variance(data, mean, {param->axis}, true, false);
Expr denom = Sqrt(Add(var, epsilon));
Expand Down

0 comments on commit 959bcf7

Please sign in to comment.