diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py index 8b2dc65bd42b..1247d615d596 100644 --- a/python/tvm/autotvm/task/space.py +++ b/python/tvm/autotvm/task/space.py @@ -165,6 +165,20 @@ def get_factors(n): ret.sort() return ret +def get_pow2s(n): + """return all power-of-two numbers that are less or equal than the integer + + Parameters + ---------- + n: int + integer for reference + + Returns + ------- + factors: list + List of all power-of-two numbers + """ + return [2**x for x in range(math.floor(math.log2(n)) + 1)] class SplitSpace(TransformSpace): """Split an axis for several times""" @@ -175,43 +189,49 @@ def __init__(self, axes, policy, **kwargs): self.policy = policy self.entities = [] - if policy == 'all': - num_outputs = kwargs["num_outputs"] - max_factor = kwargs.get("max_factor", 1 << 31) - fil = kwargs.get("filter", lambda x: True) - - length = axis.length - factors = get_factors(length) - factors = [x for x in factors if x <= max_factor] - # copy factors for every level - self.product = length - self.num_outputs = num_outputs - self.factors = [factors] * (num_outputs-1) - self._generate_space(0, [None] * (num_outputs - 1)) - self.entities = list(filter(fil, self.entities)) - self.num_output = num_outputs - elif policy == 'candidate': - self.product = axis.length - self.num_outputs = kwargs["num_outputs"] + max_factor = kwargs.get("max_factor", 1 << 31) + fil = kwargs.get("filter", lambda x: True) + self.product = axis.length + self.num_output = kwargs.get("num_outputs", 0) + assert self.num_output > 0 + + if policy == 'candidate': for size in kwargs["candidate"]: - assert len(size) == self.num_outputs - # assert np.prod(size) == self.product + assert len(size) == self.num_output self.entities.append(SplitEntity(size)) - self.num_output = self.num_outputs else: - raise RuntimeError("Invalid policy: " + policy) + if policy == 'verbose': + # Include factors and power-of-twos. May generate tails. + divisibles = get_factors(self.product) + pow2s = get_pow2s(self.product) + factors = [x for x in list(set(divisibles) | set(pow2s)) if x <= max_factor] + elif policy == 'factors': + # Include divisible factors. Guarantee no tails. + factors = [x for x in get_factors(self.product) if x <= max_factor] + elif policy == 'power2': + # Include less, equal, and round-up power-of-two numbers. May generate tails. + factors = [x for x in get_pow2s(self.product) if x <= max_factor] + else: + raise RuntimeError("Invalid policy: %s" % policy) - def _generate_space(self, now, tmp_stack): + # Enforce the product of all split factors equals to the axis length + no_tail = kwargs.get("no_tail", policy == 'factors') + + # Generate split entity by enumerating candidate factors. + self.factors = factors + self._generate_space(0, [None] * (self.num_output - 1), enforce_no_tail=no_tail) + + self.entities = list(filter(fil, self.entities)) + + def _generate_space(self, now, tmp_stack, enforce_no_tail=False): """Generate space by DFS""" - if now == self.num_outputs - 1: - size = np.prod(tmp_stack, dtype=np.int64) - if self.product % size == 0: - first = int(self.product // int(size)) - self.entities.append(SplitEntity([first] + tmp_stack[::-1])) + if now == self.num_output - 1: + if not enforce_no_tail or self.product % np.prod(tmp_stack, dtype=np.int64) == 0: + self.entities.append(SplitEntity([-1] + tmp_stack[::-1])) else: - for factor in self.factors[now]: + for factor in self.factors: tmp_stack[now] = factor - self._generate_space(now + 1, tmp_stack) + self._generate_space(now + 1, tmp_stack, enforce_no_tail) @staticmethod def get_num_output(axes, policy, **kwargs): @@ -219,7 +239,7 @@ def get_num_output(axes, policy, **kwargs): def __repr__(self): return ("Split(policy=%s, product=%d, num_outputs=%d) len=%d" % - (self.policy, self.product, self.num_outputs, len(self))) + (self.policy, self.product, self.num_output, len(self))) class SplitEntity(object): @@ -609,7 +629,7 @@ def axis(var): reduce_axis = axis - def define_split(self, name, axis, policy='all', **kwargs): + def define_split(self, name, axis, policy='factors', **kwargs): """Define a new tunable knob which splits an axis into a list of axes Parameters @@ -620,11 +640,22 @@ def define_split(self, name, axis, policy='all', **kwargs): axis to split policy: str name of policy. - If is 'all', the tuner will try all divisible factors. - If is 'candidate', try listed candidate. + If is 'factors', the tuner will try all divisible factors. + If is 'power2', the tuner will try power-of-two factors less or equal to the length. + If is 'verbose', the tuner will try all candidates in above two policies. + If is 'candidate', try given candidates. kwargs: dict extra arguments for policy - see examples below for how to use filter + max_factor: int + the maximum split factor. + filter: function(int) -> bool + see examples below for how to use filter. + num_outputs: int + the total number of axis after split. + no_tail: bool + should we only include divisible numbers as split factors. + candidate: list + (policy=candidate) manual candidate list. Examples -------- @@ -632,7 +663,7 @@ def define_split(self, name, axis, policy='all', **kwargs): >>> cfg.define_split('tile_x', x, policy='candidate', candidate=[[1, 4, 4], [4, 1, 4]]) >>> # use a filter that only accepts the split scheme whose inner most tile is less then 4 - >>> cfg.define_split('tile_y', y, policy='all', filter=lambda x: x.size[-1] <= 4) + >>> cfg.define_split('tile_y', y, policy='factors', filter=lambda x: x.size[-1] <= 4) """ axes = [axis] return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs) @@ -944,7 +975,7 @@ def fallback_split(self, name, constraints): """ space = self.space_map[name] assert isinstance(space, SplitSpace) - assert len(constraints) == space.num_outputs + assert len(constraints) == space.num_output # '-1' means no constraint constraints = [x if x != -1 else 1e10 for x in constraints] @@ -952,7 +983,7 @@ def fallback_split(self, name, constraints): entity = self._entity_map[name] now = space.product - for i in reversed(range(space.num_outputs)): + for i in reversed(range(space.num_output)): factors = get_factors(now) find = len(factors) - 1 diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py index af9c5bebb998..072c187ee294 100644 --- a/topi/python/topi/arm_cpu/bitserial_conv2d.py +++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py @@ -82,11 +82,11 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh ci, kh, kw = cfg.reduce_axis(CI_packed), cfg.reduce_axis(KH), cfg.reduce_axis(KW) ib, kb = cfg.reduce_axis(activation_bits), cfg.reduce_axis(weight_bits) - co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2, + co, vc = cfg.define_split('tile_co', co, num_outputs=2, filter=lambda x: x.size[-1] == 8) - oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2, + oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2, filter=lambda x: x.size[-1] >= 2) - ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2, + ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2, filter=lambda x: x.size[-1] >= 2) ci_o, ci_i = cfg.define_split("tile_ci", ci, num_outputs=2, filter=lambda x: x.size[-1] == 8 or x.size[-1] == 16) @@ -278,13 +278,13 @@ def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec, s[data_pad].compute_inline() _, h, _, _, _, _, _ = s[data_vec].op.axis - cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32) + cfg.define_split("tile_ah", cfg.axis(h), num_outputs=2, max_factor=32) oh, ih = cfg["tile_ah"].apply(s, data_vec, h) s[data_vec].parallel(oh) #### Schedule kernel packing co, _, _, _, _, _ = s[kernel_vec].op.axis - cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32) + cfg.define_split("tile_bco", cfg.axis(co), num_outputs=2, max_factor=32) oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co) s[kernel_vec].parallel(oco) diff --git a/topi/python/topi/arm_cpu/bitserial_dense.py b/topi/python/topi/arm_cpu/bitserial_dense.py index b07930f1577d..0148cfba3f38 100644 --- a/topi/python/topi/arm_cpu/bitserial_dense.py +++ b/topi/python/topi/arm_cpu/bitserial_dense.py @@ -66,10 +66,10 @@ def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits, pack_dtyp x, y = cfg.axis(batch), cfg.axis(out_dim) db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(in_dim) - ko, ki = cfg.define_split('tile_k', k, policy='all', num_outputs=2, + ko, ki = cfg.define_split('tile_k', k, num_outputs=2, filter=lambda xx: xx.size[-1] == 8 or xx.size[-1] == 16) - xo, xi = cfg.define_split('tile_x', x, policy='all', num_outputs=2) - yo, yi = cfg.define_split('tile_y', y, policy='all', num_outputs=2, + xo, xi = cfg.define_split('tile_x', x, num_outputs=2) + yo, yi = cfg.define_split('tile_y', y, num_outputs=2, filter=lambda xx: xx.size[-1] == 8) cfg.define_reorder('reorder_0', [yo, xo, ko, xi, wb, db, yi, ki], diff --git a/topi/python/topi/nn/bitserial_conv2d.py b/topi/python/topi/nn/bitserial_conv2d.py index 21abdf0de1ec..c04ff01eec8b 100644 --- a/topi/python/topi/nn/bitserial_conv2d.py +++ b/topi/python/topi/nn/bitserial_conv2d.py @@ -254,11 +254,11 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits, ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW) ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits) - co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2, + co, vc = cfg.define_split('tile_co', co, num_outputs=2, filter=lambda x: max(x.size[1:]) <= 16) - oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2, + oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2, filter=lambda x: max(x.size[1:]) <= 16) - ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2, + ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2, filter=lambda x: max(x.size[1:]) <= 16) cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll') @@ -358,11 +358,11 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits, ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW) ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits) - co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2, + co, vc = cfg.define_split('tile_co', co, num_outputs=2, filter=lambda x: max(x.size[1:]) <= 16) - oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2, + oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2, filter=lambda x: max(x.size[1:]) <= 16) - ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2, + ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2, filter=lambda x: max(x.size[1:]) <= 16) cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll') cfg.define_reorder("reorder_0", diff --git a/topi/python/topi/nn/bitserial_dense.py b/topi/python/topi/nn/bitserial_dense.py index 5d47b2974a7c..b28b3a41555a 100644 --- a/topi/python/topi/nn/bitserial_dense.py +++ b/topi/python/topi/nn/bitserial_dense.py @@ -95,9 +95,9 @@ def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtyp ######## Search space x, y = cfg.axis(X), cfg.axis(Y) db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K) - ko, ki = cfg.define_split('tile_k', k, policy='all', num_outputs=2) - yo, yi = cfg.define_split('tile_y', y, policy='all', num_outputs=2) - xo, xi = cfg.define_split('tile_x', x, policy='all', num_outputs=2) + ko, ki = cfg.define_split('tile_k', k, num_outputs=2) + yo, yi = cfg.define_split('tile_y', y, num_outputs=2) + xo, xi = cfg.define_split('tile_x', x, num_outputs=2) cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi], policy='candidate', candidate=[ diff --git a/topi/python/topi/x86/bitserial_conv2d.py b/topi/python/topi/x86/bitserial_conv2d.py index 905383fdb802..536386ab2fb0 100644 --- a/topi/python/topi/x86/bitserial_conv2d.py +++ b/topi/python/topi/x86/bitserial_conv2d.py @@ -99,7 +99,7 @@ def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec, s[data_pad].compute_inline() _, _, h, _, _, _, _ = s[data_vec].op.axis - cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32) + cfg.define_split("tile_ah", cfg.axis(h), num_outputs=2, max_factor=32) oh, ih = cfg["tile_ah"].apply(s, data_vec, h) if cfg["tile_ah"].size[1] == 1: oaxis = oh @@ -116,7 +116,7 @@ def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec, ##### Schedule Kenerl bitpacking co, _, _, _, _, _ = s[kernel_vec].op.axis - cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32) + cfg.define_split("tile_bco", cfg.axis(co), num_outputs=2, max_factor=32) oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co) if cfg["tile_bco"].size[1] == 1: oaxis = oco @@ -185,13 +185,13 @@ def _schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec, s[data_pad].compute_inline() _, h, _, _, _, _, _ = s[data_vec].op.axis - cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32) + cfg.define_split("tile_ah", cfg.axis(h), num_outputs=2, max_factor=32) oh, ih = cfg["tile_ah"].apply(s, data_vec, h) s[data_vec].parallel(oh) ##### Schedule kernel packing co, _, _, _, _, _ = s[kernel_vec].op.axis - cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32) + cfg.define_split("tile_bco", cfg.axis(co), num_outputs=2, max_factor=32) oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co) s[kernel_vec].parallel(oco) diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index d5bc5adcec85..6f134ea45781 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -95,7 +95,7 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): if _is_int8_hw_support(data.dtype, kernel.dtype, target): oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape ic = ic_chunk*ic_bn - assert ic == k_ic*k_ic_f*kic_s + assert ic == k_ic*k_ic_f*k_ic_s else: oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape assert ic_chunk == k_ic_chunk