Skip to content

Commit

Permalink
[AutoScheduler] Misc update to hardware parameter and task scheduler (a…
Browse files Browse the repository at this point in the history
…pache#7020)

* [AutoScheduler] Mics update to hardware parameter and task scheduler

* update

* update

* update

* update

* fix

* fix

* update

* improve warning message

* update

* lint

* update

* update

* fix

* Apply suggestions from code review

* trigger CI
  • Loading branch information
merrymercy authored Dec 3, 2020
1 parent b06b64d commit e6c1baf
Show file tree
Hide file tree
Showing 12 changed files with 121 additions and 50 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@
"tune_relay_mobile_gpu.py",
],
"auto_scheduler": ["tune_matmul_x86.py", "tune_conv2d_layer_cuda.py"],
"dev": ["low_level_custom_pass.py", "use_pass_infra.py", "bring_your_own_datatypes.py"],
}


Expand Down
26 changes: 16 additions & 10 deletions include/tvm/auto_scheduler/search_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,16 @@ class HardwareParamsNode : public Object {
int cache_line_bytes;

// GPU related parameters got from device query API

/*! \brief The max shared memory per block. */
int max_shared_memory_per_block{INT32_MAX};
/*! \brief The max register memory per block. */
int max_registers_per_block{INT32_MAX};
/*! \brief The max threads per block. */
int max_threads_per_block{INT32_MAX};
/*! \brief The max shared memory per block in bytes. */
int max_shared_memory_per_block;
/*! \brief The max number of register per block. */
int max_registers_per_block;
/*! \brief The max number of threads per block. */
int max_threads_per_block;
/*! \brief The max vthread extent. */
int max_vthread_extent{INT32_MAX};
int max_vthread_extent;
/*! \brief The thread numbers of a warp. */
int warp_size{INT32_MAX};
int warp_size;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("num_cores", &num_cores);
Expand Down Expand Up @@ -90,8 +89,15 @@ class HardwareParams : public ObjectRef {
* \param num_cores The number of cores.
* \param vector_unit_bytes The width of vector units in bytes.
* \param cache_line_bytes The size of cache line in bytes.
* \param max_shared_memory_per_block The max amount of shared memory per block for GPU.
* \param max_registers_per_block The max number of registers per block for GPU.
* \param max_threads_per_block The max number of threads per block for GPU.
* \param max_vthread_extent The max extent of vthread for GPU.
* \param warp_size The warp size for GPU
*/
HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes);
HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes,
int max_shared_memory_per_block, int max_registers_per_block,
int max_threads_per_block, int max_vthread_extent, int warp_size);

TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(HardwareParamsNode);
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@
LocalRPCMeasureContext,
)
from .measure_record import RecordToFile, RecordReader, load_best, load_records, save_records
from .relay_integration import extract_tasks, remove_index_check, rewrite_compute_body
from .relay_integration import (
extract_tasks,
remove_index_check,
rewrite_compute_body,
is_auto_scheduler_enabled,
)
from .search_task import SearchTask
from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
from .task_scheduler import TaskScheduler
Expand Down
32 changes: 30 additions & 2 deletions python/tvm/auto_scheduler/auto_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,39 @@ class HardwareParams(Object):
The width of vector units in bytes.
cache_line_bytes : int
The size of cache line in bytes.
max_shared_memory_per_block : int
The max shared memory per block in bytes.
max_registers_per_block : int
The max number of register per block.
max_threads_per_block : int
The max number of threads per block.
max_vthread_extent : int
The max vthread extent.
warp_size : int
The thread numbers of a warp.
"""

def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes):
def __init__(
self,
num_cores,
vector_unit_bytes,
cache_line_bytes,
max_shared_memory_per_block,
max_registers_per_block,
max_threads_per_block,
max_vthread_extent,
warp_size,
):
self.__init_handle_by_constructor__(
_ffi_api.HardwareParams, num_cores, vector_unit_bytes, cache_line_bytes
_ffi_api.HardwareParams,
num_cores,
vector_unit_bytes,
cache_line_bytes,
max_shared_memory_per_block,
max_registers_per_block,
max_threads_per_block,
max_vthread_extent,
warp_size,
)


Expand Down
12 changes: 12 additions & 0 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import tvm
from tvm import autotvm, te, transform
from tvm.ir.transform import PassContext
from tvm.runtime import convert_to_object
from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
from tvm.tir import expr as _expr
Expand Down Expand Up @@ -342,3 +343,14 @@ def rewrite_compute_body(compute_tensor, new_layout):
num = op_node.num_outputs
outputs = tuple(op_node.output(i) for i in range(num))
return outputs[0] if num == 1 else outputs


def is_auto_scheduler_enabled():
"""Return whether the auto-scheduler is enabled.
Parameters
----------
enabled: bool
Whether the auto-scheduler is enabled
"""
return PassContext.current().config.get("relay.backend.use_auto_scheduler", False)
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/search_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class SketchPolicy(SearchPolicy):
"sample_init_min_population": 50,
"sample_init_use_measured_ratio": 0.2,
"evolutionary_search_population": 2048,
"evolutionary_search_num_iters": 3,
"evolutionary_search_num_iters": 4,
"evolutionary_search_mutation_prob": 0.85,
"cpu_multi_level_tiling_structure": "SSRSRS",
"gpu_multi_level_tiling_structure": "SSSRRSRS",
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/auto_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class TaskScheduler:
The parameter used for 'gradient' strategy
callbacks: Optional[List[TaskSchedulerCallback]]
The task scheduler callbacks that will be called before and after tuning a task.
If None, then PrintTableInfo callback will be used.
If None, PrintTableInfo and LogEstimatedLatency callback will be used.
"""

