Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARITH] cleanup the indexmod/div on python side #4028

Merged
merged 5 commits into from
Sep 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,9 @@ def _count_flop(exp):
return _count_flop(exp.value)
if isinstance(exp, expr.Var):
return 0
if isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod,
if isinstance(exp, (expr.Add, expr.Sub, expr.Mul,
expr.Div, expr.Mod,
expr.FloorDiv, expr.FloorMod,
expr.Max, expr.Min,
expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
expr.And, expr.Or, expr.Not)):
Expand Down
20 changes: 10 additions & 10 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,23 @@ def __rmul__(self, other):
return _generic.multiply(other, self)

def __div__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(self, other)

def __rdiv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(other, self)

def __truediv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(self, other)

def __rtruediv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(other, self)

def __floordiv__(self, other):
Expand All @@ -100,8 +100,8 @@ def __rfloordiv__(self, other):
return _generic.divide(other, self)

def __mod__(self, other):
# raise div_ambiguity_error()
return _make._OpMod(self, other)
raise div_ambiguity_error()
# return _make._OpMod(self, other)

def __neg__(self):
neg_one = _api_internal._const(-1, self.dtype)
Expand Down
6 changes: 4 additions & 2 deletions src/pass/rewrite_unsafe_select.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -64,6 +64,8 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
bool VisitExpr_(const Mul* op) final { return BinaryOp(op); }
bool VisitExpr_(const Div* op) final { return BinaryOp(op); }
bool VisitExpr_(const Mod* op) final { return BinaryOp(op); }
bool VisitExpr_(const FloorDiv* op) final { return BinaryOp(op); }
bool VisitExpr_(const FloorMod* op) final { return BinaryOp(op); }
bool VisitExpr_(const Min* op) final { return BinaryOp(op); }
bool VisitExpr_(const Max* op) final { return BinaryOp(op); }
bool VisitExpr_(const EQ* op) final { return BinaryOp(op); }
Expand Down
14 changes: 8 additions & 6 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None):
yy = run_infer_type(y.astuple())
assert yy.checked_type == ret_type

idxd = tvm.indexdiv

d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
axis = tvm.var("axis")
verify_split((5, 5, 2, 2), 5,
Expand All @@ -393,15 +395,15 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None):
axis=0)
verify_split((d1, d2, d3, d4), 4,
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32")])),
axis=2)
verify_split((d1, d2, d3, d4), 2,
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1/2, d2, d3, d4), "float32"),
relay.ty.TensorType((d1/2, d2, d3, d4), "float32")])),
relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32"),
relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32")])),
axis=0)
verify_split((d1, d2, d3, d4), (2, 4, 7),
relay.ty.TupleType(tvm.convert([
Expand Down
3 changes: 2 additions & 1 deletion tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,9 @@ def verify_yolo_reorg(shape, stride, out_shape):
assert zz.checked_type == relay.ty.TensorType(out_shape, "float32")

n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
idxd = tvm.indexdiv
verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2))
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, h/2, w/2))
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, idxd(h, 2), idxd(w, 2)))

def test_yolo_reorg():
def verify_yolo_reorg(shape, stride):
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_autotvm_flop_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def test_pack_gemm():
k = tvm.reduce_axis((0, L))

bn = 4
fld = tvm.floordiv
flm = tvm.floormod
idxd = tvm.indexdiv
idxm = tvm.indexmod

A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j])
B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j])
C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj:
tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k]))
C = tvm.compute((N, M), lambda i, j: C_pack[fld(i, bn)][fld(j, bn)][flm(i, bn)][flm(j, bn)])
C = tvm.compute((N, M), lambda i, j: C_pack[idxd(i, bn)][idxd(j, bn)][idxm(i, bn)][idxm(j, bn)])

s = tvm.create_schedule([C.op])
assert compute_flop(s) == 2 * N * L * M
Expand Down
5 changes: 3 additions & 2 deletions tests/python/unittest/test_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def check_cuda(dtype, n, lanes):
print("skip because gpu does not support int8")
return
A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
B = tvm.compute((n,), lambda i: A[i]+tvm.const(1, A.dtype), name='B')
B = tvm.compute((n,), lambda i: A[i] + tvm.const(1, A.dtype), name='B')
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, bx)
Expand Down Expand Up @@ -165,9 +165,10 @@ def test_cuda_shuffle():
print("skip because cuda is not enabled..")
return

idxm = tvm.indexmod
a = tvm.placeholder((64, ), 'int32')
b = tvm.placeholder((64, ), 'int32')
c = tvm.compute((64, ), lambda x: a[x] + b[x - (x % 4) + (3 - x % 4)])
c = tvm.compute((64, ), lambda x: a[x] + b[x - idxm(x, 4) + (3 - idxm(x, 4))])
sch = tvm.create_schedule(c.op)
x = c.op.axis[0]
xo, xi = sch[c].split(x, 4)
Expand Down
5 changes: 3 additions & 2 deletions tests/python/unittest/test_ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,15 @@ def test_gpu():
dtype = "float32"
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
fld = tvm.floordiv
idxd = tvm.indexdiv

