Skip to content

Commit

Permalink
[AutoTVM] Enhance tuning space of split (#3949)
Browse files Browse the repository at this point in the history
* Refine policies for define_split

- Rename policy "all" to "factors"
- Add policy "verbose" and "power2"

* Refine search space

* add doc
  • Loading branch information
comaniac authored and vinx13 committed Sep 15, 2019
1 parent e35e1cc commit da03979
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 60 deletions.
107 changes: 69 additions & 38 deletions python/tvm/autotvm/task/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,20 @@ def get_factors(n):
ret.sort()
return ret

def get_pow2s(n):
"""return all power-of-two numbers that are less or equal than the integer
Parameters
----------
n: int
integer for reference
Returns
-------
factors: list
List of all power-of-two numbers
"""
return [2**x for x in range(math.floor(math.log2(n)) + 1)]

class SplitSpace(TransformSpace):
"""Split an axis for several times"""
Expand All @@ -175,51 +189,57 @@ def __init__(self, axes, policy, **kwargs):
self.policy = policy
self.entities = []

if policy == 'all':
num_outputs = kwargs["num_outputs"]
max_factor = kwargs.get("max_factor", 1 << 31)
fil = kwargs.get("filter", lambda x: True)

length = axis.length
factors = get_factors(length)
factors = [x for x in factors if x <= max_factor]
# copy factors for every level
self.product = length
self.num_outputs = num_outputs
self.factors = [factors] * (num_outputs-1)
self._generate_space(0, [None] * (num_outputs - 1))
self.entities = list(filter(fil, self.entities))
self.num_output = num_outputs
elif policy == 'candidate':
self.product = axis.length
self.num_outputs = kwargs["num_outputs"]
max_factor = kwargs.get("max_factor", 1 << 31)
fil = kwargs.get("filter", lambda x: True)
self.product = axis.length
self.num_output = kwargs.get("num_outputs", 0)
assert self.num_output > 0

if policy == 'candidate':
for size in kwargs["candidate"]:
assert len(size) == self.num_outputs
# assert np.prod(size) == self.product
assert len(size) == self.num_output
self.entities.append(SplitEntity(size))
self.num_output = self.num_outputs
else:
raise RuntimeError("Invalid policy: " + policy)
if policy == 'verbose':
# Include factors and power-of-twos. May generate tails.
divisibles = get_factors(self.product)
pow2s = get_pow2s(self.product)
factors = [x for x in list(set(divisibles) | set(pow2s)) if x <= max_factor]
elif policy == 'factors':
# Include divisible factors. Guarantee no tails.
factors = [x for x in get_factors(self.product) if x <= max_factor]
elif policy == 'power2':
# Include less, equal, and round-up power-of-two numbers. May generate tails.
factors = [x for x in get_pow2s(self.product) if x <= max_factor]
else:
raise RuntimeError("Invalid policy: %s" % policy)

def _generate_space(self, now, tmp_stack):
# Enforce the product of all split factors equals to the axis length
no_tail = kwargs.get("no_tail", policy == 'factors')

# Generate split entity by enumerating candidate factors.
self.factors = factors
self._generate_space(0, [None] * (self.num_output - 1), enforce_no_tail=no_tail)

self.entities = list(filter(fil, self.entities))

def _generate_space(self, now, tmp_stack, enforce_no_tail=False):
"""Generate space by DFS"""
if now == self.num_outputs - 1:
size = np.prod(tmp_stack, dtype=np.int64)
if self.product % size == 0:
first = int(self.product // int(size))
self.entities.append(SplitEntity([first] + tmp_stack[::-1]))
if now == self.num_output - 1:
if not enforce_no_tail or self.product % np.prod(tmp_stack, dtype=np.int64) == 0:
self.entities.append(SplitEntity([-1] + tmp_stack[::-1]))
else:
for factor in self.factors[now]:
for factor in self.factors:
tmp_stack[now] = factor
self._generate_space(now + 1, tmp_stack)
self._generate_space(now + 1, tmp_stack, enforce_no_tail)

@staticmethod
def get_num_output(axes, policy, **kwargs):
return kwargs["num_outputs"]

def __repr__(self):
return ("Split(policy=%s, product=%d, num_outputs=%d) len=%d" %
(self.policy, self.product, self.num_outputs, len(self)))
(self.policy, self.product, self.num_output, len(self)))


class SplitEntity(object):
Expand Down Expand Up @@ -609,7 +629,7 @@ def axis(var):

reduce_axis = axis

def define_split(self, name, axis, policy='all', **kwargs):
def define_split(self, name, axis, policy='factors', **kwargs):
"""Define a new tunable knob which splits an axis into a list of axes
Parameters
Expand All @@ -620,19 +640,30 @@ def define_split(self, name, axis, policy='all', **kwargs):
axis to split
policy: str
name of policy.
If is 'all', the tuner will try all divisible factors.
If is 'candidate', try listed candidate.
If is 'factors', the tuner will try all divisible factors.
If is 'power2', the tuner will try power-of-two factors less or equal to the length.
If is 'verbose', the tuner will try all candidates in above two policies.
If is 'candidate', try given candidates.
kwargs: dict
extra arguments for policy
see examples below for how to use filter
max_factor: int
the maximum split factor.
filter: function(int) -> bool
see examples below for how to use filter.
num_outputs: int
the total number of axis after split.
no_tail: bool
should we only include divisible numbers as split factors.
candidate: list
(policy=candidate) manual candidate list.
Examples
--------
>>> # use custom candidates
>>> cfg.define_split('tile_x', x, policy='candidate', candidate=[[1, 4, 4], [4, 1, 4]])
>>> # use a filter that only accepts the split scheme whose inner most tile is less then 4
>>> cfg.define_split('tile_y', y, policy='all', filter=lambda x: x.size[-1] <= 4)
>>> cfg.define_split('tile_y', y, policy='factors', filter=lambda x: x.size[-1] <= 4)
"""
axes = [axis]
return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs)
Expand Down Expand Up @@ -944,15 +975,15 @@ def fallback_split(self, name, constraints):
"""
space = self.space_map[name]
assert isinstance(space, SplitSpace)
assert len(constraints) == space.num_outputs
assert len(constraints) == space.num_output

# '-1' means no constraint
constraints = [x if x != -1 else 1e10 for x in constraints]

entity = self._entity_map[name]
now = space.product

for i in reversed(range(space.num_outputs)):
for i in reversed(range(space.num_output)):
factors = get_factors(now)

find = len(factors) - 1
Expand Down
10 changes: 5 additions & 5 deletions topi/python/topi/arm_cpu/bitserial_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
ci, kh, kw = cfg.reduce_axis(CI_packed), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
ib, kb = cfg.reduce_axis(activation_bits), cfg.reduce_axis(weight_bits)

co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2,
co, vc = cfg.define_split('tile_co', co, num_outputs=2,
filter=lambda x: x.size[-1] == 8)
oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2,
oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
filter=lambda x: x.size[-1] >= 2)
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2,
ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
filter=lambda x: x.size[-1] >= 2)
ci_o, ci_i = cfg.define_split("tile_ci", ci, num_outputs=2,
filter=lambda x: x.size[-1] == 8 or x.size[-1] == 16)
Expand Down Expand Up @@ -278,13 +278,13 @@ def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec,
s[data_pad].compute_inline()

_, h, _, _, _, _, _ = s[data_vec].op.axis
cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32)
cfg.define_split("tile_ah", cfg.axis(h), num_outputs=2, max_factor=32)
oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
s[data_vec].parallel(oh)

#### Schedule kernel packing
co, _, _, _, _, _ = s[kernel_vec].op.axis
cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32)
cfg.define_split("tile_bco", cfg.axis(co), num_outputs=2, max_factor=32)
oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
s[kernel_vec].parallel(oco)

Expand Down
6 changes: 3 additions & 3 deletions topi/python/topi/arm_cpu/bitserial_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits, pack_dtyp
x, y = cfg.axis(batch), cfg.axis(out_dim)
db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(in_dim)

ko, ki = cfg.define_split('tile_k', k, policy='all', num_outputs=2,
ko, ki = cfg.define_split('tile_k', k, num_outputs=2,
filter=lambda xx: xx.size[-1] == 8 or xx.size[-1] == 16)
xo, xi = cfg.define_split('tile_x', x, policy='all', num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, policy='all', num_outputs=2,
xo, xi = cfg.define_split('tile_x', x, num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, num_outputs=2,
filter=lambda xx: xx.size[-1] == 8)

cfg.define_reorder('reorder_0', [yo, xo, ko, xi, wb, db, yi, ki],
Expand Down
12 changes: 6 additions & 6 deletions topi/python/topi/nn/bitserial_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,11 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)

co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2,
co, vc = cfg.define_split('tile_co', co, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2,
oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2,
ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')

Expand Down Expand Up @@ -358,11 +358,11 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW)
ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits)

co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2,
co, vc = cfg.define_split('tile_co', co, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2,
oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2,
ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2,
filter=lambda x: max(x.size[1:]) <= 16)
cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll')
cfg.define_reorder("reorder_0",
Expand Down
6 changes: 3 additions & 3 deletions topi/python/topi/nn/bitserial_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtyp
######## Search space
x, y = cfg.axis(X), cfg.axis(Y)
db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K)
ko, ki = cfg.define_split('tile_k', k, policy='all', num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, policy='all', num_outputs=2)
xo, xi = cfg.define_split('tile_x', x, policy='all', num_outputs=2)
ko, ki = cfg.define_split('tile_k', k, num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, num_outputs=2)
xo, xi = cfg.define_split('tile_x', x, num_outputs=2)

cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi],
policy='candidate', candidate=[
Expand Down
8 changes: 4 additions & 4 deletions topi/python/topi/x86/bitserial_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
s[data_pad].compute_inline()

_, _, h, _, _, _, _ = s[data_vec].op.axis
cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32)
cfg.define_split("tile_ah", cfg.axis(h), num_outputs=2, max_factor=32)
oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
if cfg["tile_ah"].size[1] == 1:
oaxis = oh
Expand All @@ -116,7 +116,7 @@ def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,

##### Schedule Kenerl bitpacking
co, _, _, _, _, _ = s[kernel_vec].op.axis
cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32)
cfg.define_split("tile_bco", cfg.axis(co), num_outputs=2, max_factor=32)
oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
if cfg["tile_bco"].size[1] == 1:
oaxis = oco
Expand Down Expand Up @@ -185,13 +185,13 @@ def _schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec,
s[data_pad].compute_inline()

_, h, _, _, _, _, _ = s[data_vec].op.axis
cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32)
cfg.define_split("tile_ah", cfg.axis(h), num_outputs=2, max_factor=32)
oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
s[data_vec].parallel(oh)

##### Schedule kernel packing
co, _, _, _, _, _ = s[kernel_vec].op.axis
cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32)
cfg.define_split("tile_bco", cfg.axis(co), num_outputs=2, max_factor=32)
oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
s[kernel_vec].parallel(oco)

Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
if _is_int8_hw_support(data.dtype, kernel.dtype, target):
oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape
ic = ic_chunk*ic_bn
assert ic == k_ic*k_ic_f*kic_s
assert ic == k_ic*k_ic_f*k_ic_s
else:
oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape
assert ic_chunk == k_ic_chunk
Expand Down

0 comments on commit da03979

Please sign in to comment.