Skip to content

Commit

Permalink
[AutoTVM][BugFix] Fix autotvm on the conv2d_nchw_winograd.mali operat…
Browse files Browse the repository at this point in the history
…or (apache#6130)

* [AutoTVM] Fix conv2d_nchw_winograd.mali

* Fix pylint error

Co-authored-by: Yanming Wang <[email protected]>
  • Loading branch information
2 people authored and Trevor Morris committed Aug 26, 2020
1 parent c8e833f commit a9e8288
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
10 changes: 7 additions & 3 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,19 +216,23 @@ def __call__(self, *args, **kwargs):
def _default_func(self, *args, **kwargs):
assert callable(self.fcompute) and callable(self.fschedule)
out = self.fcompute(*args, **kwargs)
arg_bufs = [out] + self.get_inputs(out)
arg_bufs = [out] + self._get_inputs(out)
s = self.fschedule([out])
return s, arg_bufs

def get_inputs(self, out):
@staticmethod
def _get_inputs(out):
inputs = []
queue = [out]
hash_set = set()
while queue:
t = queue.pop(0)
if isinstance(t.op, tensor.PlaceholderOp):
inputs.append(t)
else:
queue.extend(t.op.input_tensors)
input_tensors = [t for t in t.op.input_tensors if t not in hash_set]
queue.extend(input_tensors)
hash_set.update(input_tensors)
return inputs

def _register_task_compute(name, func=None):
Expand Down
3 changes: 1 addition & 2 deletions topi/python/topi/mali/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,7 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til
[(b*bnb+bb) % nW * m + nu], tvm.tir.const(0, data_pad.dtype)), name='d')

if autotvm.GLOBAL_SCOPE.in_tuning:
VC = cfg['tile_k'].size[-1]
kvshape = (KH + tile_size - 1, KW + tile_size - 1, tvm.tir.indexdiv(CO, VC), CI, VC)
kvshape = (alpha, alpha, CO // bna, CI, bna)
U = tvm.te.placeholder(kvshape, kernel.dtype, name="U")
else:
# transform kernel
Expand Down

0 comments on commit a9e8288

Please sign in to comment.