def test_device_ir(A, B, C):
n = A.shape[0]
max_threads = 32
ib = tvm.ir_builder.create()
bx = tvm.thread_axis("blockIdx.x")
tx = tvm.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", fld(n+max_threads-1, max_threads))
ib.scope_attr(bx, "thread_extent", idxd(n+max_threads-1, max_threads))
ib.scope_attr(tx, "thread_extent", max_threads)
idx = bx.var * max_threads + tx.var
Aptr = ib.buffer_ptr(A)
Expand Down
28 changes: 14 additions & 14 deletions tests/python/unittest/test_lang_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,31 +94,31 @@ def test_buffer_index_merge_mult_mod():
def assert_simplified_equal(index_simplified, index_direct):
assert tvm.ir_pass.Equal(index_simplified, index_direct),\
"index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
idxd = tvm.indexdiv
idxm = tvm.indexmod
# Test Case1
index_simplified = A_stride.vload(
(idxdiv(idxmod(k0, k1), s), idxmod(idxmod(k0, k1), s) + idxdiv(k0, k1) * k1))
(idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1))
index_direct = A_stride.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct)

# Test Case2
index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod(idxmod(k0, idxdiv(k1, s)), n) + idxmod(k0, k1)))
index_direct = A.vload((0, idxmod(k0, k1) + idxmod(k0, idxdiv(k1, s))))
index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1)))
index_direct = A.vload((0, idxm(k0, k1) + idxm(k0, idxd(k1, s))))
assert_simplified_equal(index_simplified, index_direct)
# Test Case3
index_simplified = A.vload((idxdiv((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
idxmod(idxmod(k0, idxdiv(k1, s)), n)))
index_simplified = A.vload((idxd((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
idxd(idxm(k0, idxd(k1, s)), n),
idxm((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
idxm(idxm(k0, idxd(k1, s)), n)))
index_direct = A.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct)
# Test Case4 (not able to simplify)
index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1)))
index_direct = A.vload((0, idxdiv(idxmod(k0, idxdiv(k1, s)), n) * n +
(idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1))))
index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1)))
index_direct = A.vload((0, idxd(idxm(k0, idxd(k1, s)), n) * n +
(idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1))))
assert_simplified_equal(index_simplified, index_direct)


Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_pass_rewrite_unsafe_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_rewrite_Select():
tvm.expr.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value

a = tvm.expr.Select(i>10, y, z)
a = tvm.expr.Select(tvm.floordiv(i, 4) > 10, y, z)
aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value
assert yy.name == "tvm_if_then_else"
assert zz.name == "tvm_if_then_else"
Expand Down
9 changes: 5 additions & 4 deletions tests/python/unittest/test_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,15 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor):
# This tests whether algorithm and intrinsics expressions are simplified
# as much as possible first and then checked for equality. See Issue #696
def test_tensorize_op():
tdiv = tvm.truncdiv
tmod = tvm.truncmod
idxd = tvm.indexdiv
idxm = tvm.indexmod

def op_intrin():
bh = 9
bw = 9
x = tvm.placeholder((5, 5), name='A')
y = tvm.compute((bh, bw),
lambda i, j: x[tdiv(j,3) + tmod(i,3), tmod(j,3)+ tdiv(i,3)])
lambda i, j: x[idxd(j,3) + idxm(i,3), idxm(j,3)+ idxd(i,3)])

def intrin_func(ins, outs):
xx, = ins
Expand All @@ -239,7 +240,7 @@ def intrin_func(ins, outs):
return tvm.decl_tensor_intrin(y.op, intrin_func)

A = tvm.placeholder((5, 5), name='A')
B = tvm.compute((9,9), lambda i, j: A[tdiv(j,3) + tmod(i,3), tmod(j,3) + tdiv(i,3)])
B = tvm.compute((9,9), lambda i, j: A[idxd(j,3) + idxm(i,3), idxm(j,3) + idxd(i,3)])
bt = op_intrin()
s = tvm.create_schedule(B.op)

Expand Down
14 changes: 11 additions & 3 deletions topi/python/topi/arm_cpu/bitserial_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
OW = (PAD_W - KW) // WSTR + 1
oshape = (1, OH, OW, CO)

idxd = tvm.indexdiv
idxm = tvm.indexmod

# Pad input channels of weights and data when it is not a multiple of 8
if CI_packed % 8 != 0:
CI_PAD = CI_packed % 8
Expand Down Expand Up @@ -106,7 +109,8 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8')

kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC, len(kernel.shape) == 4)
if kernel_vec.shape[-1] % 8 != 0 and CI_PAD != 0:
idxm = tvm.indexmod
if idxm(kernel_vec.shape[-1], 8) != 0 and CI_PAD != 0:
kernel_vec = pad(kernel_vec, [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, CI_PAD])

N, H, W, IB, CI = data_q.shape
Expand Down Expand Up @@ -147,8 +151,12 @@ def _unipolar_conv(n, h, w, co, vh, vw, vc):
else:
conv_vec = tvm.compute(ovshape, _bipolar_conv, name='conv_vec', tag='bipolar')

conv = tvm.compute(oshape, lambda n, h, w, co:
conv_vec[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC].astype(out_dtype),

conv = tvm.compute(oshape,
lambda n, h, w, co:
conv_vec[n,
idxd(h, VH), idxd(w, VW), idxd(co, VC),
idxm(h, VH), idxm(w, VW), idxm(co, VC)].astype(out_dtype),
name='conv', tag='spatial_bitserial_conv_nhwc')

return conv
Expand Down
Loading