Skip to content

Commit

Permalink
Refine porting x86 NCHWc conv to AutoTVM (apache#1993)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu authored and Wei Chen committed Feb 19, 2019
1 parent aff9c79 commit 74165ce
Show file tree
Hide file tree
Showing 15 changed files with 497 additions and 813 deletions.
39 changes: 15 additions & 24 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,41 +170,32 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
kh, kw = attrs.get_int_tuple('kernel_size')
groups = attrs.get_int("groups")
channels = attrs.get_int("channels")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
out_dtype = attrs.get_string("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert dilation == (1, 1), "not support dilate now"
with tvm.target.create(attrs.get_string("target")):
if groups == 1:
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], channels,
(kh, kw), strides, padding, layout,
out_layout)
# pylint: enable=assignment-from-no-return
else:
raise ValueError("not support arbitrary group number > 1 for now")
if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.add(out, bias)
return out
if groups == 1:
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding,
layout, out_layout, out_dtype)
# pylint: enable=assignment-from-no-return
else:
raise ValueError("not support arbitrary group number > 1 for now")
if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.add(out, bias)
return out

@reg.register_schedule("_contrib_conv2d_NCHWc")
def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of conv2d NCHWc"""
groups = attrs.get_int("groups")
kh, kw = attrs.get_int_tuple('kernel_size')
oc = attrs.get_int("channels")
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
with tvm.target.create(target):
if groups == 1:
return topi.generic.schedule_conv2d_NCHWc(
oc, (kh, kw), strides, padding, layout, out_layout, outs)
return topi.generic.schedule_conv2d_NCHWc(outs)
else:
raise ValueError("not support group number > 1 for now")

Expand Down
113 changes: 81 additions & 32 deletions python/tvm/autotvm/task/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,53 @@ def query(self, target, workload):
ret = self._old_ctx.query(target, workload)
return ret

def update(self, target, workload, cfg):
"""
Update context with a specific config.
Parameters
----------
target: Target
The current target
workload : Workload
The current workload.
cfg : ConfigSpace
The specific configuration.
Note
----
This interface is for cases when TVM decides to replace an operator in the graph.
For example, `AlterOpLayout` pass (enables when `opt_level = 3`) replaces `NCHW`
convolution with `NCHW[x]c` implementation on x86 CPUs.
Thus in TOPI, we first query schedule using original `NCHW` workload,
then update the dispatcher with the new `NCHW[x]c` workload.
So that later on, `NCHW[x]c` convolution can get schedule from the dispatcher using
its own workload directly.
.. code-block:: python
@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfo):
workload = get_conv2d_workload(...)
dispatch_ctx = autotvm.task.DispatchContext.current
target = tvm.target.current_target()
config = dispatch_ctx.query(target, workload)
# Get conv2d_NCHWc workload from config
# new_workload = ...
# new_inputs = ...
# new_attrs = ...
# Store altered operator's config
dispatch_ctx.update(target, new_workload, config)
return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs)
We directly store `config` back because `conv2d_NCHW` and `conv2d_NCHWc`
share the same schedule parameters.
One can construct a new `ConfigEntity` if this is not the case.
"""
raise NotImplementedError()

def _query_inside(self, target, workload):
"""
Query the context to get the specific config for a template.
Expand Down Expand Up @@ -179,6 +226,11 @@ def _query_inside(self, target, workload):
self.workload = workload
return self._config

def update(self, target, workload, cfg):
"""Override update"""
self.workload = workload
self._config = cfg


class ApplyHistoryBest(DispatchContext):
"""
Expand All @@ -197,6 +249,7 @@ def __init__(self, records):

self.best_by_targetkey = {}
self.best_by_model = {}
self._best_user_defined = {}

if records:
self.load(records)
Expand Down Expand Up @@ -264,17 +317,32 @@ def _query_inside(self, target, workload):
if opt.startswith("-model"):
model = opt[7:]
key = (model, workload)
if key in self._best_user_defined:
return self._best_user_defined[key]
if key in self.best_by_model:
return self.best_by_model[key][0].config

# then try matching by target key
for k in target.keys:
key = (k, workload)
if key in self._best_user_defined:
return self._best_user_defined[key]
if key in self.best_by_targetkey:
return self.best_by_targetkey[key][0].config

return None

def update(self, target, workload, cfg):
for opt in target.options:
if opt.startswith("-model"):
model = opt[7:]
key = (model, workload)
self._best_user_defined[key] = cfg

for k in target.keys:
key = (k, workload)
self._best_user_defined[key] = cfg


class FallbackContext(DispatchContext):
"""
Expand Down Expand Up @@ -324,6 +392,10 @@ def clear_cache(self, target, workload):
if key in self.memory:
del self.memory[key]

def update(self, target, workload, cfg):
key = (str(target), workload)
self.memory[key] = cfg

DispatchContext.current = FallbackContext()

