Skip to content

Commit

Permalink
[USMP] Add performance characteristics to PoolInfo (apache#10005)
Browse files Browse the repository at this point in the history
* [USMP] Add performance characteristics to PoolInfo

Scheduling algorithms that wish to optimize around
memory pools require further information about the
perfomance characteristics of those pools. This
commit adds clock frequency, bandwidth, latency and
burst length as optional fields to PoolInfo.

Change-Id: I4cf3f35324d093fb38e874f0f2e587cb84d4ba1e

* Remove unused import

Change-Id: I1e2ef885425f4361b80c2bab9261ec129e61a756
  • Loading branch information
mbaret authored and ylc committed Feb 16, 2022
1 parent 35464cc commit 1c8a407
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 39 deletions.
88 changes: 63 additions & 25 deletions include/tvm/tir/usmp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,70 +44,108 @@ constexpr const char* kUSMPAlgorithmOption = "tir.usmp.algorithm";
namespace tir {
namespace usmp {

/*!
* \brief The string parameter to indicate read and write access to a pool
* This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in
* python/tvm/tir/usmp/utils.py
*/
static constexpr const char* kTargetPoolReadWriteAccess = "rw";
/*!
* \brief The string parameter to indicate read only access to a pool
* This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in
* python/tvm/tir/usmp/utils.py
*/
static constexpr const char* kTargetPoolReadOnlyAccess = "ro";

/*!
* \brief Describes a pool of memory accessible by one or more targets.
*/
struct PoolInfoNode : public Object {
/*! \brief The name of the memory pool */
String pool_name;
/*! \brief The expected size hint to be used by the allocator.
* The size_hint_bytes is defaulted to kUnrestrictedPoolSizeHint
* The size_hint_bytes is set to kUnrestrictedPoolSizeHint
* to indicate the pool is not size restricted.
*/
Integer size_hint_bytes;
/*! \brief The accessibility from each Target*/
/*! \brief The accessibility from each Target */
Map<Target, String> target_access; // 'rw' or 'ro'
/*! \brief The clock frequency of the memory in Hz */
Integer clock_frequency_hz;
/*! \brief The read bandwidth in bytes/cycle */
Integer read_bandwidth_bytes_per_cycle;
/*! \brief The write bandwidth in bytes/cycle */
Integer write_bandwidth_bytes_per_cycle;
/*! \brief The read latency in cycles */
Integer read_latency_cycles;
/*! \brief The write latency in cycles */
Integer write_latency_cycles;
/*! \brief The burst length in bytes for each Target */
Map<Target, Integer> target_burst_bytes;
/*! \brief Whether pool is internally generated.
* The internal pools will be generated as part of
* the entry point code generation of the executor*/
* the entry point code generation of the executor
*/
bool is_internal = false;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pool_name", &pool_name);
v->Visit("size_hint_bytes", &size_hint_bytes);
v->Visit("target_access", &target_access);
v->Visit("clock_frequency_hz", &clock_frequency_hz);
v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle);
v->Visit("write_bandwidth_bytes_per_cycle", &write_bandwidth_bytes_per_cycle);
v->Visit("read_latency_cycles", &read_latency_cycles);
v->Visit("write_latency_cycles", &write_latency_cycles);
v->Visit("target_burst_bytes", &target_burst_bytes);
v->Visit("is_internal", &is_internal);
}

bool SEqualReduce(const PoolInfoNode* other, SEqualReducer equal) const {
return equal(pool_name, other->pool_name) && equal(size_hint_bytes, other->size_hint_bytes) &&
equal(target_access, other->target_access) && equal(is_internal, other->is_internal);
equal(target_access, other->target_access) &&
equal(target_access, other->target_access) &&
equal(clock_frequency_hz, other->clock_frequency_hz) &&
equal(read_bandwidth_bytes_per_cycle, other->read_bandwidth_bytes_per_cycle) &&
equal(write_bandwidth_bytes_per_cycle, other->write_bandwidth_bytes_per_cycle) &&
equal(read_latency_cycles, other->read_latency_cycles) &&
equal(write_latency_cycles, other->write_latency_cycles) &&
equal(target_burst_bytes, other->target_burst_bytes) &&
equal(is_internal, other->is_internal);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(pool_name);
hash_reduce(size_hint_bytes);
hash_reduce(target_access);
hash_reduce(clock_frequency_hz);
hash_reduce(read_bandwidth_bytes_per_cycle);
hash_reduce(write_bandwidth_bytes_per_cycle);
hash_reduce(read_latency_cycles);
hash_reduce(write_latency_cycles);
hash_reduce(target_burst_bytes);
hash_reduce(is_internal);
}

static constexpr const char* _type_key = "tir.usmp.PoolInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object);
};

/*!
* \brief The PoolSize is unrestricted for the memory planner
*/
static const int kUnrestrictedPoolSizeHint = -1;

class PoolInfo : public ObjectRef {
public:
TVM_DLL PoolInfo(String pool_name, Map<Target, String> target_access,
Integer size_hint_bytes = kUnrestrictedPoolSizeHint,
Bool is_internal = Bool(false));
/*!
* \brief The string parameter to indicate read and write access to a pool
* This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in
* python/tvm/tir/usmp/utils.py
*/
static constexpr const char* kTargetPoolReadWriteAccess = "rw";
/*!
* \brief The string parameter to indicate read only access to a pool
* This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in
* python/tvm/tir/usmp/utils.py
*/
static constexpr const char* kTargetPoolReadOnlyAccess = "ro";
/*! \brief The PoolSize is unrestricted for the memory planner */
static const int kUnrestrictedPoolSizeHint = -1;
/*! \brief The clock frequency is not known */
static const int kUnknownClockFrequency = -1;
/*! \brief The read bandwidth is not known */
static const int kUnknownReadBandwidth = -1;
/*! \brief The write bandwidth is not known */
static const int kUnknownWriteBandwidth = -1;

TVM_DLL PoolInfo(String pool_name, Map<Target, String> target_access, Integer size_hint_bytes,
Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle,
Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles,
Integer write_latency_cycles, Map<Target, Integer> target_burst_bytes,
Bool is_internal);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode);
};

