From 497bc3d13e4764cc49f1ddd8ee7fd9e0a76055e6 Mon Sep 17 00:00:00 2001 From: eqy Date: Mon, 24 Dec 2018 11:28:36 -0800 Subject: [PATCH] [RELAY][AUTOTVM] Extract tuning tasks from Relay programs (#2181) --- python/tvm/autotvm/task/__init__.py | 1 + python/tvm/autotvm/task/nnvm_integration.py | 231 +++--------------- python/tvm/autotvm/task/relay_integration.py | 200 +++++++++++++++ python/tvm/autotvm/task/topi_integration.py | 192 ++++++++++++++- .../relay/test_autotvm_task_extraction.py | 56 +++++ topi/python/topi/x86/conv2d.py | 2 +- topi/python/topi/x86/depthwise_conv2d.py | 2 +- 7 files changed, 477 insertions(+), 207 deletions(-) create mode 100644 python/tvm/autotvm/task/relay_integration.py create mode 100644 tests/python/relay/test_autotvm_task_extraction.py diff --git a/python/tvm/autotvm/task/__init__.py b/python/tvm/autotvm/task/__init__.py index 04bcec92fd57..f6ea07c272d0 100644 --- a/python/tvm/autotvm/task/__init__.py +++ b/python/tvm/autotvm/task/__init__.py @@ -14,3 +14,4 @@ from .topi_integration import register_topi_compute, register_topi_schedule from .nnvm_integration import extract_from_graph, extract_from_multiple_graph +from .relay_integration import extract_from_program, extract_from_multiple_program diff --git a/python/tvm/autotvm/task/nnvm_integration.py b/python/tvm/autotvm/task/nnvm_integration.py index 6a07194a594d..cd7337586519 100644 --- a/python/tvm/autotvm/task/nnvm_integration.py +++ b/python/tvm/autotvm/task/nnvm_integration.py @@ -7,208 +7,13 @@ import logging -from ... import tensor, placeholder, create_schedule, target as _target +from ... import target as _target -from ..util import get_const_tuple -from .task import create, register +from .task import create +from .topi_integration import TaskExtractEnv logger = logging.getLogger('autotvm') -def serialize_args(args): - """serialize arguments of a topi function to a hashable tuple. - - Parameters - ---------- - args: list of hashable or Tensor - """ - ret = [] - for t in args: - if isinstance(t, tensor.Tensor): - ret.append(('TENSOR', get_const_tuple(t.shape), t.dtype)) - else: - ret.append(t) - return tuple(ret) - - -def deserialize_args(args): - """The inverse function of :code:`serialize_args`. - - Parameters - ---------- - args: list of hashable or Tensor - """ - ret = [] - for t in args: - if isinstance(t, tuple) and t[0] == 'TENSOR': - ret.append(placeholder(shape=t[1], dtype=t[2])) - else: - ret.append(t) - return ret - - -# Task extractor for nnvm graph -class TaskExtractEnv: - """Global environment for extracting tuning tasks from nnvm graph""" - current = None - - def __init__(self): - import topi - import nnvm - - # NOTE: To add more symbols, you only need to change the following lists - # nnvm symbol -> topi compute - self.symbol2topi = { - nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw, - topi.nn.group_conv2d_nchw], - nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw], - nnvm.sym.dense: [topi.nn.dense], - } - - # topi compute -> autotvm task name - self.topi_to_task = { - topi.nn.conv2d: "topi_nn_conv2d", - topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw", - topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw", - topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw", - topi.nn.dense: "topi_nn_dense", - } - - self.topi_to_schedule = { - topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw, - topi.generic.schedule_conv2d_nhwc], - topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw, - topi.generic.schedule_depthwise_conv2d_nhwc], - topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw], - topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw], - topi.nn.dense: [topi.generic.schedule_dense], - } - - self._register_tracing() - self._register_topi_task() - self.task_collection = [] - self.wanted_topi_funcs = list(self.topi_to_task.keys()) - - def _register_tracing(self): - """Register tracing function to track the topi function call""" - # register topi compute for "tracing" target - for topi_compute in self.topi_to_task: - def _local_scope(compute_func): - """start a scope to hold the local function in for loop""" - - @compute_func.register("tracing", ) - def _tracing_topi_compute(*args, **kwargs): - assert not kwargs, "Do not support extracting tuning tasks when" \ - "kwargs is used in TOPI function call." \ - "Please modify it to use only positional args." - - if compute_func in self.wanted_topi_funcs: # record this call - key = (self.topi_to_task[compute_func], serialize_args(args)) - if key not in self.task_collection: - self.task_collection.append(key) - - return compute_func.fdefault(*args) - _local_scope(topi_compute) - - # register topi schedule for "tracing" target - for topi_compute in self.topi_to_task: - for topi_schedule in self.topi_to_schedule[topi_compute]: - def _local_scope_(schedule_func): - """start a scope to hold the local function in for loop""" - - @schedule_func.register("tracing", ) - def _tracing_topi_compute(outs): - outs = [outs] if isinstance(outs, tensor.Tensor) else outs - return create_schedule([x.op for x in outs]) - _local_scope_(topi_schedule) - - def _register_topi_task(self): - """register tuning wrapper for topi function""" - import topi - - # Tuning wrapper for topi functions - @register("topi_nn_conv2d") - def _topi_nn_conv2d(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - A, W = args[:2] - layout = args[-2] - assert layout == 'NCHW', "only support NCHW currently" - C = topi.nn.conv2d(*args, **kwargs) - s = topi.generic.schedule_conv2d_nchw([C]) - return s, [A, W, C] - - @register("topi_nn_depthwise_conv2d_nchw") - def _topi_nn_depthwise_conv2d_nchw(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - A, W = args[:2] - C = topi.nn.depthwise_conv2d_nchw(*args, **kwargs) - s = topi.generic.schedule_depthwise_conv2d_nchw([C]) - return s, [A, W, C] - - @register("topi_nn_group_conv2d_nchw") - def _topi_nn_group_conv2d_nchw(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - A, W = args[:2] - C = topi.nn.group_conv2d_nchw(*args, **kwargs) - s = topi.generic.schedule_group_conv2d_nchw([C]) - return s, [A, W, C] - - @register("topi_nn_conv2d_transpose_nchw") - def _topi_nn_conv2d_transpose_nchw(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - A, W = args[:2] - C = topi.nn.conv2d_transpose_nchw(*args, **kwargs) - s = topi.generic.schedule_conv2d_transpose_nchw([C]) - return s, [A, W, C] - - @register("topi_nn_dense") - def _topi_nn_dense(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - data, weight, bias = args - C = topi.nn.dense(*args, **kwargs) - s = topi.generic.schedule_dense([C]) - if bias is not None: - return s, [data, weight, bias, C] - return s, [data, weight, C] - - def reset(self, wanted_topi_funcs): - """Reset task collections - - Parameters - ---------- - wanted_topi_funcs: List of function - The topi function to be extracted - """ - self.task_collection = [] - self.wanted_topi_funcs = wanted_topi_funcs - - def get_tasks(self): - """Get collected tasks - - Returns - ------- - tasks: List of tuple(name, args) - A list of tasks extracted from the nnvm graph - """ - return self.task_collection - - @staticmethod - def get(): - """Get the single instance of TaskExtractEnv - - Returns - ------- - env: TaskExtractEnv - The single instance of TaskExtractEnv - """ - if not TaskExtractEnv.current: - TaskExtractEnv.current = TaskExtractEnv() - return TaskExtractEnv.current - def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): """ Extract tuning tasks from a nnvm graph. @@ -237,13 +42,24 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): collected tasks """ import nnvm.compiler + import nnvm + import topi env = TaskExtractEnv.get() + #NOTE: To add more symbols, you only need to change the following lists + #nnvm symbol -> topi compute + SYMBOL2TOPI = { + nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw, + topi.nn.group_conv2d_nchw], + nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw], + nnvm.sym.dense: [topi.nn.dense], + } + topi_funcs = [] for sym_name in symbols: - if sym_name in env.symbol2topi: - topi_funcs.extend(env.symbol2topi[sym_name]) + if sym_name in SYMBOL2TOPI: + topi_funcs.extend(SYMBOL2TOPI[sym_name]) else: warnings.warn("Symbol %s is not tunable, ignored" % sym_name) @@ -297,13 +113,24 @@ def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_ collected tasks """ import nnvm.compiler + import nnvm + import topi env = TaskExtractEnv.get() + #NOTE: To add more symbols, you only need to change the following lists + #nnvm symbol -> topi compute + SYMBOL2TOPI = { + nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw, + topi.nn.group_conv2d_nchw], + nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw], + nnvm.sym.dense: [topi.nn.dense], + } + topi_funcs = [] for sym_name in symbols: - if sym_name in env.symbol2topi: - topi_funcs.extend(env.symbol2topi[sym_name]) + if sym_name in SYMBOL2TOPI: + topi_funcs.extend(SYMBOL2TOPI[sym_name]) else: warnings.warn("Symbol %s is not tunable, ignored" % sym_name) diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py new file mode 100644 index 000000000000..21acf257f9ac --- /dev/null +++ b/python/tvm/autotvm/task/relay_integration.py @@ -0,0 +1,200 @@ +# pylint: disable=unused-variable,invalid-name +""" +Decorator and utilities for the integration with TOPI and Relay +99.9% copy-paste of implementation by @MerryMercy + +""" +import threading +import warnings +import logging + + +from ... import tensor, placeholder, target as _target + +from .task import create +from .topi_integration import TaskExtractEnv + +logger = logging.getLogger('autotvm') + + +def serialize_args(args): + """serialize arguments of a topi function to a hashable tuple. + + Parameters + ---------- + args: list of hashable or Tensor + """ + ret = [] + for t in args: + if isinstance(t, tensor.Tensor): + ret.append(('TENSOR', get_const_tuple(t.shape), t.dtype)) + else: + ret.append(t) + return tuple(ret) + + +def deserialize_args(args): + """The inverse function of :code:`serialize_args`. + + Parameters + ---------- + args: list of hashable or Tensor + """ + ret = [] + for t in args: + if isinstance(t, tuple) and t[0] == 'TENSOR': + ret.append(placeholder(shape=t[1], dtype=t[2])) + else: + ret.append(t) + return ret + + +def extract_from_program(func, params, ops, target, target_host=None): + """ Extract tuning tasks from a relay program. + + This function collects tuning tasks by building the program + with a "tracing" target and tracing all the calls to topi. + + Parameters + ---------- + func: relay.expr.Function + The func to tune + params: dict of str to numpy array + The associated parameters of the program + ops: List of relay op + List of relay ops to be tuned + dtype: str or dict of str to str + The input types to the program + target: tvm.target.Target + The compilation target + target_host: tvm.target.Target + The host compilation target + + Returns + ------- + task: Array of autotvm.task.Task + collected tasks + """ + env = TaskExtractEnv.get() + import tvm.relay.op + from tvm import relay + import topi + + # NOTE: To add more ops, you only need to change the following lists + # relay op -> topi compute + OP2TOPI = { + tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw, + topi.nn.group_conv2d_nchw], + tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw], + tvm.relay.op.nn.dense: [topi.nn.dense], + } + + topi_funcs = [] + for op_name in ops: + if op_name in OP2TOPI: + topi_funcs.extend(OP2TOPI[op_name]) + else: + warnings.warn("Op %s is not tunable, ignored" % op_name) + + # run compiler to collect all TOPI calls during compilation + env.reset(topi_funcs) + + # disable logger temporarily + old_state = logger.disabled + logger.disabled = True + + # use a "tracing" target to do a fake compile for collecting topi calls + tracing_target = _target.create("llvm -device=tracing") + relay.backend.compile_engine.get().clear() + # wrap build call in thread to avoid multiprocessing problems + build_thread = threading.Thread(target=relay.build, args=(func, + tracing_target, + target_host, + params)) + build_thread.start() + build_thread.join() + logger.disabled = old_state + + # create tasks for target + tasks = [] + for task_name, args in env.get_tasks(): + tasks.append(create(task_name, args, + target=target, target_host=target_host, + template_key='direct')) + + return tasks + + +def extract_from_multiple_program(funcs, params, ops, target, target_host=None): + """ Extract tuning tasks from multiple relay programs. + + This function is the multiple program version of extract_from_program + + Parameters + ---------- + funcs: List of relay.expr.Function + The list of functions to tune + params: List of dict of str to numpy array + The associated parameters of the programs + ops: List of relay op + List of relay ops to be tuned + target: tvm.target.Target + The compilation target + target_host: tvm.target.Target + The host compilation target + + Returns + ------- + task: Array of autotvm.task.Task + collected tasks + """ + env = TaskExtractEnv.get() + import tvm.relay.op + from tvm import relay + import topi + + # NOTE: To add more ops, you only need to change the following lists + # relay op -> topi compute + OP2TOPI = { + tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw, + topi.nn.group_conv2d_nchw], + tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw], + tvm.relay.op.nn.dense: [topi.nn.dense], + } + + topi_funcs = [] + for op_name in ops: + if op_name in OP2TOPI: + topi_funcs.extend(OP2TOPI[op_name]) + else: + warnings.warn("Op %s is not tunable, ignored" % op_name) + + # run compiler to collect all TOPI calls during compilation + env.reset(topi_funcs) + + # disable logger temporarily + old_state = logger.disabled + logger.disabled = True + + # use a "tracing" target to do a fake compile for collecting topi calls + tracing_target = _target.create("llvm -device=tracing") + + for func, param in zip(funcs, params): + # wrap build call in thread to avoid multiprocessing problems + build_thread = threading.Thread(target=relay.build, args=(func, + tracing_target, + target_host, + params)) + build_thread.start() + build_thread.join() + + logger.disabled = old_state + + # create tasks for target + tasks = [] + for task_name, args in env.get_tasks(): + tasks.append(create(task_name, args, + target=target, target_host=target_host, + template_key='direct')) + + return tasks diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index f005ee0c9a54..412d7ae0e40b 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -11,16 +11,202 @@ See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. """ -from ... import _api_internal, tensor - -from .task import args_to_workload, dispatcher +from ... import _api_internal, tensor, placeholder, create_schedule +from .task import args_to_workload, dispatcher, register +from ..util import get_const_tuple # A table that records all registered dispatcher for all targets _REGISTED_DISPATHCER = { } +def serialize_args(args): + """serialize arguments of a topi function to a hashable tuple. + + Parameters + ---------- + args: list of hashable or Tensor + """ + ret = [] + for t in args: + if isinstance(t, tensor.Tensor): + ret.append(('TENSOR', get_const_tuple(t.shape), t.dtype)) + else: + ret.append(t) + return tuple(ret) + + +def deserialize_args(args): + """The inverse function of :code:`serialize_args`. + + Parameters + ---------- + args: list of hashable or Tensor + """ + ret = [] + for t in args: + if isinstance(t, tuple) and t[0] == 'TENSOR': + ret.append(placeholder(shape=t[1], dtype=t[2])) + else: + ret.append(t) + return ret + + +# Task extractor for nnvm graph, relay program +class TaskExtractEnv: + """Global environment for extracting tuning tasks from nnvm graph""" + current = None + + def __init__(self): + import topi + + # topi compute -> autotvm task name + self.topi_to_task = { + topi.nn.conv2d: "topi_nn_conv2d", + topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw", + topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw", + topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw", + topi.nn.dense: "topi_nn_dense", + } + + self.topi_to_schedule = { + topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw, + topi.generic.schedule_conv2d_nhwc], + topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw, + topi.generic.schedule_depthwise_conv2d_nhwc], + topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw], + topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw], + topi.nn.dense: [topi.generic.schedule_dense], + } + + self._register_tracing() + self._register_topi_task() + self.task_collection = [] + self.wanted_topi_funcs = list(self.topi_to_task.keys()) + + def _register_tracing(self): + """Register tracing function to track the topi function call""" + # register topi compute for "tracing" target + for topi_compute in self.topi_to_task: + def _local_scope(compute_func): + """start a scope to hold the local function in for loop""" + + @compute_func.register("tracing", ) + def _tracing_topi_compute(*args, **kwargs): + assert not kwargs, "Do not support extracting tuning tasks when" \ + "kwargs is used in TOPI function call." \ + "Please modify it to use only positional args." + + if compute_func in self.wanted_topi_funcs: # record this call + key = (self.topi_to_task[compute_func], serialize_args(args)) + if key not in self.task_collection: + self.task_collection.append(key) + + return compute_func.fdefault(*args) + _local_scope(topi_compute) + + # register topi schedule for "tracing" target + for topi_compute in self.topi_to_task: + for topi_schedule in self.topi_to_schedule[topi_compute]: + def _local_scope_(schedule_func): + """start a scope to hold the local function in for loop""" + + @schedule_func.register("tracing", ) + def _tracing_topi_compute(outs): + outs = [outs] if isinstance(outs, tensor.Tensor) else outs + return create_schedule([x.op for x in outs]) + _local_scope_(topi_schedule) + + def _register_topi_task(self): + """register tuning wrapper for topi function""" + import topi + + # Tuning wrapper for topi functions + @register("topi_nn_conv2d") + def _topi_nn_conv2d(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + layout = args[-2] + assert layout == 'NCHW', "only support NCHW currently" + C = topi.nn.conv2d(*args, **kwargs) + s = topi.generic.schedule_conv2d_nchw([C]) + return s, [A, W, C] + + @register("topi_nn_depthwise_conv2d_nchw") + def _topi_nn_depthwise_conv2d_nchw(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + C = topi.nn.depthwise_conv2d_nchw(*args, **kwargs) + s = topi.generic.schedule_depthwise_conv2d_nchw([C]) + return s, [A, W, C] + + @register("topi_nn_group_conv2d_nchw") + def _topi_nn_group_conv2d_nchw(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + C = topi.nn.group_conv2d_nchw(*args, **kwargs) + s = topi.generic.schedule_group_conv2d_nchw([C]) + return s, [A, W, C] + + @register("topi_nn_conv2d_transpose_nchw") + def _topi_nn_conv2d_transpose_nchw(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + C = topi.nn.conv2d_transpose_nchw(*args, **kwargs) + s = topi.generic.schedule_conv2d_transpose_nchw([C]) + return s, [A, W, C] + + @register("topi_nn_dense") + def _topi_nn_dense(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + data, weight, bias = args + C = topi.nn.dense(*args, **kwargs) + s = topi.generic.schedule_dense([C]) + if bias is not None: + return s, [data, weight, bias, C] + return s, [data, weight, C] + + def reset(self, wanted_topi_funcs): + """Reset task collections + + Parameters + ---------- + wanted_topi_funcs: List of function + The topi function to be extracted + """ + self.task_collection = [] + self.wanted_topi_funcs = wanted_topi_funcs + + def get_tasks(self): + """Get collected tasks + + Returns + ------- + tasks: List of tuple(name, args) + A list of tasks extracted from the nnvm graph + """ + return self.task_collection + + @staticmethod + def get(): + """Get the single instance of TaskExtractEnv + + Returns + ------- + env: TaskExtractEnv + The single instance of TaskExtractEnv + """ + if not TaskExtractEnv.current: + TaskExtractEnv.current = TaskExtractEnv() + return TaskExtractEnv.current + + def register_topi_compute(topi_compute, target_keys, template_keys, func=None): """Register a tunable template for a topi compute function. diff --git a/tests/python/relay/test_autotvm_task_extraction.py b/tests/python/relay/test_autotvm_task_extraction.py new file mode 100644 index 000000000000..8c93e4a56642 --- /dev/null +++ b/tests/python/relay/test_autotvm_task_extraction.py @@ -0,0 +1,56 @@ +"""Test task extraction for autotvm""" +import tvm.relay.testing +from tvm import relay +from tvm import autotvm + +def get_network(name, batch_size): + """Get the symbol definition and random weight of a network""" + input_shape = (batch_size, 3, 224, 224) + + if name == 'resnet-18': + net, params = relay.testing.resnet.get_workload(num_layers=18, batch_size=batch_size) + elif name == 'mobilenet': + net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size) + elif name == 'dcgan': + net, params = relay.testing.dcgan.get_workload(batch_size=batch_size) + input_shape = (batch_size, 100) + else: + raise ValueError("Unsupported network: " + name) + + return net, params, input_shape + +def test_task_extraction(): + target = 'llvm' + + net, params, input_shape = get_network('resnet-18', batch_size=1) + tasks = autotvm.task.extract_from_program(net, target=target, + params=params, + ops=(relay.op.nn.conv2d,)) + assert len(tasks) == 12 + + net, params, input_shape = get_network('resnet-18', batch_size=1) + tasks = autotvm.task.extract_from_program(net, target=target, + params=params, + ops=(relay.op.nn.dense,)) + assert len(tasks) == 1 + + net, params, input_shape = get_network('resnet-18', batch_size=1) + tasks = autotvm.task.extract_from_program(net, target=target, + params=params, + ops=(relay.op.nn.conv2d, relay.op.nn.dense)) + assert len(tasks) == 13 + + net, params, input_shape = get_network('mobilenet', batch_size=1) + tasks = autotvm.task.extract_from_program(net, target=target, + params=params, + ops=(relay.op.nn.conv2d, relay.op.nn.dense)) + assert len(tasks) == 20 + + net, params, input_shape = get_network('dcgan', batch_size=1) + tasks = autotvm.task.extract_from_program(net, target=target, + params=params, + ops=(relay.op.nn.conv2d_transpose,)) + assert len(tasks) == 4 + +if __name__ == '__main__': + test_task_extraction() diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 1a73736264bd..fe38b38d38e0 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -2,7 +2,7 @@ """Conv2D schedule on x86""" import tvm from tvm import autotvm -from tvm.autotvm.task.nnvm_integration import deserialize_args +from tvm.autotvm.task.topi_integration import deserialize_args from tvm.autotvm.task import get_config from .. import generic, tag from .. import nn diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py index 8f37a0316229..64858df91cdc 100644 --- a/topi/python/topi/x86/depthwise_conv2d.py +++ b/topi/python/topi/x86/depthwise_conv2d.py @@ -4,7 +4,7 @@ from tvm import autotvm from tvm.autotvm.task import get_config from tvm.autotvm.task.space import SplitEntity -from tvm.autotvm.task.nnvm_integration import deserialize_args +from tvm.autotvm.task.topi_integration import deserialize_args from .. import generic, tag from ..nn.pad import pad from ..util import get_const_tuple