def clear_fallback_cache(target, workload):
Expand Down Expand Up @@ -391,37 +463,14 @@ def _query_inside(self, target, workload):
cfg : ConfigSpace
The specific configuration.
"""
cfg = self._records[self._counter][0].config
self._counter += 1
return cfg

def query_global_dict(self, key):
"""
Query the context to get config from global
config dictionary.
Parameters
----------
key : str
Key to query the config.
Returns
-------
cfg : ConfigSpace
The specific configuration.
"""
if self._counter < len(self._records):
cfg = self._records[self._counter][0].config
self._counter += 1
self.update(target, workload, cfg)
return cfg
key = (str(target), workload)
return self._global_cfg_dict[key]

def update_global_dict(self, key, val):
"""
Update the global config dictionary.
Parameters
----------
key : str
Key of config.
val : ConfigSpace
Value of config.
"""
self._global_cfg_dict[key] = val
def update(self, target, workload, cfg):
key = (str(target), workload)
self._global_cfg_dict[key] = cfg
14 changes: 13 additions & 1 deletion python/tvm/autotvm/task/space.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=too-few-public-methods,invalid-name,unused-argument,arguments-differ
# pylint: disable=consider-using-enumerate
# pylint: disable=consider-using-enumerate,too-many-lines
"""
Template configuration space.
Expand Down Expand Up @@ -996,5 +996,17 @@ def fallback_with_reference_log(self, ref_log):
if not isinstance(self.space_map[knob_name], SplitSpace):
self._entity_map[knob_name] = best_match_cfg[knob_name]

def __setitem__(self, name, entity):
"""set the entity(knob) of by name
Parameters
----------
name: str
name of the entity
entity: SplitEntity, ReorderEntity, AnnotateEntity, OtherOptionEntity
value of the entity
"""
self._entity_map[name] = entity

def __repr__(self):
return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash)
15 changes: 9 additions & 6 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def create(func_name, args, target, target_host=None, template_key=None):

return ret

def args_to_workload(x):
def args_to_workload(x, topi_compute_func=None):
"""Convert argument list to hashable workload tuple.
This function will convert list to tuple, tvm node to python value and
flatten tvm.tensor.Tensor to a tuple
Expand All @@ -191,25 +191,28 @@ def args_to_workload(x):
----------
x: primitive hashable types or tensor.Tensor
The original value
topi_compute_func: topi compute function
The function name will be added as first element of the workload tuple
Returns
-------
ret: hashable
The hashable value
"""
if isinstance(x, tensor.Tensor):
return get_const_tuple(x.shape) + (x.dtype, )
workload = get_const_tuple(x.shape) + (x.dtype, )
elif isinstance(x, (tuple, list, container.Array)):
return tuple([args_to_workload(a) for a in x])
workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float)):
return x
workload = x
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value
workload = x.value
elif x is None:
return 0
workload = 0
else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
'primitive types only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload

def template(func):
"""
Expand Down
17 changes: 7 additions & 10 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint: disable=unused-variable,invalid-name
# pylint: disable=unused-variable,invalid-name,unused-argument
"""
Decorators for registering tunable templates to TOPI.
Expand All @@ -13,7 +13,6 @@

from ... import _api_internal, tensor

from ..util import get_func_name
from .task import args_to_workload, dispatcher


Expand Down Expand Up @@ -55,8 +54,6 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
--------
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
fname = get_func_name(topi_compute)

def _decorator(f):
targets = [target_keys] if isinstance(target_keys, str) else target_keys
for target_key in targets:
Expand All @@ -68,7 +65,7 @@ def _decorator(f):
def config_dispatcher(*args, **kwargs):
"""override topi call as a config dispatcher"""
assert not kwargs, "Do not support kwargs in template function call"
return (fname, ) + args_to_workload(args)
return args_to_workload(args, topi_compute)
_REGISTED_DISPATHCER[target_key][topi_compute] = config_dispatcher

config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_compute]
Expand All @@ -88,7 +85,7 @@ def template_call(cfg, *args, **kwargs):
attrs = {}
for k, v in node.op.attrs.items():
attrs[k] = v
attrs['workload'] = (fname, ) + args_to_workload(args)
attrs['workload'] = args_to_workload(args, topi_compute)
if isinstance(op, tensor.ComputeOp):
op = _api_internal._ComputeOp(
op.name, op.tag, attrs, op.axis, op.body)
Expand Down Expand Up @@ -153,7 +150,7 @@ def _decorator(f):
if topi_schedule not in _REGISTED_DISPATHCER[target_key]:
@topi_schedule.register(target_key)
@dispatcher
def config_dispatcher(outs):
def config_dispatcher(outs, *args, **kwargs):
"""override topi call as a workload dispatcher"""
def traverse(tensors):
"""traverse all ops to find attached workload"""
Expand All @@ -179,11 +176,11 @@ def traverse(tensors):
config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_schedule]

@config_dispatcher.register(template_keys)
def template_call(cfg, outs):
def template_call(cfg, outs, *args, **kwargs):
"""call the schedule func"""
if f == topi_schedule.fdefault:
return f(outs)
return f(cfg, outs)
return f(outs, *args, **kwargs)
return f(cfg, outs, *args, **kwargs)

return f

Expand Down
22 changes: 2 additions & 20 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,33 +55,15 @@ def schedule_conv2d_nhwc(outs):


@tvm.target.generic_func
def schedule_conv2d_NCHWc(num_filter, kernel_size, strides,
padding, layout, out_layout, outs):
def schedule_conv2d_NCHWc(outs):
"""Schedule for conv2d_NCHW[x]c
Parameters
----------
num_filter : int
The number of filter, i.e., the output channel.
kernel_size : tuple of int
(kernel_height, kernel_width)
strides : tuple of int
(stride_of_height, stride_of_width)
padding : tuple of int
(pad_of_height, pad_of_width)
layout : str
Input data layout
out_layout : str
Output data layout
outs : Array of Tensor
The computation graph description of conv2d_NCHWc
in the format of an array of tensors.
The number of filter, i.e., the output channel.
Returns
-------
Expand Down
Loading

0 comments on commit 74165ce

Please sign in to comment.