Expand Down
47 changes: 45 additions & 2 deletions python/tvm/tir/usmp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""USMP Utilities and Data Structures"""
# pylint: disable=invalid-name

from typing import Dict, Optional, List
from typing import Dict, Optional, List, Union

from tvm._ffi import register_object
from tvm.runtime import Object
Expand Down Expand Up @@ -52,6 +52,34 @@ class PoolInfo(Object):
The default value would be -1 which means the pool
is not size restricted.
clock_frequency_hz : Optional[int]
The clock frequency that the memory pool runs at in Hz.
If not specified/known, this will default to -1 indicating
it hasn't been defined.
read_bandwidth_bytes_per_cycle : Optional[int]
The read bandwidth of the memory pool in bytes/cycle.
If not specified/known, this will default to -1 indicating
it hasn't been defined.
write_bandwidth_bytes_per_cycle : Optional[int]
The write bandwidth of the memory pool in bytes/cycle.
If not specified/known, this will default to -1 indicating
it hasn't been defined.
read_latency_cycles : Optional[int]
The read latency of the memory pool in cycles.
If not specified/known, this will default to 0.
write_latency_cycles : Optional[int]
The write latency of the memory pool in cycles.
If not specified/known, this will default to 0.
target_burst_bytes : Optional[Union[Dict[Target, int], None]]
The burst length of the memory pool in bytes per target.
If not specified/known for a given target, a burst length
of 1 byte will be assumed.
"""

# The string parameter to indicate read and write access to a pool
Expand All @@ -67,13 +95,28 @@ def __init__(
self,
pool_name: str,
target_access: Dict[Target, str],
size_hint_bytes: Optional[int] = None,
size_hint_bytes: Optional[int] = -1,
clock_frequency_hz: Optional[int] = -1,
read_bandwidth_bytes_per_cycle: Optional[int] = -1,
write_bandwidth_bytes_per_cycle: Optional[int] = -1,
read_latency_cycles: Optional[int] = 0,
write_latency_cycles: Optional[int] = 0,
target_burst_bytes: Optional[Union[Dict[Target, int], None]] = None,
):
if not target_burst_bytes:
target_burst_bytes = dict()

self.__init_handle_by_constructor__(
_ffi_api.PoolInfo, # type: ignore # pylint: disable=no-member
pool_name,
target_access,
size_hint_bytes,
clock_frequency_hz,
read_bandwidth_bytes_per_cycle,
write_bandwidth_bytes_per_cycle,
read_latency_cycles,
write_latency_cycles,
target_burst_bytes,
)


