Skip to content

Commit

Permalink
Changes by black.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rishabh Jain committed Sep 21, 2020
1 parent 4b1dc99 commit 6b71ecc
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 9 deletions.
5 changes: 3 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,8 +1245,9 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"):
super_diag_right_align = align[:5] == "RIGHT"
sub_diag_right_align = align[-5:] == "RIGHT"

return _make.matrix_set_diag(data, diagonal, k_one, k_two, super_diag_right_align,
sub_diag_right_align)
return _make.matrix_set_diag(
data, diagonal, k_one, k_two, super_diag_right_align, sub_diag_right_align
)


def adv_index(inputs):
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/testing/matrix_set_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""MatrixSetDiag in Python"""
import numpy as np


def matrix_set_diag(input_np, diagonal, k=0, align="RIGHT_LEFT"):
"""matrix_set_diag operator implemented in numpy.
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,7 @@ def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0

return cpp.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value)


def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"):
"""
Returns a tensor with the diagonals of input tensor replaced with the provided diagonal values.
Expand Down Expand Up @@ -872,8 +873,9 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"):
super_diag_right_align = align[:5] == "RIGHT"
sub_diag_right_align = align[-5:] == "RIGHT"

return cpp.matrix_set_diag(data, diagonal, k_one, k_two, super_diag_right_align,
sub_diag_right_align)
return cpp.matrix_set_diag(
data, diagonal, k_one, k_two, super_diag_right_align, sub_diag_right_align
)


def adv_index(data, indices):
Expand Down
9 changes: 5 additions & 4 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,11 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
out_relay = intrp.evaluate(func)(input_np, diagonal_np)
tvm.testing.assert_allclose(out_relay.asnumpy(), out_np)

_verify((2, 2), (2,), 'float32')
_verify((4, 3, 3), (4, 3), 'int32')
_verify((2, 3, 4), (2, 3), 'float32', 1)
_verify((2, 3, 4), (2, 4, 3), 'int32', (-1, 2), "LEFT_RIGHT")
_verify((2, 2), (2,), "float32")
_verify((4, 3, 3), (4, 3), "int32")
_verify((2, 3, 4), (2, 3), "float32", 1)
_verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "LEFT_RIGHT")


if __name__ == "__main__":
test_adaptive_pool()
Expand Down
4 changes: 3 additions & 1 deletion tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,10 +714,12 @@ def check_device(device, ctx):
for device, ctx in tvm.testing.enabled_targets():
check_device(device, ctx)


def verify_matrix_set_diag(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
input = te.placeholder(shape=input_shape, name="input", dtype=dtype)
diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype)
matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, k, align)

def check_device(device, ctx):
ctx = tvm.context(device, 0)
print("Running on target: %s" % device)
Expand Down Expand Up @@ -1160,7 +1162,7 @@ def test_sparse_to_dense():

@tvm.testing.uses_gpu
def test_matrix_set_diag():
for dtype in ['float32', 'int32']:
for dtype in ["float32", "int32"]:
verify_matrix_set_diag((2, 2), (2,), dtype)
verify_matrix_set_diag((4, 3, 3), (4, 3), dtype)
verify_matrix_set_diag((2, 3, 4), (2, 3), dtype, 1)
Expand Down

0 comments on commit 6b71ecc

Please sign in to comment.