Skip to content

Commit

Permalink
[AUTOTVM] Improve tutorial and logging (apache#1544)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored and tqchen committed Aug 3, 2018
1 parent 3360674 commit 136061d
Show file tree
Hide file tree
Showing 17 changed files with 200 additions and 116 deletions.
2 changes: 1 addition & 1 deletion python/tvm/autotvm/measure/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Distributed executor infrastructure to scale up the tuning"""

from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option
from .measure_methods import request_remote, create_measure_batch, use_rpc
from .measure_methods import request_remote, check_remote, create_measure_batch, use_rpc

from .local_executor import LocalExecutor
from .executor import Future, Executor
48 changes: 44 additions & 4 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import time
from random import getrandbits
import threading

import numpy as np

Expand All @@ -23,6 +24,7 @@
from .measure import MeasureResult, MeasureErrorNo
from .local_executor import LocalExecutor

logger = logging.getLogger('autotvm')

class HashMismatchError(ValueError):
"""Raised when the code hash of a submitted config doesn't match that on the
Expand All @@ -42,9 +44,9 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
If is none, will use environment variable "TVM_TRACKER_HOST"
and "TVM_TRACKER_PORT"
priority: int, optional
priority of this request, larger is more prior
The priority of this request, larger is more prior
timeout: float, optional
timeout of this session (units: seconds)
The timeout of this session (units: seconds)
Returns
------
Expand All @@ -63,6 +65,33 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
session_timeout=timeout)
return remote

def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10):
"""
Check the availability of a remote device
Parameters
----------
target: Target
The wanted compilation target
device_key: string
device key of registered device in tracker
tracker_addr: Tuple(string, int), optional
The address of rpc tracker in (host, port) format.
If is none, will use environment variable "TVM_TRACKER_HOST"
and "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this check (units: seconds).
If time is out, a RuntimerError will be raised.
"""
def _check():
remote = request_remote(device_key, tracker_addr, priority)
remote.context(str(target))
t = threading.Thread(target=_check,)
t.start()
t.join(timeout)
return not t.is_alive()

