Skip to content

Commit

Permalink
[AutoTVM] Suppress the warning messages when compile engine selects i…
Browse files Browse the repository at this point in the history
…mpls (apache#5821)
  • Loading branch information
icemelon authored and Trevor Morris committed Jun 30, 2020
1 parent ad262b5 commit 1278e34
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 12 deletions.
1 change: 1 addition & 0 deletions python/tvm/autotvm/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ def __init__(self):

self.cuda_target_arch = None
self.in_tuning = False
self.silent = False

GLOBAL_SCOPE = AutotvmGlobalScope()
13 changes: 6 additions & 7 deletions python/tvm/autotvm/task/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import numpy as np

from .space import FallbackConfigEntity
from .. import env as _env

logger = logging.getLogger('autotvm')

Expand All @@ -47,6 +48,8 @@ class DispatchContext(object):
specific dispatch mechanism for templates.
"""
current = None
# a set to prevent print duplicated message
warning_messages = set()

def __init__(self):
self._old_ctx = DispatchContext.current
Expand Down Expand Up @@ -295,21 +298,17 @@ class FallbackContext(DispatchContext):
def __init__(self):
super(FallbackContext, self).__init__()
self.memory = {}
self.silent = False

# a set to prevent print duplicated message
self.messages = set()

def _query_inside(self, target, workload):
key = (str(target), workload)
if key in self.memory:
return self.memory[key]

if not self.silent:
if not _env.GLOBAL_SCOPE.silent:
msg = "Cannot find config for target=%s, workload=%s. A fallback configuration "\
"is used, which may bring great performance regression." % (target, workload)
if msg not in self.messages:
self.messages.add(msg)
if msg not in DispatchContext.warning_messages:
DispatchContext.warning_messages.add(msg)
logger.warning(msg)
cfg = FallbackConfigEntity()

Expand Down
18 changes: 16 additions & 2 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from . import _backend

logger = logging.getLogger('compile_engine')

autotvm_logger = logging.getLogger('autotvm')

@tvm._ffi.register_object("relay.LoweredOutput")
class LoweredOutput(Object):
Expand Down Expand Up @@ -190,24 +190,38 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
return best_plevel_impl, outs

outputs = {}
workloads = {}
best_autotvm_impl = None
best_cfg = None
dispatch_ctx = autotvm.task.DispatchContext.current
autotvm.GLOBAL_SCOPE.silent = True
for impl in all_impls:
outs = impl.compute(attrs, inputs, out_type)
outputs[impl] = outs
workload = autotvm.task.get_workload(outs)
workloads[impl] = workload
if workload is None:
# Not an AutoTVM tunable implementation
continue
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback:
# It's a fallback config
# Skip fallback config
continue
if best_cfg is None or best_cfg.cost > cfg.cost:
best_autotvm_impl = impl
best_cfg = cfg
autotvm.GLOBAL_SCOPE.silent = False
if best_autotvm_impl:
# The best autotvm implementation definitely doesn't use fallback config
return best_autotvm_impl, outputs[best_autotvm_impl]
# Use the implementation with highest plevel
if workloads[best_plevel_impl] is not None:
msg = "Cannot find config for target=%s, workload=%s. A fallback configuration "\
"is used, which may bring great performance regression." \
% (target, workloads[best_plevel_impl])
if msg not in autotvm.task.DispatchContext.warning_messages:
autotvm.task.DispatchContext.warning_messages.add(msg)
autotvm_logger.warning(msg)
return best_plevel_impl, outputs[best_plevel_impl]


Expand Down
5 changes: 3 additions & 2 deletions tests/python/integration/test_winograd_nnpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_conv2d_nchw():
skip("nnpack is not available")

devices = ['llvm -device=arm_cpu']
autotvm.DispatchContext.current.silent = True
autotvm.GLOBAL_SCOPE.silent = True
with WinogradFallback():
# resnet 18 workloads
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, devices=devices)
Expand Down Expand Up @@ -137,8 +137,9 @@ def test_conv2d_nchw():
# werid workloads
verify_conv2d_nchw(1, 3, 3, 3, 3, 1, 1, devices=devices)
verify_conv2d_nchw(1, 13, 71, 59, 3, 1, 1, devices=devices)
autotvm.GLOBAL_SCOPE.silent = False


if __name__ == "__main__":
import pytest
pytest.main()
pytest.main([__file__])
3 changes: 2 additions & 1 deletion topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filte
break

ic_block = 8
autotvm.DispatchContext.current.silent = True
autotvm.GLOBAL_SCOPE.silent = True
A = te.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='A', dtype='uint8')
W = te.placeholder((num_filter//oc_block, in_channel//ic_block//groups, kernel, kernel, ic_block//4, oc_block, 4), name='W', dtype='int8')

Expand Down Expand Up @@ -103,6 +103,7 @@ def check_device(device):
for device in ["llvm -mcpu=skylake-avx512"]:
with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device)
autotvm.GLOBAL_SCOPE.silent = False

@pytest.mark.skip
def test_conv2d_NCHWc():
Expand Down

0 comments on commit 1278e34

Please sign in to comment.