Skip to content

Commit

Permalink
[FRONTEND] [HYBRID] Augmented assign operator supported! (apache#1459)
Browse files Browse the repository at this point in the history
  • Loading branch information
were authored and tqchen committed Jul 20, 2018
1 parent 1f2abda commit 33245b8
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 26 deletions.
9 changes: 8 additions & 1 deletion python/tvm/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

def list_to_block(visit, lst):
"""Convert a list of Python IR nodes to HalideIR Block"""
lst = list(map(visit, lst))
lst = [visit(i) for i in lst]
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())]
if not lst:
return make_nop()
Expand Down Expand Up @@ -162,6 +162,13 @@ def visit_Name(self, node):
def visit_Num(self, node):
return _api.const(node.n)

def visit_AugAssign(self, node):
lhs = self.visit(node.target)
rhs = self.visit(node.value)
rhs = HybridParser._binop_maker[type(node.op)](lhs, rhs)
if not isinstance(lhs, _expr.Call):
raise ValueError("The LHS of an AugAssign is supposed to be a call!")
return _make.Provide(lhs.func, 0, rhs, lhs.args)

def visit_Assign(self, node):
if len(node.targets) != 1:
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/hybrid/var_decl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, args):
self.scope_level = []
self._args = {}
self.args = args
self.aug_assign_ = False


def visit_FunctionDef(self, node):
Expand Down Expand Up @@ -48,6 +49,12 @@ def visit_Call(self, node):
self.visit(elem)


def visit_AugAssign(self, node):
self.aug_assign_ = True
self.generic_visit(node)
self.aug_assign_ = False


def visit_Name(self, node):
# If it is from the argument list or loop variable, we do not worry about it!
if node.id in self._args.keys():
Expand All @@ -62,6 +69,8 @@ def visit_Name(self, node):
if node.id not in self.status.keys():
if not isinstance(node.ctx, ast.Store):
raise ValueError('In Python, "first store" indicates "declaration"')
if self.aug_assign_:
raise ValueError('"First store" cannot be an AugAssign')
self.status[node.id] = (node, self.scope_level[-1], set())
else:
decl, loop, usage = self.status[node.id]
Expand Down
75 changes: 50 additions & 25 deletions tests/python/unittest/test_hybrid_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def tvm_val_2_py_val(val):
module(*nd_args)

for nd, np in to_check:
numpy.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5)
numpy.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-3, atol=1e-3)

return module


@script
Expand Down Expand Up @@ -83,7 +85,7 @@ def test_outer_product():
func = tvm.lower(ir, [n, m, a, b, c])
func = tvm.build(func)

run_and_check(outer_product, [n, m, a, b, c], [c], {n: 999, m: 1001})
run_and_check(outer_product, [n, m, a, b, c], [c], {n: 99, m: 101})

for key, _ in HYBRID_GLOBALS.items():
assert key not in globals().keys()
Expand Down Expand Up @@ -165,20 +167,32 @@ def fanout(n, a, b):
run_and_check(fanout, [n, a, b], [b], {n: 10})


@script
def failure():
for i in range(1, 100):
i = 0

def test_failure():
try:
@script
def failure():
for i in range(1, 100):
i = 0
tvm.hybrid.parse(failure, [])
except IOError as err:
assert sys.version_info[0] == 2
print('[Warning] Case test_failure is skipped by Python2 because "%s"' % str(err))
except Exception as err:
print('[Warning] Case test_failure.0 is skipped by Python2 because "%s"' % str(err))
except ValueError as err:
assert str(err) == 'You CAN NEVER overwrite a loop variable!'

try:
@tvm.hybrid.script
def augdefine():
for i in range(10):
es += 0
tvm.hybrid.parse(augdefine, [])
except IOError as err:
assert sys.version_info[0] == 2
print('[Warning] Case test_failure.1 is skipped by Python2 because "%s"' % str(err))
except ValueError as err:
assert str(err) == '"First store" cannot be an AugAssign'



def test_looptype():
@script
Expand Down Expand Up @@ -280,7 +294,7 @@ def blur(a, b):
s = 0.0
for di in range(3):
for dj in range(3):
s = s + a[i-di, j-dj]
s += a[i-di, j-dj]
b[i-2, j-2] = s / 9.0
try:
a = tvm.placeholder((32, 32), 'float32', 'a')
Expand Down Expand Up @@ -315,29 +329,39 @@ def blur2d(a, b):

a = tvm.placeholder((32, 32), 'float32', 'a')
b = tvm.placeholder((30, 30), 'float32', 'b')

run_and_check(blur2d, [a, b], [b])

if tvm.gpu().exist:
@tvm.hybrid.script
def share_vec_add(a, b, c):
shared = allocate((256, ), 'float32', 'shared')
for i in bind("threadIdx.x", 256):
shared[i] = a[i]
local = allocate((256, ), 'float32', 'local')
for i in bind("threadIdx.x", 256):
local[i] = b[i]
for i in bind("threadIdx.x", 256):
c[i] = shared[i] + local[i]

a = tvm.placeholder((256, ), dtype='float32', name='a')
b = tvm.placeholder((256, ), dtype='float32', name='b')
c = tvm.placeholder((256, ), dtype='float32', name='c')
run_and_check(share_vec_add, [a, b, c], [c], target='cuda')
def shared_gemm(a, b, c):
for io in bind('blockIdx.x', 8):
for ii in bind('blockIdx.y', 8):
shared_b = allocate((64, 64), 'float32', 'shared')
for k in range(64):
shared_b[io * 8 + ii, k] = b[io * 8 + ii, k]
for jo in bind('threadIdx.y', 8):
for ji in bind('threadIdx.x', 8):
for k in range(64):
c[io*8+ii, jo*8+ji] += a[io*8+ii, k] * shared_b[k, jo*8+ji]

a = tvm.placeholder((64, 64), dtype='float32', name='a')
b = tvm.placeholder((64, 64), dtype='float32', name='b')
c = tvm.placeholder((64, 64), dtype='float32', name='c')
module = run_and_check(shared_gemm, [a, b, c], [c], target='cuda')
assert "__syncthreads()" in module.imported_modules[0].get_source()
else:
print('[Warning] No GPU found! Skip shared mem test!')


def test_augassign():
@tvm.hybrid.script
def augassign(a):
for i in range(a.shape[0]):
a[i] += 1.0
a = tvm.placeholder((16, ), dtype='float32', name='a')
run_and_check(augassign, [a], [a])


if __name__ == "__main__":
test_outer_product()
test_fanout()
Expand All @@ -348,4 +372,5 @@ def share_vec_add(a, b, c):
test_math_intrin()
test_non_zero()
test_allocate()
test_augassign()

0 comments on commit 33245b8

Please sign in to comment.