Skip to content

Commit

Permalink
[Ansor][AutoTVM v2.0] Phase 2: Basic GPU Sketch Search Policy (apache…
Browse files Browse the repository at this point in the history
…#6269)

* Add PreloadMeasuredStates & Split search_policy.py

* Add GPU sketch rule

* Update

* Bug fix for log record

* Lint fix

* Update tutorial

* Update

* UT fix

* Remove tutorial

* Update

* Update

* Update UT

* Lint fix

* Update

* Update
  • Loading branch information
jcf94 authored and Trevor Morris committed Aug 26, 2020
1 parent d87c961 commit dd82c8e
Show file tree
Hide file tree
Showing 18 changed files with 1,387 additions and 316 deletions.
35 changes: 35 additions & 0 deletions include/tvm/auto_scheduler/search_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,35 @@ class SearchCallback : public ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode);
};

/*! \brief Preload measured states from a log file.
* This can resume the state of the search policy */
class PreloadMeasuredStatesNode : public SearchCallbackNode {
public:
/*! \brief The name of the record log file. */
String filename;

void Callback(SearchPolicyNode* policy) final;

static constexpr const char* _type_key = "auto_scheduler.PreloadMeasuredStates";
TVM_DECLARE_FINAL_OBJECT_INFO(PreloadMeasuredStatesNode, SearchCallbackNode);
};

/*!
* \brief Managed reference to PreloadMeasuredStatesNode.
* \sa PreloadMeasuredStatesNode
*/
class PreloadMeasuredStates : public SearchCallback {
public:
/*!
* \brief The constructor.
* \param filename The name of the record log file.
*/
explicit PreloadMeasuredStates(String filename);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadMeasuredStates, SearchCallback,
PreloadMeasuredStatesNode);
};

