diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index c9636a7517cd..fb6d57e4618e 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1245,8 +1245,9 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"): super_diag_right_align = align[:5] == "RIGHT" sub_diag_right_align = align[-5:] == "RIGHT" - return _make.matrix_set_diag(data, diagonal, k_one, k_two, super_diag_right_align, - sub_diag_right_align) + return _make.matrix_set_diag( + data, diagonal, k_one, k_two, super_diag_right_align, sub_diag_right_align + ) def adv_index(inputs): diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index c223fda67d8a..7d77115a9060 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -874,8 +874,8 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"): sub_diag_right_align = align[-5:] == "RIGHT" return cpp.matrix_set_diag( - data, diagonal, k_one, k_two, super_diag_right_align, sub_diag_right_align - ) + data, diagonal, k_one, k_two, super_diag_right_align, sub_diag_right_align + ) def adv_index(data, indices): diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 026f1e644da9..577d1c78d8e2 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -526,10 +526,11 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): out_relay = intrp.evaluate(func)(input_np, diagonal_np) tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) - _verify((2, 2), (2,), 'float32') - _verify((4, 3, 3), (4, 3), 'int32') - _verify((2, 3, 4), (2, 3), 'float32', 1) - _verify((2, 3, 4), (2, 4, 3), 'int32', (-1, 2), "LEFT_RIGHT") + _verify((2, 2), (2,), "float32") + _verify((4, 3, 3), (4, 3), "int32") + _verify((2, 3, 4), (2, 3), "float32", 1) + _verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "LEFT_RIGHT") + if __name__ == "__main__": test_adaptive_pool() diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index b7448b2e8c5c..77e25407fc70 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -719,6 +719,7 @@ def verify_matrix_set_diag(input_shape, diagonal_shape, dtype, k=0, align="RIGHT input = te.placeholder(shape=input_shape, name="input", dtype=dtype) diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype) matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, k, align) + def check_device(device, ctx): ctx = tvm.context(device, 0) print("Running on target: %s" % device) @@ -1161,7 +1162,7 @@ def test_sparse_to_dense(): @tvm.testing.uses_gpu def test_matrix_set_diag(): - for dtype in ['float32', 'int32']: + for dtype in ["float32", "int32"]: verify_matrix_set_diag((2, 2), (2,), dtype) verify_matrix_set_diag((4, 3, 3), (4, 3), dtype) verify_matrix_set_diag((2, 3, 4), (2, 3), dtype, 1)