Expand Down
2 changes: 1 addition & 1 deletion src/tir/usmp/algo/greedy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ size_t GreedyBase::round_up_to_byte_alignment(const size_t& non_aligned_byte_off
*/
bool GreedyBase::IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
const size_t& size_bytes) {
if (candidate_pool->size_hint_bytes == -1) {
if (candidate_pool->size_hint_bytes == PoolInfo::kUnrestrictedPoolSizeHint) {
// this means pool is not bounded
return true;
}
Expand Down
8 changes: 5 additions & 3 deletions src/tir/usmp/transform/assign_pool_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ class PoolInfoAssigner : public StmtExprMutator {
ICHECK(target_host) << "main function does not have a target attr";
Array<usmp::PoolInfo> pool_infos =
module->GetAttr<Array<usmp::PoolInfo>>(tvm::attr::kPoolInfoIRModuleAttr)
.value_or({usmp::PoolInfo("global_workspace",
{{target_host.value(), usmp::kTargetPoolReadWriteAccess}},
usmp::kUnrestrictedPoolSizeHint, Bool(true))});
.value_or({usmp::PoolInfo(
"global_workspace", {{target_host.value(), PoolInfo::kTargetPoolReadWriteAccess}},
PoolInfo::kUnrestrictedPoolSizeHint, PoolInfo::kUnknownClockFrequency,
PoolInfo::kUnknownReadBandwidth, PoolInfo::kUnknownWriteBandwidth, 0, 0,
{{target_host.value(), 1}}, Bool(true))});
for (const usmp::PoolInfo& pool_info : pool_infos) {
for (const auto& kv : pool_info->target_access) {
Target tgt = kv.first;
Expand Down
32 changes: 24 additions & 8 deletions src/tir/usmp/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,31 +93,47 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});

PoolInfo::PoolInfo(String pool_name, Map<Target, String> target_access, Integer size_hint_bytes,
Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle,
Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles,
Integer write_latency_cycles, Map<Target, Integer> target_burst_bytes,
Bool is_internal) {
auto poolinfo_node = make_object<PoolInfoNode>();
poolinfo_node->pool_name = pool_name;
poolinfo_node->size_hint_bytes = size_hint_bytes;
poolinfo_node->target_access = target_access;
poolinfo_node->clock_frequency_hz = clock_frequency_hz;
poolinfo_node->read_bandwidth_bytes_per_cycle = read_bandwidth_bytes_per_cycle;
poolinfo_node->write_bandwidth_bytes_per_cycle = write_bandwidth_bytes_per_cycle;
poolinfo_node->read_latency_cycles = read_latency_cycles;
poolinfo_node->write_latency_cycles = write_latency_cycles;
poolinfo_node->target_burst_bytes = target_burst_bytes;
poolinfo_node->is_internal = is_internal;
data_ = std::move(poolinfo_node);
}

TVM_REGISTER_NODE_TYPE(PoolInfoNode);
TVM_REGISTER_GLOBAL("tir.usmp.PoolInfo")
.set_body_typed([](String pool_name, Map<Target, String> target_access,
Integer size_hint_bytes) {
if (size_hint_bytes.defined()) {
return PoolInfo(pool_name, target_access, size_hint_bytes);
}
return PoolInfo(pool_name, target_access);
.set_body_typed([](String pool_name, Map<Target, String> target_access, Integer size_hint_bytes,
Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle,
Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles,
Integer write_latency_cycles, Map<Target, Integer> target_burst_bytes) {
return PoolInfo(pool_name, target_access, size_hint_bytes, clock_frequency_hz,
read_bandwidth_bytes_per_cycle, write_bandwidth_bytes_per_cycle,
read_latency_cycles, write_latency_cycles, target_burst_bytes, Bool(false));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PoolInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PoolInfoNode*>(ref.get());
p->stream << "PoolInfoNode(\n"
<< "pool_name=" << node->pool_name << ",\n target_access=" << node->target_access
<< ",\n size_hint_bytes=" << node->size_hint_bytes << ")";
<< " pool_name=" << node->pool_name << ",\n target_access=" << node->target_access
<< ",\n size_hint_bytes=" << node->size_hint_bytes
<< ",\n clock_frequency_hz=" << node->clock_frequency_hz
<< ",\n read_bandwidth_bytes_per_cycle=" << node->read_bandwidth_bytes_per_cycle
<< ",\n write_bandwidth_bytes_per_cycle=" << node->write_bandwidth_bytes_per_cycle
<< ",\n read_latency_cycles=" << node->read_latency_cycles
<< ",\n write_latency_cycles=" << node->write_latency_cycles
<< ",\n target_burst_bytes=" << node->target_burst_bytes << ")";
});

PoolAllocation::PoolAllocation(PoolInfo pool_info, Integer byte_offset) {
Expand Down

0 comments on commit 1c8a407

Please sign in to comment.