Skip to content

Commit

Permalink
[CUDA][CodeGen] Fix cuda codegen's fp16 inf literal (#12581)
Browse files Browse the repository at this point in the history
* Fix cuda codegen's fp16 inf literal

* add relay testcase
  • Loading branch information
wrongtest-intellif authored Aug 25, 2022
1 parent 21db1eb commit bb00a15
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
6 changes: 4 additions & 2 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1197,8 +1197,10 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p)
break;
}
case 16: {
os << "__float2half_rn";
os << '(' << std::scientific << op->value << 'f' << ')';
os << "__float2half_rn" << '(';
FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
PrintConst(const_f32.get(), os, p);
os << ')';
break;
}
default:
Expand Down
16 changes: 12 additions & 4 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,7 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0, indices_dtype="int32"
verify_gather_nd((2, 2, 2), (2, 2, 1), [[[1], [0]], [[0], [1]]], 1, indices_dtype="uint32")


def _verify_infiniteness_ops(relay_op, ref_op):
def _verify_infiniteness_ops(relay_op, ref_op, target="llvm", dev=None):
for dtype in ["float32", "float16", "float16", "int32", "int16"]:
shape = (2, 8, 8)
x = relay.var("x", relay.TensorType(shape, dtype))
Expand All @@ -1359,17 +1359,25 @@ def _verify_infiniteness_ops(relay_op, ref_op):
] = np.infty
data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.nan

op_res = create_executor().evaluate(y, {x: data})
op_res = create_executor(target=target, device=dev).evaluate(y, {x: data})
ref_res = ref_op(data)
np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01)


@tvm.testing.requires_gpu
def test_isfinite():
_verify_infiniteness_ops(relay.isfinite, np.isfinite)
for target, dev in tvm.testing.enabled_targets():
if target not in ["llvm", "cuda"]:
continue
_verify_infiniteness_ops(relay.isfinite, np.isfinite, target=target, dev=dev)


@tvm.testing.requires_gpu
def test_isinf():
_verify_infiniteness_ops(relay.isinf, np.isinf)
for target, dev in tvm.testing.enabled_targets():
if target not in ["llvm", "cuda"]:
continue
_verify_infiniteness_ops(relay.isinf, np.isinf, target=target, dev=dev)


def test_unravel_index(target, dev, executor_kind):
Expand Down

0 comments on commit bb00a15

Please sign in to comment.