Skip to content

Commit

Permalink
[Relay] Stop ToMixedPrecision when constant is out of dtype range (ap…
Browse files Browse the repository at this point in the history
…ache#15461)

* [Relay] Stop ToMixedPrecision when constant is out of dtype range

In some layers, e.g. Clip, we might have a compilation error in the
case when operation takes on the input a constant which is out of
target data type range.

To prevent such situation, a new method was introduced. It compares
values of constant attributes with the range of the target data type. In
case if the value is out of range then float32 will be used.

* Fix lint
  • Loading branch information
echuraev authored Aug 3, 2023
1 parent 1b7175b commit 0e905aa
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 2 deletions.
41 changes: 39 additions & 2 deletions src/relay/transforms/to_mixed_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include <utility>

#include "../../support/scalars.h"
#include "pattern_utils.h"

namespace tvm {
Expand Down Expand Up @@ -110,6 +111,39 @@ class MixedPrecisionPass : public MixedModeMutator {
std::vector<DataType> original_dtype_;
bool keep_orig_output_dtype_;

/*! \brief If some of the constant attributes are out of mixed_precision_type_ bounds, then
* computation cannot be performed in mixed precision. */
bool IsMixedPrecisionApplicableToAttrs(const Attrs& attrs) const {
if (attrs.get() != nullptr) {
double min_bound;
double max_bound;
if (mixed_precision_type_.is_float16()) {
min_bound = -support::kMaxFloat16;
max_bound = support::kMaxFloat16;
} else if (mixed_precision_type_.is_bfloat16()) {
min_bound = -support::kMaxBFloat16;
max_bound = support::kMaxBFloat16;
} else if (mixed_precision_type_.is_float8()) {
double bound = (mixed_precision_type_.code() == DataType::kE4M3Float) ? support::kMaxE4M3
: support::kMaxE5M2;
min_bound = -bound;
max_bound = bound;
} else if (mixed_precision_type_.is_float()) {
min_bound = std::numeric_limits<float>::lowest();
max_bound = std::numeric_limits<float>::max();
} else {
return true;
}

if (auto cur_attrs = attrs.as<ClipAttrs>()) {
if (cur_attrs->a_min < min_bound || cur_attrs->a_max > max_bound) {
return false;
}
}
}
return true;
}

Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
/* If the accumulation dtype is in the attributes make a copy and mutate the field. */
Attrs cur_attrs = call->attrs;
Expand Down Expand Up @@ -382,9 +416,12 @@ class MixedPrecisionPass : public MixedModeMutator {
all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER;
}

bool is_mixed_precision_applicable =
static_cast<bool>(final_category == MIXED_PRECISION_ALWAYS &&
IsMixedPrecisionApplicableToAttrs(pre_call_node->attrs));
// Create the new arguments to the call.
DataType wanted_arg_dtypes =
final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type_ : DataType::Float(32);
is_mixed_precision_applicable ? mixed_precision_type_ : DataType::Float(32);
auto call_args_and_types = CastAllArgs(post_call_node->args, cur_arg_types, wanted_arg_dtypes);
Array<Expr> new_args = call_args_and_types.first;
Array<Type> new_arg_types;
Expand All @@ -397,7 +434,7 @@ class MixedPrecisionPass : public MixedModeMutator {
}

// Finally create the new attributes.
if (final_category == MIXED_PRECISION_ALWAYS) {
if (is_mixed_precision_applicable) {
Attrs new_attrs = GetNewAttrs(pre_call_node, accumulation_dtype);
Expr output = Call(cur_op, new_args, new_attrs, new_arg_types, pre_call_node->span);
if (accumulation_dtype != output_dtype) {
Expand Down
49 changes: 49 additions & 0 deletions tests/python/relay/test_to_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,5 +537,54 @@ def test_convert_follow_node_with_integer_arguments(target_precision):
assert tvm.ir.structural_equal(expected_mod, output_mod)


def test_clip(target_precision):
data = relay.var("data", shape=[1, 10], dtype="float32")
res = relay.clip(data, a_min=-128000, a_max=128000)

mod = tvm.IRModule.from_expr(res)

mod_params = {
"data": np.random.uniform(-1, 1, size=[1, 10]).astype("float32"),
}
output_mod = verify_mixed_precision_output_close(
mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=0.01
)

# Create expected module
if target_precision == "bfloat16":
data = relay.cast(relay.var("data", shape=[1, 10]), target_precision)
res = relay.clip(data, a_min=-128000, a_max=128000)
expected_mod = tvm.IRModule.from_expr(res)
expected_mod = InferType()(expected_mod)
assert tvm.ir.structural_equal(expected_mod, output_mod)


def test_clip_with_pre_op(target_precision):
data = relay.var("data", shape=[1, 10], dtype="float32")
const = relay.const(5, "float32")
res = relay.divide(data, const)
res = relay.clip(res, a_min=-128000, a_max=128000)

mod = tvm.IRModule.from_expr(res)

mod_params = {
"data": np.random.uniform(-1, 1, size=[1, 10]).astype("float32"),
}
output_mod = verify_mixed_precision_output_close(
mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=0.01
)

# Create expected module
data = relay.cast(relay.var("data", shape=[1, 10]), target_precision)
const = relay.cast(relay.const(5, "float32"), target_precision)
res = relay.divide(data, const)
if target_precision == "float16":
res = relay.cast(res, "float32")
res = relay.clip(res, a_min=-128000, a_max=128000)
expected_mod = tvm.IRModule.from_expr(res)
expected_mod = InferType()(expected_mod)
assert tvm.ir.structural_equal(expected_mod, output_mod)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 0e905aa

Please sign in to comment.