Skip to content

Commit

Permalink
[AutoScheduler] Refactor task interface for tuning single operators (a…
Browse files Browse the repository at this point in the history
…pache#7028)

* [AutoScheduler] Refactor task interface

* updae tutorials and tests

* update

* fix lint

* fix lint

* update

* fix test
  • Loading branch information
merrymercy authored Dec 4, 2020
1 parent 91c905d commit 75afcd7
Show file tree
Hide file tree
Showing 23 changed files with 456 additions and 125 deletions.
4 changes: 2 additions & 2 deletions include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,15 @@ class ComputeDAGNode : public Object {
* This is an optimization to rewrite the layout of input tensors according to the schedule we get.
*/
enum class LayoutRewriteOption : int {
/*! \brief Do not process layout rewrite. */
/*! \brief Do not perform layout rewrite. */
NoRewrite = 0,
/*! \brief Insert layout transformation stages for input placeholders in the compute DAG */
InsertTransformStage = 1,
/*!
* \brief Do not insert layout transformation stages and assume the input placeholders
* are pre-transformed.
* \note The lowered function with this option does not accept the origial input shapes,
* so this option must be used along with a layout conversion pass in Relay.
* so this option must be used along with `AutoSchedulerLayoutRewrite` pass in Relay.
*/
RewriteForPreTransformed = 2,
};
Expand Down
7 changes: 3 additions & 4 deletions python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
from . import workload_registry

# Shortcut
from .auto_schedule import TuningOptions, HardwareParams, create_task, auto_schedule
from .compute_dag import ComputeDAG
from .compute_dag import ComputeDAG, LayoutRewriteOption
from .cost_model import RandomModel, XGBModel
from .dispatcher import DispatchContext, ApplyHistoryBest
from .measure import (
Expand All @@ -43,14 +42,14 @@
RPCRunner,
LocalRPCMeasureContext,
)
from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records
from .measure_record import RecordToFile, RecordReader, load_best_record, load_records, save_records
from .relay_integration import (
extract_tasks,
remove_index_check,
rewrite_compute_body,
is_auto_scheduler_enabled,
)
from .search_task import SearchTask
from .search_task import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule
from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
from .task_scheduler import TaskScheduler
from .workload_registry import register_workload, make_workload_key
23 changes: 16 additions & 7 deletions python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@
from .workload_registry import workload_key_to_tensors


class LayoutRewriteOption:
"""Options for applying layout rewrite."""

# Do not perform layout rewrite
NO_REWRITE = 0
# Insert layout transformation stages for input placeholders in the compute DAG
INSERT_TRANSFORM_STAGE = 1
# Do not insert layout transformation stages and assume the input placeholders
# are pre-transformed.
# Note: The lowered function with this option does not accept the origial input shapes,
# so this option must be used along with `AutoSchedulerLayoutRewrite` pass in Relay.
REWRITE_FOR_PRE_TRANSFORMED = 2


@tvm._ffi.register_object("auto_scheduler.ComputeDAG")
class ComputeDAG(Object):
"""
Expand All @@ -52,11 +66,6 @@ class ComputeDAG(Object):
Input/output tensors or workload key for a compute declaration.
"""

# Layout Rewrite Options
NoRewrite = 0
InsertTransformStage = 1
RewriteForPreTransformed = 2

def __init__(self, compute_or_sche):
if isinstance(compute_or_sche, str):
compute = workload_key_to_tensors(compute_or_sche)
Expand Down Expand Up @@ -92,7 +101,7 @@ def get_init_state(self):
"""
return State(self.init_state, self)

def apply_steps_from_state(self, state, layout_rewrite=NoRewrite):
def apply_steps_from_state(self, state, layout_rewrite=LayoutRewriteOption.NO_REWRITE):
"""
Apply the history transform steps from a State to get a TVM schedule.
Expand All @@ -101,7 +110,7 @@ def apply_steps_from_state(self, state, layout_rewrite=NoRewrite):
state : Union[State, StateObject]
The state from which we get transform steps.
layout_rewrite: Bool
layout_rewrite: LayoutRewriteOption = NoRewrite
Rewrite the layout of placeholders specified by "layout_free_placeholders" attr
to make it most friendly for the generated schedule to read from.
Expand Down
18 changes: 10 additions & 8 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@
make_traceback_info,
request_remote,
)
from .compute_dag import ComputeDAG
from .search_task import SearchTask
from .compute_dag import LayoutRewriteOption
from .workload_registry import (
serialize_workload_registry_entry,
deserialize_workload_registry_entry,
Expand Down Expand Up @@ -178,13 +177,15 @@ def recover_measure_input(inp, rebuild_state=False):
new_input: MeasureInput
The fully recovered MeasureInput with all fields rebuilt.
"""
# pylint: disable=import-outside-toplevel
from .search_task import SearchTask # lazily import to avoid recursive dependency

task = inp.task
new_task = SearchTask(
ComputeDAG(task.workload_key),
task.workload_key,
task.target,
task.target_host,
task.hardware_params,
workload_key=task.workload_key,
target=task.target,
target_host=task.target_host,
hardware_params=task.hardware_params,
)

if rebuild_state:
Expand Down Expand Up @@ -521,6 +522,7 @@ def __del__(self):
# Close the tracker and server before exit
self.tracker.terminate()
self.server.terminate()
time.sleep(0.5)


class MeasureErrorNo(object):
Expand Down Expand Up @@ -549,7 +551,7 @@ def _timed_func(inp_serialized, build_func, verbose):

try:
sch, args = task.compute_dag.apply_steps_from_state(
inp.state, layout_rewrite=ComputeDAG.RewriteForPreTransformed
inp.state, layout_rewrite=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
)
# pylint: disable=broad-except
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/measure_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def save_records(filename, inputs, results):
_ffi_api.SaveRecords(filename, inputs, results)


def load_best(filename, workload_key=None, target=None):
def load_best_record(filename, workload_key=None, target=None):
"""Return the best measurement pair form a log file. This may return none results if
there is no legal measure pair with the specified workload_key/target found from the log file.
Expand Down
9 changes: 8 additions & 1 deletion python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,14 @@ def extract_tasks(
weights = []
for wkl_key, ccache_key in env.wkl_key_to_ccache_key.items():
dag = ComputeDAG(wkl_key)
tasks.append(SearchTask(dag, wkl_key, target, target_host, hardware_params))
tasks.append(
SearchTask(
workload_key=wkl_key,
target=target,
target_host=target_host,
hardware_params=hardware_params,
)
)
weights.append(use_count_dict[ccache_key] + 1)

# clean the cached lowering results
Expand Down
Loading

0 comments on commit 75afcd7

Please sign in to comment.