def __init__(
Expand Down Expand Up @@ -214,7 +214,11 @@ def __init__(
self.beta = beta
self.gamma = gamma
self.backward_window_size = backward_window_size
self.callbacks = callbacks if callbacks is not None else [PrintTableInfo()]
self.callbacks = (
callbacks
if callbacks is not None
else [PrintTableInfo(), LogEstimatedLatency("total_latency.tsv")]
)

assert len(self.tasks) != 0, "No tasks"
assert self.strategy in ["round-robin", "gradient"]
Expand Down
9 changes: 3 additions & 6 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
from tvm import topi
import tvm
from tvm.ir.transform import PassContext
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.te import SpecializedCondition
from tvm.contrib import nvcc
from tvm._ffi import get_global_func
Expand Down Expand Up @@ -230,10 +230,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
)

# register auto-scheduler implementations
use_auto_scheduler = PassContext.current().config.get(
"relay.backend.use_auto_scheduler", False
)
if use_auto_scheduler and judge_winograd_auto_scheduler:
if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
naive_schedule, # this implementation should never be picked by autotvm
Expand Down Expand Up @@ -460,7 +457,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda",
)

if PassContext.current().config.get("relay.backend.use_auto_scheduler", False):
if is_auto_scheduler_enabled():
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc_without_weight_transform),
naive_schedule, # this implementation should never be picked by autotvm
Expand Down
17 changes: 13 additions & 4 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import re
from tvm import topi
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.te import SpecializedCondition
from tvm.relay.ty import is_dynamic
from .generic import *
Expand Down Expand Up @@ -117,14 +118,17 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
if not is_auto_scheduler_enabled():
logger.warning("conv2d NHWC layout is not optimized for x86 with autotvm.")
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True),
wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc),
name="conv2d_nhwc.x86",
)
elif layout == "HWCN":
assert kernel_layout == "HWIO"
logger.warning("conv2d HWCN layout is not optimized for x86.")
if not is_auto_scheduler_enabled():
logger.warning("conv2d HWCN layout is not optimized for x86 with autotvm.")
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_hwcn),
wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn),
Expand Down Expand Up @@ -157,7 +161,10 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.")
if not is_auto_scheduler_enabled():
logger.warning(
"depthwise_conv2d NHWC layout is not optimized for x86 with autotvm."
)
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc),
Expand All @@ -168,15 +175,17 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
else: # group_conv2d
if layout == "NCHW":
assert kernel_layout == "OIHW"
logger.warning("group_conv2d is not optimized for x86.")
if not is_auto_scheduler_enabled():
logger.warning("group_conv2d is not optimized for x86 with autotvm.")
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw),
name="group_conv2d_nchw.generic",
)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
logger.warning("group_conv2d is not optimized for x86.")
if not is_auto_scheduler_enabled():
logger.warning("group_conv2d is not optimized for x86 with autotvm.")
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.group_conv2d_nhwc, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_nhwc),
Expand Down
51 changes: 29 additions & 22 deletions src/auto_scheduler/search_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,54 +35,57 @@ namespace auto_scheduler {
TVM_REGISTER_NODE_TYPE(HardwareParamsNode);
TVM_REGISTER_NODE_TYPE(SearchTaskNode);

HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes) {
HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes,
int max_shared_memory_per_block, int max_registers_per_block,
int max_threads_per_block, int max_vthread_extent, int warp_size) {
auto node = make_object<HardwareParamsNode>();
node->num_cores = num_cores;
node->vector_unit_bytes = vector_unit_bytes;
node->cache_line_bytes = cache_line_bytes;
node->max_shared_memory_per_block = max_shared_memory_per_block;
node->max_registers_per_block = max_registers_per_block;
node->max_threads_per_block = max_threads_per_block;
node->max_vthread_extent = max_vthread_extent;
node->warp_size = warp_size;
data_ = std::move(node);
}

HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target,
const Target& target_host) {
if (target->kind->device_type == kDLCPU) {
return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64);
return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 0, 0, 0, 0, 0);
} else if (target->kind->device_type == kDLGPU) {
auto hardware_params = HardwareParams(-1, 16, 64);
auto* p_hardware_params = hardware_params.CopyOnWrite();

auto ctx = TVMContext{kDLGPU, 0};
auto func = tvm::runtime::Registry::Get("device_api.gpu");
ICHECK(func != nullptr) << "Cannot find GPU device_api in registry";
auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());

tvm::runtime::TVMRetValue ret;
device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret);
p_hardware_params->max_shared_memory_per_block = ret;
int max_shared_memory_per_block = ret;

device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, &ret);
p_hardware_params->max_registers_per_block = ret;
int max_registers_per_block = ret;

device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret);
p_hardware_params->max_threads_per_block = ret;
int max_threads_per_block = ret;

device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret);
p_hardware_params->warp_size = ret;

p_hardware_params->max_vthread_extent = p_hardware_params->warp_size / 4;
int warp_size = ret;

return hardware_params;
int max_vthread_extent = warp_size / 4;
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_registers_per_block,
max_threads_per_block, max_vthread_extent, warp_size);
} else if (target->kind->device_type == kDLMetal) {
// Reference: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
// This setting looks working for Metal GPUs later than A10
auto hardware_params = HardwareParams(-1, 16, 64);
auto* p_hardware_params = hardware_params.CopyOnWrite();
p_hardware_params->max_shared_memory_per_block = 32 * 1024;
p_hardware_params->max_registers_per_block = 4 * 1024;
p_hardware_params->max_threads_per_block = 1024;
p_hardware_params->warp_size = 8;
p_hardware_params->max_vthread_extent = p_hardware_params->warp_size / 4;
return hardware_params;
int max_shared_memory_per_block = 32 * 1024;
int max_registers_per_block = 4 * 1024;
int max_threads_per_block = 1024;
int warp_size = 8;
int max_vthread_extent = warp_size / 4;
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_registers_per_block,
max_threads_per_block, max_vthread_extent, warp_size);
} else {
LOG(FATAL) << "No default hardware parameters for target: " << target;
}
Expand All @@ -106,8 +109,12 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe
}

TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams")
.set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes) {
return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes);
.set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes,
int max_shared_memory_per_block, int max_registers_per_block,
int max_threads_per_block, int max_vthread_extent, int warp_size) {
return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes,
max_shared_memory_per_block, max_registers_per_block,
max_threads_per_block, max_vthread_extent, warp_size);
});

TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask")
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_auto_scheduler_compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_stage_order():
dag,
json.dumps(("test-key",)),
tvm.target.Target("llvm"),
hardware_params=auto_scheduler.HardwareParams(100000, 16, 64),
hardware_params=auto_scheduler.HardwareParams(100000, 16, 64, 0, 0, 0, 0, 0),
)

task2 = pickle.loads(pickle.dumps(task))
Expand Down
4 changes: 3 additions & 1 deletion tests/python/unittest/test_auto_scheduler_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def test_gpu_feature():
inp.task.workload_key,
inp.task.target,
None,
auto_scheduler.HardwareParams(100000, 16, 64),
auto_scheduler.HardwareParams(
100000, 16, 64, 1 << 30, 1 << 30, 1 << 30, 1 << 30, 1 << 30
),
)

state = dag.infer_bound_from_state(inputs[0].state)
Expand Down

0 comments on commit e6c1baf

Please sign in to comment.