diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index c9636a7517cd0..fb6d57e4618ef 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1245,8 +1245,9 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"): super_diag_right_align = align[:5] == "RIGHT" sub_diag_right_align = align[-5:] == "RIGHT" - return _make.matrix_set_diag(data, diagonal, k_one, k_two, super_diag_right_align, - sub_diag_right_align) + return _make.matrix_set_diag( + data, diagonal, k_one, k_two, super_diag_right_align, sub_diag_right_align + ) def adv_index(inputs): diff --git a/python/tvm/topi/testing/matrix_set_diag.py b/python/tvm/topi/testing/matrix_set_diag.py index b2753046efb51..81a8f6cccafe1 100644 --- a/python/tvm/topi/testing/matrix_set_diag.py +++ b/python/tvm/topi/testing/matrix_set_diag.py @@ -18,6 +18,7 @@ """MatrixSetDiag in Python""" import numpy as np + def matrix_set_diag(input_np, diagonal, k=0, align="RIGHT_LEFT"): """matrix_set_diag operator implemented in numpy. diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 3bcf383184cac..7d77115a9060f 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -805,6 +805,7 @@ def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0 return cpp.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value) + def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"): """ Returns a tensor with the diagonals of input tensor replaced with the provided diagonal values. @@ -872,8 +873,9 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"): super_diag_right_align = align[:5] == "RIGHT" sub_diag_right_align = align[-5:] == "RIGHT" - return cpp.matrix_set_diag(data, diagonal, k_one, k_two, super_diag_right_align, - sub_diag_right_align) + return cpp.matrix_set_diag( + data, diagonal, k_one, k_two, super_diag_right_align, sub_diag_right_align + ) def adv_index(data, indices): diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 026f1e644da95..577d1c78d8e2e 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -526,10 +526,11 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): out_relay = intrp.evaluate(func)(input_np, diagonal_np) tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) - _verify((2, 2), (2,), 'float32') - _verify((4, 3, 3), (4, 3), 'int32') - _verify((2, 3, 4), (2, 3), 'float32', 1) - _verify((2, 3, 4), (2, 4, 3), 'int32', (-1, 2), "LEFT_RIGHT") + _verify((2, 2), (2,), "float32") + _verify((4, 3, 3), (4, 3), "int32") + _verify((2, 3, 4), (2, 3), "float32", 1) + _verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "LEFT_RIGHT") + if __name__ == "__main__": test_adaptive_pool() diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 3de7d73fc5566..77e25407fc70c 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -714,10 +714,12 @@ def check_device(device, ctx): for device, ctx in tvm.testing.enabled_targets(): check_device(device, ctx) + def verify_matrix_set_diag(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): input = te.placeholder(shape=input_shape, name="input", dtype=dtype) diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype) matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, k, align) + def check_device(device, ctx): ctx = tvm.context(device, 0) print("Running on target: %s" % device) @@ -1160,7 +1162,7 @@ def test_sparse_to_dense(): @tvm.testing.uses_gpu def test_matrix_set_diag(): - for dtype in ['float32', 'int32']: + for dtype in ["float32", "int32"]: verify_matrix_set_diag((2, 2), (2,), dtype) verify_matrix_set_diag((4, 3, 3), (4, 3), dtype) verify_matrix_set_diag((2, 3, 4), (2, 3), dtype, 1)