From 2e9941bb2b9e083513ed08a65ffa0a5b32e4c8cb Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 30 Jan 2023 21:33:03 -0800 Subject: [PATCH] Fix initializer for CumSum. (#9) --- python/tvm/relax/frontend/onnx_frontend.py | 3 +-- tests/python/relax/frontend/test_onnx_frontend.py | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/onnx_frontend.py b/python/tvm/relax/frontend/onnx_frontend.py index 70b9f9b8ea61..23d2aca5a685 100644 --- a/python/tvm/relax/frontend/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx_frontend.py @@ -454,8 +454,7 @@ class CumSum(OnnxOpConverter): def _impl_v13(cls, bb, inputs, attr): assert getattr(attr, "reverse", 0) == 0, "reverse is not supported yet" if len(inputs) > 1: - # axis = int(infer_value(inputs[1], params).numpy()) - axis = inputs[1] + axis = int(inputs[1].data.numpy()) else: axis = None return bb.emit_te( diff --git a/tests/python/relax/frontend/test_onnx_frontend.py b/tests/python/relax/frontend/test_onnx_frontend.py index cfeeb0d78774..d35355bde78a 100644 --- a/tests/python/relax/frontend/test_onnx_frontend.py +++ b/tests/python/relax/frontend/test_onnx_frontend.py @@ -37,7 +37,7 @@ def generate_random_inputs( - model: ModelProto, inputs: Dict[str, np.array] = None + model: ModelProto, inputs: Optional[Dict[str, np.array]] = None ) -> Dict[str, np.array]: input_values = {} # Iterate through model inputs and extract their shape. @@ -559,13 +559,13 @@ def test_cumsum(): "cumsum_test", inputs=[ helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), - helper.make_tensor_value_info("axis", TensorProto.INT64, ()), ], + initializer=[helper.make_tensor("axis", TensorProto.INT64, (), [1])], outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)], ) model = helper.make_model(graph, producer_name="cumsum_test") - check_correctness(model, {"axis": [1]}) + check_correctness(model) if __name__ == "__main__": @@ -585,6 +585,7 @@ def test_cumsum(): test_conv() test_pow() test_erf() + test_cumsum() # TODO, still has issues # test_reshape() @@ -594,4 +595,3 @@ def test_cumsum(): test_transpose() test_unsqueeze() # test_shape() - # test_cumsum() # need axis as int