def create_measure_batch(task, option):
"""Get a standard measure_batch function.
Expand Down Expand Up @@ -115,6 +144,17 @@ def create_measure_batch(task, option):
build_func = default_build_func
build_kwargs['use_ndk'] = True

# check the availability of remote devices
if hasattr(measure_func, 'rpc_info'):
rpc_info = measure_func.rpc_info
if check_remote(task.target, rpc_info['key'], (rpc_info['host'], rpc_info['port'])):
logger.info("Get devices for measurement successfully!")
else:
raise RuntimeError("Cannot get remote devices from the tracker. "
"Please check the status of tracker by "
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
"and make sure you have free devices on the queue status.")

# add device info of cuda and opencl target
if ('cuda' in task.target.keys or 'opencl' in task.target.keys) \
and hasattr(measure_func, 'rpc_info'):
Expand Down Expand Up @@ -313,7 +353,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
continue
except InstantiationError as e:
tstamp = time.time()
res_pack.append(MeasureResult((e,),
res_pack.append(MeasureResult((InstantiationError(str(e)),),
MeasureErrorNo.INSTANTIATION_ERROR,
tstamp - tic, tstamp))
continue
Expand Down Expand Up @@ -346,7 +386,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
if ref_output:
for expected, real in zip(ref_output, args):
if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
logging.warning("Wrong Answer!")
logger.warning("Wrong Answer!")
errno = MeasureErrorNo.WRONG_ANSWER
except TVMError as exc:
msg = str(exc)
Expand Down
13 changes: 7 additions & 6 deletions python/tvm/autotvm/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .measure import MeasureInput, MeasureResult

AUTOTVM_LOG_VERSION = 0.1
logger = logging.getLogger('autotvm')

try: # convert unicode to str for python2
_unicode = unicode
Expand Down Expand Up @@ -181,10 +182,10 @@ def split_workload(in_file, clean=True):
tic = time.time()
lines = list(open(in_file).readlines())

logging.info("start converting...")
logger.info("start converting...")
pool = multiprocessing.Pool()
lines = pool.map(decode, lines)
logging.info("map done %.2f", time.time() - tic)
logger.info("map done %.2f", time.time() - tic)

wkl_dict = OrderedDict()
for inp, res in lines:
Expand All @@ -206,13 +207,13 @@ def split_workload(in_file, clean=True):
cleaned.append([inp, res])

# write to file
logging.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned))
logger.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned))
with open(args.i + ".%03d.wkl" % i, 'w') as fout:
for inp, res in cleaned:
fout.write(encode(inp, res) + '\n')
else:
for i, (k, v) in enumerate(wkl_dict.items()):
logging.info("Key: %s\tNum: %d", k, len(v))
logger.info("Key: %s\tNum: %d", k, len(v))
with open(args.i + ".%03d.wkl" % i, 'w') as fout:
for inp, res in v:
fout.write(encode(inp, res) + '\n')
Expand All @@ -238,7 +239,7 @@ def pick_best(in_file, out_file):
for v in best_context.best_by_targetkey.values():
best_set.add(measure_str_key(v[0]))

logging.info("Extract %d best records from the %s", len(best_set), in_file)
logger.info("Extract %d best records from the %s", len(best_set), in_file)
fout = open(out_file, 'w') if isinstance(out_file, str) else out_file

for inp, res in load_from_file(in_file):
Expand Down Expand Up @@ -270,7 +271,7 @@ def pick_best(in_file, out_file):
parser.add_argument("--code", action='store_true')

args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logger.basicConfig(level=logger.INFO)

if args.mode == 'pick':
args.o = args.o or args.i + ".best.log"
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/autotvm/task/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
- During search, we can use it to pass the current proposal from tuner.
- During evaluation, we can use it to set pick the best policy.
"""
# pylint: disable=invalid-name

from __future__ import absolute_import as _abs

import logging
Expand All @@ -19,6 +21,8 @@

from tvm import target as _target

logger = logging.getLogger('autotvm')

class DispatchContext(object):
"""
Base class of dispatch context.
Expand Down Expand Up @@ -216,7 +220,7 @@ def load(self, records):
best_by_model[key] = (inp, res)
break

logging.debug("Finish loading %d records", counter)
logger.debug("Finish loading %d records", counter)

def query(self, target, workload):
if target is None:
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/autotvm/tophub.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
TVM will download these parameters for you when you create the target for the first time.
"""
# pylint: disable=invalid-name

import logging
import os
Expand All @@ -16,6 +17,7 @@

AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub")

logger = logging.getLogger('autotvm')

def _alias(name):
"""convert alias for some packages"""
Expand Down Expand Up @@ -79,7 +81,7 @@ def download_package(backend):
os.mkdir(path)

backend = _alias(backend)
logging.info("Download pre-tuned parameters for %s", backend)
logger.info("Download pre-tuned parameters for %s", backend)
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/%s.log" % backend,
os.path.join(rootpath, backend + ".log"), True, verbose=0)

