Skip to content

Commit

Permalink
Code clean for minimum Ansor system
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 committed Jun 24, 2020
1 parent 910964e commit d567617
Show file tree
Hide file tree
Showing 20 changed files with 10 additions and 7,049 deletions.
2 changes: 1 addition & 1 deletion python/tvm/ansor/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def apply_steps_from_state(self, state):
args : List[Tensor]
"""
state_obj = state if isinstance(state, StateObject) else state.state_object
return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj, layout_rewrite_level)
return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj)

def print_python_code_from_state(self, state):
"""
Expand Down
99 changes: 1 addition & 98 deletions python/tvm/ansor/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,103 +389,6 @@ def local_builder_build(inputs: List[MeasureInput], timeout: float, n_parallel:

return results


@tvm._ffi.register_func("ansor.rpc_runner.run")
def rpc_runner_run(inputs: List[MeasureInput], build_results: List[BuildResult],
key: str, host: str, port: int, priority: int, timeout: float,
n_parallel: int, number: int, repeat: int, min_repeat_ms: int,
cooldown_interval: float, verbose: int):
global global_run_arguments
global_run_arguments = (inputs, build_results, key, host, port, priority, timeout, number,
repeat, min_repeat_ms, cooldown_interval, verbose)

assert len(inputs) == len(build_results), \
"Measure input size should be equal to build results"
pool = NoDaemonPool(n_parallel)
tuple_res = pool.map(rpc_run_worker, range(len(build_results)))
pool.terminate()
pool.join()
del pool

results = []
for res in tuple_res:
results.append(MeasureResult(*res))

if verbose >= 1:
print("")

return results


def rpc_run_worker(index):
""" ...
"""
inputs, build_results, key, host, port, priority, timeout, number, \
repeat, min_repeat_ms, cooldown_interval, verbose = global_run_arguments

MAX_FLOAT = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log
inp = inputs[index]
build_res = build_results[index]

if build_res.error_no != MeasureErrorNo.NO_ERROR:
return (MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, \
time.time()

def timed_func():
tic = time.time()
error_no = 0
error_msg = None
try:
# upload built module
remote = request_remote(key, host, port, priority, timeout)
remote.upload(build_res.filename)
func = remote.load_module(os.path.split(build_res.filename)[1])
ctx = remote.context(str(inp.task.target), 0)
time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms)
except Exception:
costs = (MAX_FLOAT,)
error_no = MeasureErrorNo.COMPILE_DEVICE
error_msg = make_error_msg()

if error_no == 0:
try:
args = [ndarray.non_empty(get_const_tuple(x.shape), x.dtype, ctx) for x in
build_res.args]
ctx.sync()

costs = time_f(*args).results
# clean up remote files
remote.remove(build_res.filename)
remote.remove(os.path.splitext(build_res.filename)[0] + '.so')
remote.remove('')
except Exception:
costs = (MAX_FLOAT,)
error_no = MeasureErrorNo.RUNTIME_DEVICE
error_msg = make_error_msg()

shutil.rmtree(os.path.dirname(build_res.filename))
toc = time.time()

time.sleep(cooldown_interval)
if verbose >= 1:
if error_no == MeasureErrorNo.NO_ERROR:
print("*", end="")
else:
print("*E", end="") # Run error

return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc

res = call_func_with_timeout(timeout, timed_func)

if isinstance(res, TimeoutError):
if verbose >= 1:
print("*T", end="") # Run timeout
res = (MAX_FLOAT,), MeasureErrorNo.RUN_TIMEOUT, None, build_res.time_cost + \
timeout, time.time()
return res


@tvm._ffi.register_func("ansor.local_runner.run")
def local_run(inputs: List[MeasureInput], build_results: List[BuildResult],
timeout: float, number: int, repeat: int, min_repeat_ms: int,
Expand All @@ -510,7 +413,7 @@ def timed_func(inp, build_res):

if error_no == 0:
try:
args = [ndarray.non_empty(get_const_tuple(x.shape), x.dtype, ctx) for x in
args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in
build_res.args]
ctx.sync()

Expand Down
1 change: 0 additions & 1 deletion src/ansor/auto_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <tvm/runtime/registry.h>
#include <string>
#include <utility>
#include "search_policy/sketch_search_policy.h"

namespace tvm {
namespace ansor {
Expand Down
Loading

0 comments on commit d567617

Please sign in to comment.