From 779d9b3987bbaaaf6883747b63ca0962435ab883 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 9 Feb 2023 13:59:51 +0300 Subject: [PATCH] fix axis --- python/tvm/topi/cuda/scatter_elements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index 8bacea5e29f7c..3a436b212e656 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -67,7 +67,7 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): if not isinstance(axis, int): axis = get_const_int(axis) - def gen_ir(data, indices, updates, out): + def gen_ir(data, indices, updates, out, axis): ib = tir.ir_builder.create() data_ptr = ib.buffer_ptr(data) @@ -173,7 +173,7 @@ def gen_ir(data, indices, updates, out): return te.extern( [data.shape], [data, indices, updates], - lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], axis), dtype=data.dtype, out_buffers=[out_buf], name="scatter_elements_cuda",