Skip to content

Commit

Permalink
Onnx mod, bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Jul 28, 2020
1 parent 4d0fa8b commit 6751c29
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,10 +530,11 @@ class Mod(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "Mod op take 2 inputs, {} given".format(len(inputs))
if attr['fmod'] == 1:
if attr['fmod'] == 0:
op_name = "floor_mod"
else:
op_name = "mod"

return AttrCvt(op_name)(inputs, {}, params)


Expand Down
29 changes: 13 additions & 16 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2374,17 +2374,11 @@ def test_pooling():
auto_pad='SAME_UPPER')


def verify_mod(x_shape, y_shape, fmod, dtype='float32'):
x_np = np.random.uniform(size=x_shape).astype(dtype)
y_np = np.random.uniform(size=y_shape).astype(dtype)
def verify_mod(x_shape, y_shape, fmod, out_shape, dtype='float32'):
x_np = np.random.uniform(-100.0, 100.0, x_shape).astype(dtype)
y_np = np.random.uniform(-100.0, 100.0, y_shape).astype(dtype)
y_np = np.where(y_np==0, 1, y_np) #remove 0's to avoid division by zero error

if fmod:
np_out = np.fmod(x_np, y_np)
else:
np_out = np.mod(x_np, y_np)

out_shape = np_out.shape
mod_node = helper.make_node("Mod",
inputs=["x", "y"],
outputs=["z"],
Expand All @@ -2401,22 +2395,25 @@ def verify_mod(x_shape, y_shape, fmod, dtype='float32'):
onnx_dtype, list(out_shape))])
model = helper.make_model(graph, producer_name='mod_test')

onnx_out = get_onnxruntime_output(model, [x_np, y_np], dtype)[0]

for target, ctx in ctx_list():
tvm_out = get_tvm_output(
model, [x_np, y_np], target, ctx, out_shape)
tvm.testing.assert_allclose(np_out, tvm_out, rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)


def test_mod():
# Mod
verify_mod(x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=0)

verify_mod(x_shape=[1, 32, 32], y_shape=[1, 1, 32], fmod=0, dtype="int32")
verify_mod(x_shape=[1, 32, 32], y_shape=[1, 1, 32], fmod=0, out_shape=(1, 32, 32), dtype="int32")
verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=0, out_shape=(1, 32, 32, 32), dtype="int32")

# fmod
verify_mod(x_shape=[1, 1, 32], y_shape=[1, 32, 32], fmod=1)

verify_mod(x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=1, dtype="int32")
verify_mod(x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=1, out_shape=(1, 32, 32), dtype="int32")
verify_mod(x_shape=[1, 1, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32))
verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 1, 32, 32], fmod=1, out_shape=(1, 32, 32, 32))
verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32), dtype="int32")
verify_mod(x_shape=[1, 32, 32, 32], y_shape=[1, 32, 32, 32], fmod=1, out_shape=(1, 32, 32, 32))


def verify_xor(x_shape, y_shape):
Expand Down

0 comments on commit 6751c29

Please sign in to comment.