Skip to content

Commit

Permalink
UT ready
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 committed Jun 24, 2020
1 parent d567617 commit a8e589e
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 346 deletions.
5 changes: 2 additions & 3 deletions python/tvm/ansor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@
# Shortcut
from .compute_dag import ComputeDAG
from .auto_schedule import SearchTask, TuneOption, HardwareParams, \
auto_schedule
from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext
from .cost_model import RandomModel
auto_schedule, EmptyPolicy
from .measure import MeasureInput, LocalBuilder, LocalRunner
from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \
load_from_file, write_measure_records_to_file
from .workload_registry import register_workload_func, \
Expand Down
13 changes: 11 additions & 2 deletions python/tvm/ansor/auto_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import tvm._ffi
from tvm.runtime import Object
from .measure import LocalBuilder, LocalRunner
from .cost_model import RandomModel
from . import _ffi_api


Expand Down Expand Up @@ -82,10 +81,20 @@ def set_verbose(self, verbose):
def run_callbacks(self, callbacks):
_ffi_api.SearchPolicyRunCallbacks(self, callbacks)


@tvm._ffi.register_object("ansor.EmptyPolicy")
class EmptyPolicy(SearchPolicy):
""" The example search policy
"""
def __init__(self):
self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy)


@tvm._ffi.register_object("ansor.SearchCallback")
class SearchCallback(Object):
"""Callback function before or after search process"""


@tvm._ffi.register_object("ansor.TuneOption")
class TuneOption(Object):
""" The options for tuning
Expand Down Expand Up @@ -164,7 +173,7 @@ def auto_schedule(workload, target=None,
"""
if isinstance(search_policy, str):
if search_policy == 'default':
search_policy = SketchSearchPolicy(RandomModel())
search_policy = EmptyPolicy()
else:
raise ValueError("Invalid search policy: " + search_policy)

Expand Down
20 changes: 0 additions & 20 deletions python/tvm/ansor/cost_model/__init__.py

This file was deleted.

46 changes: 0 additions & 46 deletions python/tvm/ansor/cost_model/cost_model.py

This file was deleted.

98 changes: 0 additions & 98 deletions python/tvm/ansor/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,104 +178,6 @@ def __init__(self,
self.__init_handle_by_constructor__(
_ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval)

@tvm._ffi.register_object("ansor.ProgramMeasurer")
class ProgramMeasurer(Object):
"""
Parameters
----------
builder : Builder
runner : Runner
callbacks : List[MeasureCallback]
verbose : Int
max_continuous_error : Float
"""

def __init__(self, builder: Builder, runner: Runner,
callbacks: List[MeasureCallback],
verbose: int, max_continuous_error: int = -1):
self.__init_handle_by_constructor__(
_ffi_api.ProgramMeasurer, builder, runner, callbacks, verbose, max_continuous_error)

@tvm._ffi.register_object("ansor.RPCRunner")
class RPCRunner(Runner):
"""
Parameters
----------
key : Str
host : Str
port : Int
priority : Int
n_parallel : Int
timeout : Int
number : Int
repeat : Int
min_repeat_ms : Int
cooldown_interval : Float
"""

def __init__(self, key, host, port, priority=1,
n_parallel=1,
timeout=10,
number=3,
repeat=1,
min_repeat_ms=0,
cooldown_interval=0.0):
self.__init_handle_by_constructor__(
_ffi_api.RPCRunner, key, host, port, priority, timeout, n_parallel,
number, repeat, min_repeat_ms, cooldown_interval)

if check_remote(key, host, port, priority, timeout):
LOGGER.info("Get devices for measurement successfully!")
else:
raise RuntimeError("Cannot get remote devices from the tracker. "
"Please check the status of tracker by "
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
"and make sure you have free devices on the queue status.")


class LocalRPCMeasureContext:
""" A context wrapper for running RPCRunner locally.
This will launch a local RPC Tracker and local RPC Server.
Parameters
----------
priority : Int
n_parallel : Int
timeout : Int
number : Int
repeat : Int
min_repeat_ms : Int
cooldown_interval : Float
"""

def __init__(self,
priority=1,
n_parallel=1,
timeout=10,
number=10,
repeat=1,
min_repeat_ms=0,
cooldown_interval=0.0):
ctx = tvm.context("cuda", 0)
if ctx.exist:
cuda_arch = "sm_" + "".join(ctx.compute_version.split('.'))
set_cuda_target_arch(cuda_arch)
host = '0.0.0.0'
self.tracker = Tracker(host, port=9000, port_end=10000, silent=True)
device_key = '$local$device$%d' % self.tracker.port
self.server = Server(host, port=self.tracker.port, port_end=10000,
key=device_key, use_popen=True, silent=True,
tracker_addr=(self.tracker.host, self.tracker.port))
self.runner = RPCRunner(device_key, host, self.tracker.port, priority,
n_parallel, timeout, number, repeat,
min_repeat_ms, cooldown_interval)
# wait for the processes to start
time.sleep(0.5)

def __del__(self):
self.server.terminate()
self.tracker.terminate()


class MeasureErrorNo(object):
"""Error type for MeasureResult"""
Expand Down
42 changes: 0 additions & 42 deletions src/ansor/measure.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
TVM_REGISTER_OBJECT_TYPE(RunnerNode);
TVM_REGISTER_OBJECT_TYPE(BuilderNode);
TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode);
TVM_REGISTER_OBJECT_TYPE(RPCRunnerNode);
TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode);
TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode);

