Skip to content

Commit

Permalink
[ARITH] cleanup the indexmod/div on python side (apache#4028)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Sep 28, 2019
1 parent 9b58279 commit 89be933
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions python/vta/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,9 @@ def inject_dma_intrin(stmt_in):
Transformed statement
"""
env = get_env()
idxd = tvm.indexdiv
idxm = tvm.indexmod

def _check_compact(buf):
ndim = len(buf.shape)
size = tvm.const(1, buf.shape[0].dtype)
Expand Down Expand Up @@ -369,7 +372,7 @@ def _fold_buffer_dim(buf, scope, elem_block):
x_size = 1
x_stride = buf.strides[ndim - base]
next_base = base
if not util.equal_const_int(x_stride % elem_block, 0):
if not util.equal_const_int(idxm(x_stride, elem_block), 0):
raise RuntimeError(
"scope %s need to have block=%d, shape=%s, strides=%s" % (
scope, elem_block, buf.shape, buf.strides))
Expand All @@ -394,7 +397,7 @@ def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
raise RuntimeError("Expect buffer type to be %s instead of %s" %
(dtype, buf.dtype))
shape, strides = buf.shape, buf.strides
if not util.equal_const_int(buf.elem_offset % elem_block, 0):
if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0):
raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block))
if allow_fold:
shape, strides = _fold_buffer_dim(buf, scope, elem_block)
Expand All @@ -421,23 +424,23 @@ def raise_error():
x_size = 1
x_stride = 1
y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not util.equal_const_int(strides[-2] - elem_block, 0):
raise_error()

if ndim == 2:
x_size = shape[-2]
x_stride = shape[-2]
y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block
if not util.equal_const_int(strides[-3] % elem_block, 0):
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not util.equal_const_int(idxm(strides[-3], elem_block), 0):
raise_error()

if ndim == 3:
x_size = shape[-2]
x_stride = strides[-3] / elem_block
x_stride = idxd(strides[-3], elem_block)
y_size = shape[-3]
return x_size, y_size, x_stride, buf.elem_offset / elem_block
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)

else:
if not util.equal_const_int(strides[-1], 1):
Expand All @@ -451,23 +454,23 @@ def raise_error():
x_size = 1
x_stride = 1
y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not util.equal_const_int(strides[-3], elem_block):
raise_error()

if ndim == 3:
x_size = shape[-3]
x_stride = shape[-3]
y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block
if not util.equal_const_int(strides[-4] % elem_block, 0):
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not util.equal_const_int(idxm(strides[-4], elem_block), 0):
raise_error()

if ndim == 4:
x_size = shape[-3]
x_stride = strides[-4] / elem_block
x_stride = idxd(strides[-4], elem_block)
y_size = shape[-4]
return x_size, y_size, x_stride, buf.elem_offset / elem_block
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)

raise_error()

Expand Down Expand Up @@ -765,6 +768,8 @@ def inject_alu_intrin(stmt_in):
Transformed statement
"""
env = get_env()
idxm = tvm.indexmod

def _do_fold(stmt):
def _equal(x, y):
return tvm.ir_pass.Equal(tvm.ir_pass.Simplify(x - y), 0)
Expand Down Expand Up @@ -910,10 +915,10 @@ def _flatten_loop(src_coeff, dst_coeff, extents):
assert len(extents) != 0
assert tvm.ir_pass.Equal(
tvm.ir_pass.Simplify(
src_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0)
idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir_pass.Equal(
tvm.ir_pass.Simplify(
dst_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0)
idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir_pass.Equal(src_coeff[-2], 1)
assert tvm.ir_pass.Equal(dst_coeff[-2], 1)
if env.BATCH > 1:
Expand Down

0 comments on commit 89be933

Please sign in to comment.