Skip to content

Commit

Permalink
[AutoScheduler] Remove max_registers_per_block in HardwareParams (a…
Browse files Browse the repository at this point in the history
…pache#7040)

* [AutoScheduler] Fix hardware params

* address comments
  • Loading branch information
merrymercy authored Dec 5, 2020
1 parent 3d9ae3e commit 878a0a9
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 21 deletions.
10 changes: 5 additions & 5 deletions include/tvm/auto_scheduler/search_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class HardwareParamsNode : public Object {
// GPU related parameters got from device query API
/*! \brief The max shared memory per block in bytes. */
int max_shared_memory_per_block;
/*! \brief The max number of register per block. */
int max_registers_per_block;
/*! \brief The max local memory per block in bytes. */
int max_local_memory_per_block;
/*! \brief The max number of threads per block. */
int max_threads_per_block;
/*! \brief The max vthread extent. */
Expand All @@ -60,7 +60,7 @@ class HardwareParamsNode : public Object {
v->Visit("vector_unit_bytes", &vector_unit_bytes);
v->Visit("cache_line_bytes", &cache_line_bytes);
v->Visit("max_shared_memory_per_block", &max_shared_memory_per_block);
v->Visit("max_registers_per_block", &max_registers_per_block);
v->Visit("max_local_memory_per_block", &max_local_memory_per_block);
v->Visit("max_threads_per_block", &max_threads_per_block);
v->Visit("max_vthread_extent", &max_vthread_extent);
v->Visit("warp_size", &warp_size);
Expand Down Expand Up @@ -90,13 +90,13 @@ class HardwareParams : public ObjectRef {
* \param vector_unit_bytes The width of vector units in bytes.
* \param cache_line_bytes The size of cache line in bytes.
* \param max_shared_memory_per_block The max amount of shared memory per block for GPU.
* \param max_registers_per_block The max number of registers per block for GPU.
* \param max_local_memory_per_block The max amount of local memory per block for GPU.
* \param max_threads_per_block The max number of threads per block for GPU.
* \param max_vthread_extent The max extent of vthread for GPU.
* \param warp_size The warp size for GPU
*/
HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes,
int max_shared_memory_per_block, int max_registers_per_block,
int max_shared_memory_per_block, int max_local_memory_per_block,
int max_threads_per_block, int max_vthread_extent, int warp_size);

TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode);
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class HardwareParams(Object):
The size of cache line in bytes.
max_shared_memory_per_block : int
The max shared memory per block in bytes.
max_registers_per_block : int
The max number of register per block.
max_local_memory_per_block : int
The max local memory per block in bytes.
max_threads_per_block : int
The max number of threads per block.
max_vthread_extent : int
Expand All @@ -65,7 +65,7 @@ def __init__(
vector_unit_bytes,
cache_line_bytes,
max_shared_memory_per_block,
max_registers_per_block,
max_local_memory_per_block,
max_threads_per_block,
max_vthread_extent,
warp_size,
Expand All @@ -76,7 +76,7 @@ def __init__(
vector_unit_bytes,
cache_line_bytes,
max_shared_memory_per_block,
max_registers_per_block,
max_local_memory_per_block,
max_threads_per_block,
max_vthread_extent,
warp_size,
Expand Down
2 changes: 1 addition & 1 deletion src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i
pass_list.push_back(tir::transform::Simplify());
tvm::Map<String, tvm::PrimExpr> gpu_params{
{"max_shared_memory_per_block", task->hardware_params->max_shared_memory_per_block},
{"max_local_memory_per_block", task->hardware_params->max_registers_per_block},
{"max_local_memory_per_block", task->hardware_params->max_local_memory_per_block},
{"max_threads_per_block", task->hardware_params->max_threads_per_block},
{"max_vector_bytes", task->hardware_params->vector_unit_bytes},
{"max_vthread", task->hardware_params->max_vthread_extent},
Expand Down
4 changes: 2 additions & 2 deletions src/auto_scheduler/measure_record.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ struct Handler<::tvm::auto_scheduler::HardwareParamsNode> {
writer->WriteArrayItem(data.vector_unit_bytes);
writer->WriteArrayItem(data.cache_line_bytes);
writer->WriteArrayItem(data.max_shared_memory_per_block);
writer->WriteArrayItem(data.max_registers_per_block);
writer->WriteArrayItem(data.max_local_memory_per_block);
writer->WriteArrayItem(data.max_threads_per_block);
writer->WriteArrayItem(data.max_vthread_extent);
writer->WriteArrayItem(data.warp_size);
Expand All @@ -140,7 +140,7 @@ struct Handler<::tvm::auto_scheduler::HardwareParamsNode> {
reader->Read(&data->max_shared_memory_per_block);
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&data->max_registers_per_block);
reader->Read(&data->max_local_memory_per_block);
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&data->max_threads_per_block);
Expand Down
19 changes: 10 additions & 9 deletions src/auto_scheduler/search_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ TVM_REGISTER_NODE_TYPE(HardwareParamsNode);
TVM_REGISTER_NODE_TYPE(SearchTaskNode);

HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes,
int max_shared_memory_per_block, int max_registers_per_block,
int max_shared_memory_per_block, int max_local_memory_per_block,
int max_threads_per_block, int max_vthread_extent, int warp_size) {
auto node = make_object<HardwareParamsNode>();
node->num_cores = num_cores;
node->vector_unit_bytes = vector_unit_bytes;
node->cache_line_bytes = cache_line_bytes;
node->max_shared_memory_per_block = max_shared_memory_per_block;
node->max_registers_per_block = max_registers_per_block;
node->max_local_memory_per_block = max_local_memory_per_block;
node->max_threads_per_block = max_threads_per_block;
node->max_vthread_extent = max_vthread_extent;
node->warp_size = warp_size;
Expand All @@ -64,8 +64,9 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret);
int max_shared_memory_per_block = ret;

device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, &ret);
int max_registers_per_block = ret;
// There is no explicit local memory limition in CUDA runtime,
// so we can use INT32_MAX to disalbe the check on local_memory.
int max_local_memory_per_block = INT32_MAX;

device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret);
int max_threads_per_block = ret;
Expand All @@ -74,17 +75,17 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target
int warp_size = ret;

int max_vthread_extent = warp_size / 4;
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_registers_per_block,
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
max_threads_per_block, max_vthread_extent, warp_size);
} else if (target->kind->device_type == kDLMetal) {
// Reference: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
// This setting looks working for Metal GPUs later than A10
int max_shared_memory_per_block = 32 * 1024;
int max_registers_per_block = 4 * 1024;
int max_local_memory_per_block = INT32_MAX; // skip the check on local memory
int max_threads_per_block = 1024;
int warp_size = 8;
int max_vthread_extent = warp_size / 4;
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_registers_per_block,
return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
max_threads_per_block, max_vthread_extent, warp_size);
} else {
LOG(FATAL) << "No default hardware parameters for target: " << target;
Expand All @@ -110,10 +111,10 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe

TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams")
.set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes,
int max_shared_memory_per_block, int max_registers_per_block,
int max_shared_memory_per_block, int max_local_memory_per_block,
int max_threads_per_block, int max_vthread_extent, int warp_size) {
return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes,
max_shared_memory_per_block, max_registers_per_block,
max_shared_memory_per_block, max_local_memory_per_block,
max_threads_per_block, max_vthread_extent, warp_size);
});

Expand Down

0 comments on commit 878a0a9

Please sign in to comment.