const char* ErrorNoToStr[] = {
"NoError",
Expand Down Expand Up @@ -127,38 +125,6 @@ Array<BuildResult> LocalBuilderNode::Build(const Array<MeasureInput>& inputs,
return Array<BuildResult>();
}

// RPC Runner
RPCRunner::RPCRunner(const std::string& key, const std::string& host, int port,
int priority, int timeout, int n_parallel, int number,
int repeat, int min_repeat_ms, double cooldown_interval) {
auto node = make_object<RPCRunnerNode>();
node->key = key;
node->host = host;
node->port = port;
node->priority = priority;
node->timeout = timeout;
node->n_parallel = n_parallel;
node->number = number;
node->repeat = repeat;
node->min_repeat_ms = min_repeat_ms;
node->cooldown_interval = cooldown_interval;
data_ = std::move(node);
}

Array<MeasureResult> RPCRunnerNode::Run(const Array<MeasureInput>& inputs,
const Array<BuildResult>& build_results,
int verbose) {
if (const auto* f = runtime::Registry::Get("ansor.rpc_runner.run")) {
Array<MeasureResult> results = (*f)(
inputs, build_results, key, host, port, priority, timeout, n_parallel,
number, repeat, min_repeat_ms, cooldown_interval, verbose);
return results;
} else {
LOG(FATAL) << "ansor.rpc_runner.run is not registered";
}
return Array<MeasureResult>();
}

// Local Runner
LocalRunner::LocalRunner(int timeout, int number, int repeat,
int min_repeat_ms, double cooldown_interval) {
Expand Down Expand Up @@ -379,14 +345,6 @@ TVM_REGISTER_GLOBAL("ansor.LocalRunner")
return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval);
});

TVM_REGISTER_GLOBAL("ansor.RPCRunner")
.set_body_typed([](const std::string& key, const std::string& host, int port,
int priority, int timeout, int n_parallel, int number,
int repeat, int min_repeat_ms, double cooldown_interval){
return RPCRunner(key, host, port, priority, timeout, n_parallel, number,
repeat, min_repeat_ms, cooldown_interval);
});

TVM_REGISTER_GLOBAL("ansor.ProgramMeasurer")
.set_body_typed([](Builder builder, Runner runner,
Array<MeasureCallback> callbacks, int verbose,
Expand Down
36 changes: 0 additions & 36 deletions src/ansor/measure.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,42 +219,6 @@ class LocalBuilder: public Builder {
TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, Builder, LocalBuilderNode);
};

/*! \brief RPCRunner that uses RPC call to measures the time cost of programs
* on remote devices */
class RPCRunnerNode : public RunnerNode {
public:
std::string key;
std::string host;
int port;
int priority;
int n_parallel;
int number;
int repeat;
int min_repeat_ms;
double cooldown_interval;

/*! \biref Run measurement and return results */
Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
const Array<BuildResult>& build_results,
int verbose) final;

static constexpr const char* _type_key = "ansor.RPCRunner";
TVM_DECLARE_FINAL_OBJECT_INFO(RPCRunnerNode, RunnerNode);
};

/*!
* \brief Managed reference to RPCRunnerNode.
* \sa RPCRunnerNode
*/
class RPCRunner : public Runner {
public:
RPCRunner(const std::string& key, const std::string& host, int port,
int priority, int timeout, int n_parallel, int number,
int repeat, int min_repeat_ms, double cooldown_interval);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RPCRunner, Runner, RPCRunnerNode);
};

/*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */
class LocalRunnerNode: public RunnerNode {
public:
Expand Down
Loading

0 comments on commit a8e589e

Please sign in to comment.