Skip to content

Commit

Permalink
fix test dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
tkonolige committed Feb 18, 2021
1 parent f69589d commit 60f8d87
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_grad_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_embed_grad():
table = relay.var("table", shape=(6, 3), dtype="float64")
indices = relay.var("indices", shape=(4,), dtype="int64")
table_nd = np.reshape(np.arange(18), (6, 3)).astype("float64")
indices_nd = np.array([0, 0, 3, 2])
indices_nd = np.array([0, 0, 3, 2]).astype("int64")
fwd_func = relay.Function([table, indices], relay.nn.embed(table, indices))
# Can't test against indices because the function is nonsmooth with respect to them.
check_grad(fwd_func, inputs=[table_nd, indices_nd], test_inputs=[table_nd])
Expand Down

0 comments on commit 60f8d87

Please sign in to comment.