Skip to content

Commit

Permalink
Fix serialization of inf float value (#5912)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiaoquan authored Jun 24, 2020
1 parent 11815b8 commit fcaba98
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/node/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,15 @@ class JSONAttrSetter : public AttrVisitor {
template <typename T>
void ParseValue(const char* key, T* value) const {
std::istringstream is(GetValue(key));
is >> *value;
if (is.fail()) {
LOG(FATAL) << "Wrong value format for field " << key;
if (is.str() == "inf") {
*value = std::numeric_limits<T>::infinity();
} else if (is.str() == "-inf") {
*value = -std::numeric_limits<T>::infinity();
} else {
is >> *value;
if (is.fail()) {
LOG(FATAL) << "Wrong value format for field " << key;
}
}
}
void Visit(const char* key, double* value) final { ParseValue(key, value); }
Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_node_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ def test_const_saveload_json():
zz = tvm.ir.load_json(json_str)
tvm.ir.assert_structural_equal(zz, z, map_free_vars=True)

def _test_infinity_value(value, dtype):
x = tvm.tir.const(value, dtype)
json_str = tvm.ir.save_json(x)
tvm.ir.assert_structural_equal(x, tvm.ir.load_json(json_str))

def test_infinity_value():
_test_infinity_value(float("inf"), 'float64')
_test_infinity_value(float("-inf"), 'float64')
_test_infinity_value(float("inf"), 'float32')
_test_infinity_value(float("-inf"), 'float32')

def test_make_smap():
# save load json
Expand Down Expand Up @@ -145,3 +155,4 @@ def test_dict():
test_make_sum()
test_pass_config()
test_dict()
test_infinity_value()

0 comments on commit fcaba98

Please sign in to comment.