Skip to content

Commit

Permalink
[Hybrid][Fix] Fix hybrid script to support array of tensors (#4494)
Browse files Browse the repository at this point in the history
* [Fix][Hybrid] Fix hybrid script to support array of tensors

* add test case

* clean up

* trigger ci
  • Loading branch information
icemelon authored and kevinthesun committed Dec 12, 2019
1 parent fb12f35 commit 123a407
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
10 changes: 8 additions & 2 deletions python/tvm/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,9 +647,15 @@ def source_to_op(src, args, symbols, closure_vars):
parser = parse_python(src, args, symbols, closure_vars)

input_tensors = []
def get_input_tensors(arg):
if isinstance(arg, Tensor):
input_tensors.append(arg)
elif isinstance(arg, Array):
for i in arg:
get_input_tensors(i)

for i in args:
if isinstance(i, Tensor):
input_tensors.append(i)
get_input_tensors(i)
op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))]
Expand Down
32 changes: 32 additions & 0 deletions tests/python/unittest/test_hybrid_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,37 @@ def add_something(a):
func, ins, outs = run_and_check(add_something, [a])
run_and_check(func, ins, outs=outs)

def test_array_inputs():
@script
def sum_array(inputs):
out = output_tensor((10,), inputs[0].dtype)
n = len(inputs)
for i in range(10):
for j in const_range(n):
out[i] += inputs[j][i]
return out
n = 5
inputs = []
for i in range(n):
inputs.append(tvm.placeholder((10,), name='t%s' % i, dtype='float32'))

out = sum_array(tvm.convert(inputs))
assert len(out.op.inputs) == n

sch = tvm.create_schedule(out.op)
mod = tvm.build(sch, inputs + [out], target='llvm')
assert mod

input_nd = []
out_ref = numpy.zeros((10,))
for _ in range(n):
arr = numpy.random.uniform(size=(10,)).astype('float32')
input_nd.append(tvm.nd.array(arr))
out_ref += arr
out_nd = tvm.nd.array(numpy.zeros((10,), 'float32'))
mod(*input_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_ref)

if __name__ == "__main__":
test_outer_product()
test_fanout()
Expand All @@ -807,5 +838,6 @@ def add_something(a):
test_const_range()
test_schedule()
test_capture()
test_array_inputs()
# TODO:
# test_inplace()

0 comments on commit 123a407

Please sign in to comment.