Skip to content

Commit

Permalink
[AutoScheduler] Tutorial on auto-scheduling a network for GPU (apache…
Browse files Browse the repository at this point in the history
…#6882)

* add a tutorial on auto-scheduling a network for cuda

* fix typo

* fix training time printing

* fix lint

* fix

* upload logs

* fix

* use weighted sum as the default objective function

* update ci logs

* fix the bug in kill_child_processes

* fix test

* address comments

* add early stopping in task scheduler & fix a stuck issue in measurement

* fix lint

* trigger CI

* fix early stopping
  • Loading branch information
merrymercy authored Nov 13, 2020
1 parent f952fa7 commit 050a836
Show file tree
Hide file tree
Showing 20 changed files with 631 additions and 128 deletions.
3 changes: 3 additions & 0 deletions include/tvm/auto_scheduler/measure.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>

namespace tvm {
Expand Down Expand Up @@ -436,6 +437,8 @@ class ProgramMeasurerNode : public Object {
std::unordered_map<std::string, State> best_state;
/*! \brief Workload key to best state's count index map. */
std::unordered_map<std::string, int> best_ct;
/*! \brief The set of workloads that have at least one valid schedule */
std::unordered_set<std::string> has_valid;
/*! \brief The ProgramBuilder to build each program. */
ProgramBuilder builder;
/*! \brief The ProgramRunner to measure each program. */
Expand Down
4 changes: 0 additions & 4 deletions python/tvm/auto_scheduler/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import multiprocessing
import logging
from collections import defaultdict
import time

import numpy as np

Expand Down Expand Up @@ -138,7 +137,6 @@ def update(self, inputs, results):
if len(inputs) <= 0:
return
assert len(inputs) == len(results)
tic = time.time()

self.inputs.extend(inputs)
self.results.extend(results)
Expand Down Expand Up @@ -178,8 +176,6 @@ def update(self, inputs, results):
],
)

logger.info("XGBModel Training time: %.2f s", time.time() - tic)

def predict(self, task, states):
"""Predict the scores of states
Parameters
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def query(self, target, workload_key):

if not self.silent:
msg = (
"Cannot find tuned schedule for target=%s, workload_key=%s. "
"Cannot find tuned schedules for target=%s, workload_key=%s. "
"A fallback schedule is used, "
"which may bring great performance regression." % (target, workload_key)
)
Expand Down
59 changes: 36 additions & 23 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import os
import time
import shutil
import traceback
import tempfile
import multiprocessing

Expand All @@ -48,10 +47,11 @@
from . import _ffi_api
from .loop_state import StateObject
from .utils import (
get_const_tuple,
call_func_with_timeout,
request_remote,
check_remote,
get_const_tuple,
make_traceback_info,
request_remote,
)
from .compute_dag import ComputeDAG
from .search_task import SearchTask
Expand All @@ -60,8 +60,6 @@
deserialize_workload_registry_entry,
)

# The maximum length of error message
MAX_ERROR_MSG_LEN = 512

# The time cost for measurements with errors
# We use 1e10 instead of sys.float_info.max for better readability in log
Expand Down Expand Up @@ -536,16 +534,6 @@ class MeasureErrorNo(object):
UNKNOWN_ERROR = 8 # Unknown error


def make_error_msg():
""" Get the error message from traceback. """
error_msg = str(traceback.format_exc())
if len(error_msg) > MAX_ERROR_MSG_LEN:
error_msg = (
error_msg[: MAX_ERROR_MSG_LEN // 2] + "\n...\n" + error_msg[-MAX_ERROR_MSG_LEN // 2 :]
)
return error_msg


def _timed_func(inp_serialized, build_func, verbose):
tic = time.time()
inp = MeasureInput.deserialize(inp_serialized)
Expand All @@ -560,14 +548,13 @@ def _timed_func(inp_serialized, build_func, verbose):
# pylint: disable=broad-except
except Exception:
error_no = MeasureErrorNo.INSTANTIATION_ERROR
error_msg = make_error_msg()
error_msg = make_traceback_info()

if error_no == 0:
dirname = tempfile.mkdtemp()
filename = os.path.join(dirname, "tmp_func." + build_func.output_format)

try:
# TODO(merrymercy): Port the unroll pass.
with transform.PassContext():
func = build_module.build(
sch, args, target=task.target, target_host=task.target_host
Expand All @@ -576,7 +563,7 @@ def _timed_func(inp_serialized, build_func, verbose):
# pylint: disable=broad-except
except Exception:
error_no = MeasureErrorNo.COMPILE_HOST
error_msg = make_error_msg()
error_msg = make_traceback_info()
else:
filename = ""

Expand All @@ -585,6 +572,7 @@ def _timed_func(inp_serialized, build_func, verbose):
print(".", end="")
else:
print(".E", end="") # Build error

return filename, args, error_no, error_msg, time.time() - tic


Expand Down Expand Up @@ -615,6 +603,10 @@ def local_build_worker(args):
if verbose >= 1:
print(".T", end="") # Build timeout
res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout
elif isinstance(res, Exception):
if verbose >= 1:
print(".E", end="") # Build error
res = None, [], MeasureErrorNo.COMPILE_HOST, str(res), timeout

return res

Expand Down Expand Up @@ -703,7 +695,7 @@ def _timed_eval_func(
except Exception:
costs = (MAX_FLOAT,)
error_no = MeasureErrorNo.COMPILE_DEVICE
error_msg = make_error_msg()
error_msg = make_traceback_info()

if error_no == 0:
try:
Expand All @@ -718,7 +710,7 @@ def _timed_eval_func(
except Exception:
costs = (MAX_FLOAT,)
error_no = MeasureErrorNo.RUNTIME_DEVICE
error_msg = make_error_msg()
error_msg = make_traceback_info()

shutil.rmtree(os.path.dirname(build_res.filename))
toc = time.time()
Expand Down Expand Up @@ -825,6 +817,17 @@ def local_run(
build_res.time_cost + timeout,
time.time(),
)
elif isinstance(res, Exception):
if verbose >= 1:
print("*E", end="") # Run error
res = (
(MAX_FLOAT,),
MeasureErrorNo.RUNTIME_DEVICE,
str(res),
build_res.time_cost + timeout,
time.time(),
)

measure_results.append(MeasureResult(*res))

if verbose >= 1:
Expand Down Expand Up @@ -876,7 +879,7 @@ def _timed_rpc_run(
except Exception:
costs = (MAX_FLOAT,)
error_no = MeasureErrorNo.COMPILE_DEVICE
error_msg = make_error_msg()
error_msg = make_traceback_info()

if error_no == 0:
try:
Expand All @@ -900,7 +903,7 @@ def _timed_rpc_run(
except Exception:
costs = (MAX_FLOAT,)
error_no = MeasureErrorNo.RUNTIME_DEVICE
error_msg = make_error_msg()
error_msg = make_traceback_info()

shutil.rmtree(os.path.dirname(build_res.filename))
toc = time.time()
Expand Down Expand Up @@ -939,7 +942,6 @@ def _rpc_run_worker(args):
)

res = call_func_with_timeout(timeout, _timed_rpc_run, args=args)

if isinstance(res, TimeoutError):
if verbose >= 1:
print("*T", end="") # Run timeout
Expand All @@ -950,6 +952,17 @@ def _rpc_run_worker(args):
build_res.time_cost + timeout,
time.time(),
)
elif isinstance(res, Exception):
if verbose >= 1:
print("*E", end="") # Run error
res = (
(MAX_FLOAT,),
MeasureErrorNo.RUNTIME_DEVICE,
str(res),
build_res.time_cost + timeout,
time.time(),
)

return res


Expand Down
81 changes: 78 additions & 3 deletions python/tvm/auto_scheduler/measure_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, pointless-string-statement

""" Serialization and other I/O support for measurement records (tuning logs). """
import argparse
import logging
import os
import itertools

import numpy as np

Expand All @@ -24,6 +29,8 @@
from .measure import MeasureErrorNo, MeasureCallback
from . import _ffi_api

logger = logging.getLogger("auto_scheduler")


@tvm._ffi.register_object("auto_scheduler.RecordToFile")
class RecordToFile(MeasureCallback):
Expand All @@ -36,7 +43,7 @@ class RecordToFile(MeasureCallback):
File name for this callback to write log to.
"""

def __init__(self, filename="auto_scheduler_tuning.json"):
def __init__(self, filename):
self.__init_handle_by_constructor__(_ffi_api.RecordToFile, filename)


Expand All @@ -47,11 +54,11 @@ class RecordReader(Object):
Parameters
----------
filename : str = "auto_scheduler_tuning.json"
filename : str
File name for this reader to load log from.
"""

def __init__(self, filename="auto_scheduler_tuning.json"):
def __init__(self, filename):
self.__init_handle_by_constructor__(_ffi_api.RecordReader, filename)

def read_lines(self, max_lines=None, skip_lines=0):
Expand Down Expand Up @@ -173,3 +180,71 @@ def load_best(filename, workload_key=None, target=None):
best_res = res

return best_inp, best_res


def distill_record_file(in_file, out_file):
"""
Pick the best entries from a record file and store them to another file.
This function distills the useful log entries from a large log file.
If out_file already exists, the best entries from both
in_file and out_file will be saved.
Parameters
----------
in_file: str
The filename of input
out_file: str or file
The filename of output
"""
# pylint: disable=import-outside-toplevel
from .dispatcher import ApplyHistoryBest

context = load_records(in_file)
if os.path.isfile(out_file):
out_context = load_records(out_file)
context = itertools.chain(context, out_context)
context, context_clone = itertools.tee(context)
best_context = ApplyHistoryBest(context)
best_set = set()

def measure_input_str_key(inp):
return _ffi_api.SerializeMeasureInput(inp)

for v in best_context.best_by_model.values():
best_set.add(measure_input_str_key(v[0]))

for v in best_context.best_by_targetkey.values():
best_set.add(measure_input_str_key(v[0]))

inputs = []
results = []
for inp, res in context_clone:
if measure_input_str_key(inp) in best_set:
inputs.append(inp)
results.append(res)
best_set.remove(measure_input_str_key(inp))

# create a new file and save the best records
open(out_file, "w")
save_records(out_file, inputs, results)
logger.info("Extract %d best records from %s to %s", len(inputs), in_file, out_file)


"""
Usage:
* Distill the best entries from a large log file
e.g. python -m tvm.auto_scheduler.measure_record --mode distill --i input.json
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["distill"], required=True)
parser.add_argument("--i", type=str, help="input file")
parser.add_argument("--o", type=str, default=None, help="output file")

args = parser.parse_args()
logging.basicConfig()
logger.setLevel(logging.INFO)

if args.mode == "distill":
args.o = args.o or args.i + ".best.json"
distill_record_file(args.i, args.o)
Loading

0 comments on commit 050a836

Please sign in to comment.