Skip to content

Commit

Permalink
Changes by black.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rishabh Jain committed Sep 21, 2020
1 parent 997e0a2 commit 8bff88c
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
5 changes: 3 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,8 +1245,9 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"):
super_diag_right_align = align[:5] == "RIGHT"
sub_diag_right_align = align[-5:] == "RIGHT"

return _make.matrix_set_diag(data, diagonal, k_one, k_two, super_diag_right_align,
sub_diag_right_align)
return _make.matrix_set_diag(
data, diagonal, k_one, k_two, super_diag_right_align, sub_diag_right_align
)


def adv_index(inputs):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,8 +874,8 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"):
sub_diag_right_align = align[-5:] == "RIGHT"

return cpp.matrix_set_diag(
data, diagonal, k_one, k_two, super_diag_right_align, sub_diag_right_align
)
data, diagonal, k_one, k_two, super_diag_right_align, sub_diag_right_align
)


def adv_index(data, indices):
Expand Down
9 changes: 5 additions & 4 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,11 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
out_relay = intrp.evaluate(func)(input_np, diagonal_np)
tvm.testing.assert_allclose(out_relay.asnumpy(), out_np)

_verify((2, 2), (2,), 'float32')
_verify((4, 3, 3), (4, 3), 'int32')
_verify((2, 3, 4), (2, 3), 'float32', 1)
_verify((2, 3, 4), (2, 4, 3), 'int32', (-1, 2), "LEFT_RIGHT")
_verify((2, 2), (2,), "float32")
_verify((4, 3, 3), (4, 3), "int32")
_verify((2, 3, 4), (2, 3), "float32", 1)
_verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "LEFT_RIGHT")


if __name__ == "__main__":
test_adaptive_pool()
Expand Down
3 changes: 2 additions & 1 deletion tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ def verify_matrix_set_diag(input_shape, diagonal_shape, dtype, k=0, align="RIGHT
input = te.placeholder(shape=input_shape, name="input", dtype=dtype)
diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype)
matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, k, align)

def check_device(device, ctx):
ctx = tvm.context(device, 0)
print("Running on target: %s" % device)
Expand Down Expand Up @@ -1161,7 +1162,7 @@ def test_sparse_to_dense():

@tvm.testing.uses_gpu
def test_matrix_set_diag():
for dtype in ['float32', 'int32']:
for dtype in ["float32", "int32"]:
verify_matrix_set_diag((2, 2), (2,), dtype)
verify_matrix_set_diag((4, 3, 3), (4, 3), dtype)
verify_matrix_set_diag((2, 3, 4), (2, 3), dtype, 1)
Expand Down

0 comments on commit 8bff88c

Please sign in to comment.