Expand Down Expand Up @@ -110,7 +112,7 @@ def list_packages():
"""
path = tempdir()
filename = path.relpath("info.json")
logging.info("Download meta info for pre-tuned parameters")
logger.info("Download meta info for pre-tuned parameters")
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/info.json",
filename, True, verbose=0)

Expand Down
19 changes: 11 additions & 8 deletions python/tvm/autotvm/tuner/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
"""Namespace of callback utilities of AutoTVM"""
import sys
import time
import logging

import numpy as np

from .. import record

logger = logging.getLogger('autotvm')

def log_to_file(file_out, protocol='json'):
"""Log the tuning records into file.
Expand Down Expand Up @@ -90,7 +92,7 @@ def progress_bar(total, prefix=''):
prefix: str
The prefix of output message
"""
class _Context:
class _Context(object):
"""Context to store local variables"""
def __init__(self):
self.best_flops = 0
Expand All @@ -112,13 +114,14 @@ def _callback(tuner, inputs, results):
if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs)

ctx.cur_flops = flops
ctx.best_flops = tuner.best_flops
if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
ctx.cur_flops = flops
ctx.best_flops = tuner.best_flops

sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
'| %.2f s' %
(prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total,
time.time() - tic))
sys.stdout.flush()
sys.stdout.write('%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
'| %.2f s\r' %
(prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total,
time.time() - tic))
sys.stdout.flush()

return _callback
20 changes: 11 additions & 9 deletions python/tvm/autotvm/tuner/sa_model_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint: disable=consider-using-enumerate
# pylint: disable=consider-using-enumerate, invalid-name
"""
Cost model optimizer based on simulated annealing
"""
Expand All @@ -12,6 +12,8 @@
from ..util import sample_ints
from .model_based_tuner import ModelOptimizer, knob2point, point2knob

logger = logging.getLogger('autotvm')

class SimulatedAnnealingOptimizer(ModelOptimizer):
"""parallel simulated annealing optimization algorithm
Expand Down Expand Up @@ -103,16 +105,16 @@ def find_maximums(self, model, num, exclusive):

if log_interval and k % log_interval == 0:
t_str = "%.2f" % t
logging.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
"elapsed: %.2f",
k, k_last_modify, heap_items[0][0],
np.max([v for v, _ in heap_items]), t_str,
time.time() - tic)
logger.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
"elapsed: %.2f",
k, k_last_modify, heap_items[0][0],
np.max([v for v, _ in heap_items]), t_str,
time.time() - tic)

heap_items.sort(key=lambda item: -item[0])
logging.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f",
k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic)
logging.debug("SA Maximums: %s", heap_items)
logger.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f",
k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic)
logger.debug("SA Maximums: %s", heap_items)

if self.persistent:
self.points = points
Expand Down
26 changes: 18 additions & 8 deletions python/tvm/autotvm/tuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

import numpy as np

from ..measure import MeasureInput
from ..measure import create_measure_batch
from ..measure import MeasureInput, create_measure_batch

from ..env import GLOBAL_SCOPE

logger = logging.getLogger('autotvm')

class Tuner(object):
"""Base class for tuners
Expand Down Expand Up @@ -86,9 +87,10 @@ def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()):
measure_batch = create_measure_batch(self.task, measure_option)
parallel_num = getattr(measure_batch, 'parallel_num', 1)
early_stopping = early_stopping or 1e9
old_level = logger.level

GLOBAL_SCOPE.in_tuning = True
i = 0
i = error_ct = 0
while i < n_trial:
if not self.has_next():
break
Expand All @@ -103,17 +105,20 @@ def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()):
config = inp.config
if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs)
error_ct = 0
else:
flops = 0
error_ct += 1

if flops > self.best_flops:
self.best_flops = flops
self.best_config = config
self.best_measure_pair = (inp, res)
self.best_iter = i + k

logging.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
i + k + 1, flops / 1e9, self.best_flops / 1e9,
res, config)
logger.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
i + k + 1, flops / 1e9, self.best_flops / 1e9,
res, config)

i += len(results)

Expand All @@ -123,11 +128,16 @@ def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()):
callback(self, inputs, results)

if i > self.best_iter + early_stopping:
logging.debug("Early stopped. Best iter: %d.", self.best_iter)
logger.debug("Early stopped. Best iter: %d.", self.best_iter)
break

GLOBAL_SCOPE.in_tuning = False
if error_ct > 50:
logger.warning("Too many errors happen in the tuning. Now is in debug mode")
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(old_level)

GLOBAL_SCOPE.in_tuning = False
del measure_batch

def reset(self):
Expand Down
Loading

0 comments on commit 136061d

Please sign in to comment.