From 0b36b638661d1015639b716e662989634211f733 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 5 Nov 2020 10:42:39 -0800 Subject: [PATCH] fix test dtype --- tests/python/relay/test_op_grad_level1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_op_grad_level1.py b/tests/python/relay/test_op_grad_level1.py index b20f74edebefb..7b79f0d63ef63 100644 --- a/tests/python/relay/test_op_grad_level1.py +++ b/tests/python/relay/test_op_grad_level1.py @@ -154,7 +154,7 @@ def test_embed_grad(): table = relay.var("table", shape=(6, 3), dtype="float64") indices = relay.var("indices", shape=(4,), dtype="int64") table_nd = np.reshape(np.arange(18), (6, 3)).astype("float64") - indices_nd = np.array([0, 0, 3, 2]) + indices_nd = np.array([0, 0, 3, 2]).astype("int64") fwd_func = relay.Function([table, indices], relay.nn.embed(table, indices)) # Can't test against indices because the function is nonsmooth with respect to them. check_grad(fwd_func, inputs=[table_nd, indices_nd], test_inputs=[table_nd])