/*! \brief Attribute keys of ops used for SearchPolicy. */
struct SearchPolicyKey {
/*! \brief Always apply unroll to the inner most iterator of the specificed iterators. */
Expand Down Expand Up @@ -141,6 +170,12 @@ class SearchPolicyNode : public Object {
virtual State Search(int num_measure_trials, int early_stopping, int num_measures_per_round,
ProgramMeasurer measurer) = 0;

/*!
* \brief Preload measured states from a log file to resume the state of the search policy.
* \param log_file The name of the record log file.
*/
void PreloadMeasuredStates(const String& log_file);

/*!
* \brief Call SearchCallback with the current SearchPolicyNode
* \param callbacks SearchCallback to be called.
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@

# Shortcut
from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \
auto_schedule, EmptyPolicy, SketchPolicy
auto_schedule
from .compute_dag import ComputeDAG
from .cost_model import RandomModel, XGBModel
from .measure import MeasureInput, MeasureResult, LocalBuilder, LocalRunner, RPCRunner, \
LocalRPCMeasureContext
from .measure_record import RecordToFile, RecordReader, load_best, \
load_records, save_records
from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
from .workload_registry import register_workload, make_workload_key
122 changes: 1 addition & 121 deletions python/tvm/auto_scheduler/auto_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@
Candidate schedules are measured against the specific hardware target.
"""

import random

import tvm._ffi
from tvm.runtime import Object
from .measure import LocalBuilder, LocalRunner
from .cost_model import RandomModel
from .search_policy import EmptyPolicy
from . import _ffi_api


Expand Down Expand Up @@ -82,124 +80,6 @@ def __init__(self, dag, workload_key, target, target_host=None,
hardware_params)


@tvm._ffi.register_object("auto_scheduler.SearchPolicy")
class SearchPolicy(Object):
""" The base class of search policies. """


@tvm._ffi.register_object("auto_scheduler.EmptyPolicy")
class EmptyPolicy(SearchPolicy):
""" This is an example empty search policy which will always generate
the init state of ComputeDAG.
Parameters
----------
task : SearchTask
The SearchTask for the computation declaration.
init_search_callbacks : Optional[List[SearchCallback]]
Callback functions called before the search process.
"""
def __init__(self, task, init_search_callbacks=None):
self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy, task, init_search_callbacks)


@tvm._ffi.register_object("auto_scheduler.SketchPolicy")
class SketchPolicy(SearchPolicy):
""" The search policy that searches in a hierarchical search space defined by sketches.
The policy randomly samples programs from the space defined by sketches
and use evolutionary search to fine-tune them.
Parameters
----------
task : SearchTask
The SearchTask for the computation declaration.
schedule_cost_model : CostModel = RandomModel()
The cost model to estimate the complete schedules.
params : Optional[Dict[str, Any]]
Parameters of the search policy.
See `src/auto_scheduler/search_policy/sketch_search_policy.h` for the definitions.
See `DEFAULT_PARAMS` below to find the default values.
seed : Optional[int]
Random seed.
verbose : int = 1
Verbosity level. 0 for silent, 1 to output information during schedule search.
init_search_callbacks : Optional[List[SearchCallback]]
Callback functions called before the search process, usually used to do extra
initializations.
Possible callbacks:
- auto_scheduler.PreloadMeasuredStates
- auto_scheduler.PreloadCustomSketchRule
TODO(jcf94): Add these search callback implementations.
"""

DEFAULT_PARAMS = {
"eps_greedy": 0.05,

'evolutionary_search_population': 2048,
"evolutionary_search_use_measured_ratio": 0.2,

'cpu_multi_level_tiling_structure': 'SSRSRS',
'gpu_multi_level_tiling_structure': 'SSSRRSRS',

'max_innermost_split_factor': 16,
'max_vectorize_size': 16,

'disable_change_compute_location': 0,
}

def __init__(self, task, schedule_cost_model=RandomModel(), params=None, seed=None, verbose=1,
init_search_callbacks=None):
if params is None:
params = SketchPolicy.DEFAULT_PARAMS
else:
for key, value in SketchPolicy.DEFAULT_PARAMS.items():
if key not in params:
params[key] = value

self.__init_handle_by_constructor__(
_ffi_api.SketchPolicy, task, schedule_cost_model, params,
seed or random.randint(1, 1 << 30), verbose, init_search_callbacks)

def generate_sketches(self, print_for_debug=False):
""" Generate the sketches.
This python interface is mainly used for debugging and testing.
The actual search is all doen in c++.
Parameters
----------
print_for_debug : bool = False
Whether print out the sketches for debug.
Returns
-------
sketches : List[State]
The generated sketches of this search task.
"""
sketches = _ffi_api.SketchPolicyGenerateSketches(self)
if print_for_debug:
for i, s in enumerate(sketches):
print("=" * 20 + " %d " % i + "=" * 20)
print(s)
return sketches

def sample_initial_population(self, pop_size):
"""Sample initial population.
This python interface is mainly used for debugging and testing.
The actual search is all doen in c++.
Parameters
----------
pop_size : int
The size of sampled population
Returns
-------
states: List[State]
The sampled states
"""
states = _ffi_api.SketchPolicySampleInitialPopulation(self, pop_size)
return states

@tvm._ffi.register_object("auto_scheduler.TuningOptions")
class TuningOptions(Object):
""" This controls the options of performance tuning.
Expand Down
9 changes: 8 additions & 1 deletion python/tvm/auto_scheduler/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ class Iterator(Object):
class Stage(Object):
""" A stage in the compute declaration. Similar to tvm.te.schedule.Stage. """

# Static trans table for compute_at location
# This is used to transform the compute_at location to C++ enum
COMPUTE_AT_TRANS_TABLE = {
"root": 0,
"inlined": 1,
"iter": 2
}

@tvm._ffi.register_object("auto_scheduler.State")
class StateObject(Object):
Expand Down Expand Up @@ -85,7 +92,7 @@ class State:
This is a wrapper class of StateObject to deal with copy-on-write property
"""

# Static trans table for thread bind
# Static trans table for thread bind and annotation
# This is used to transform the annotation name to C++ enum
ANNOTATION_TRANS_TABLE = {
"none": 0,
Expand Down
Loading

0 comments on commit dd82c8e

Please sign in to comment.