diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h index 582399865d6f..9ebe7f29b1f4 100644 --- a/include/tvm/tir/usmp/utils.h +++ b/include/tvm/tir/usmp/utils.h @@ -44,19 +44,6 @@ constexpr const char* kUSMPAlgorithmOption = "tir.usmp.algorithm"; namespace tir { namespace usmp { -/*! - * \brief The string parameter to indicate read and write access to a pool - * This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in - * python/tvm/tir/usmp/utils.py - */ -static constexpr const char* kTargetPoolReadWriteAccess = "rw"; -/*! - * \brief The string parameter to indicate read only access to a pool - * This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in - * python/tvm/tir/usmp/utils.py - */ -static constexpr const char* kTargetPoolReadOnlyAccess = "ro"; - /*! * \brief Describes a pool of memory accessible by one or more targets. */ @@ -64,33 +51,66 @@ struct PoolInfoNode : public Object { /*! \brief The name of the memory pool */ String pool_name; /*! \brief The expected size hint to be used by the allocator. - * The size_hint_bytes is defaulted to kUnrestrictedPoolSizeHint + * The size_hint_bytes is set to kUnrestrictedPoolSizeHint * to indicate the pool is not size restricted. */ Integer size_hint_bytes; - /*! \brief The accessibility from each Target*/ + /*! \brief The accessibility from each Target */ Map target_access; // 'rw' or 'ro' + /*! \brief The clock frequency of the memory in Hz */ + Integer clock_frequency_hz; + /*! \brief The read bandwidth in bytes/cycle */ + Integer read_bandwidth_bytes_per_cycle; + /*! \brief The write bandwidth in bytes/cycle */ + Integer write_bandwidth_bytes_per_cycle; + /*! \brief The read latency in cycles */ + Integer read_latency_cycles; + /*! \brief The write latency in cycles */ + Integer write_latency_cycles; + /*! \brief The burst length in bytes for each Target */ + Map target_burst_bytes; /*! \brief Whether pool is internally generated. * The internal pools will be generated as part of - * the entry point code generation of the executor*/ + * the entry point code generation of the executor + */ bool is_internal = false; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pool_name", &pool_name); v->Visit("size_hint_bytes", &size_hint_bytes); v->Visit("target_access", &target_access); + v->Visit("clock_frequency_hz", &clock_frequency_hz); + v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle); + v->Visit("write_bandwidth_bytes_per_cycle", &write_bandwidth_bytes_per_cycle); + v->Visit("read_latency_cycles", &read_latency_cycles); + v->Visit("write_latency_cycles", &write_latency_cycles); + v->Visit("target_burst_bytes", &target_burst_bytes); v->Visit("is_internal", &is_internal); } bool SEqualReduce(const PoolInfoNode* other, SEqualReducer equal) const { return equal(pool_name, other->pool_name) && equal(size_hint_bytes, other->size_hint_bytes) && - equal(target_access, other->target_access) && equal(is_internal, other->is_internal); + equal(target_access, other->target_access) && + equal(target_access, other->target_access) && + equal(clock_frequency_hz, other->clock_frequency_hz) && + equal(read_bandwidth_bytes_per_cycle, other->read_bandwidth_bytes_per_cycle) && + equal(write_bandwidth_bytes_per_cycle, other->write_bandwidth_bytes_per_cycle) && + equal(read_latency_cycles, other->read_latency_cycles) && + equal(write_latency_cycles, other->write_latency_cycles) && + equal(target_burst_bytes, other->target_burst_bytes) && + equal(is_internal, other->is_internal); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(pool_name); hash_reduce(size_hint_bytes); hash_reduce(target_access); + hash_reduce(clock_frequency_hz); + hash_reduce(read_bandwidth_bytes_per_cycle); + hash_reduce(write_bandwidth_bytes_per_cycle); + hash_reduce(read_latency_cycles); + hash_reduce(write_latency_cycles); + hash_reduce(target_burst_bytes); hash_reduce(is_internal); } @@ -98,16 +118,34 @@ struct PoolInfoNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object); }; -/*! - * \brief The PoolSize is unrestricted for the memory planner - */ -static const int kUnrestrictedPoolSizeHint = -1; - class PoolInfo : public ObjectRef { public: - TVM_DLL PoolInfo(String pool_name, Map target_access, - Integer size_hint_bytes = kUnrestrictedPoolSizeHint, - Bool is_internal = Bool(false)); + /*! + * \brief The string parameter to indicate read and write access to a pool + * This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in + * python/tvm/tir/usmp/utils.py + */ + static constexpr const char* kTargetPoolReadWriteAccess = "rw"; + /*! + * \brief The string parameter to indicate read only access to a pool + * This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in + * python/tvm/tir/usmp/utils.py + */ + static constexpr const char* kTargetPoolReadOnlyAccess = "ro"; + /*! \brief The PoolSize is unrestricted for the memory planner */ + static const int kUnrestrictedPoolSizeHint = -1; + /*! \brief The clock frequency is not known */ + static const int kUnknownClockFrequency = -1; + /*! \brief The read bandwidth is not known */ + static const int kUnknownReadBandwidth = -1; + /*! \brief The write bandwidth is not known */ + static const int kUnknownWriteBandwidth = -1; + + TVM_DLL PoolInfo(String pool_name, Map target_access, Integer size_hint_bytes, + Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle, + Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles, + Integer write_latency_cycles, Map target_burst_bytes, + Bool is_internal); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode); }; diff --git a/python/tvm/tir/usmp/utils.py b/python/tvm/tir/usmp/utils.py index 470765174acb..d138238ad888 100644 --- a/python/tvm/tir/usmp/utils.py +++ b/python/tvm/tir/usmp/utils.py @@ -17,7 +17,7 @@ """USMP Utilities and Data Structures""" # pylint: disable=invalid-name -from typing import Dict, Optional, List +from typing import Dict, Optional, List, Union from tvm._ffi import register_object from tvm.runtime import Object @@ -52,6 +52,34 @@ class PoolInfo(Object): The default value would be -1 which means the pool is not size restricted. + clock_frequency_hz : Optional[int] + The clock frequency that the memory pool runs at in Hz. + If not specified/known, this will default to -1 indicating + it hasn't been defined. + + read_bandwidth_bytes_per_cycle : Optional[int] + The read bandwidth of the memory pool in bytes/cycle. + If not specified/known, this will default to -1 indicating + it hasn't been defined. + + write_bandwidth_bytes_per_cycle : Optional[int] + The write bandwidth of the memory pool in bytes/cycle. + If not specified/known, this will default to -1 indicating + it hasn't been defined. + + read_latency_cycles : Optional[int] + The read latency of the memory pool in cycles. + If not specified/known, this will default to 0. + + write_latency_cycles : Optional[int] + The write latency of the memory pool in cycles. + If not specified/known, this will default to 0. + + target_burst_bytes : Optional[Union[Dict[Target, int], None]] + The burst length of the memory pool in bytes per target. + If not specified/known for a given target, a burst length + of 1 byte will be assumed. + """ # The string parameter to indicate read and write access to a pool @@ -67,13 +95,28 @@ def __init__( self, pool_name: str, target_access: Dict[Target, str], - size_hint_bytes: Optional[int] = None, + size_hint_bytes: Optional[int] = -1, + clock_frequency_hz: Optional[int] = -1, + read_bandwidth_bytes_per_cycle: Optional[int] = -1, + write_bandwidth_bytes_per_cycle: Optional[int] = -1, + read_latency_cycles: Optional[int] = 0, + write_latency_cycles: Optional[int] = 0, + target_burst_bytes: Optional[Union[Dict[Target, int], None]] = None, ): + if not target_burst_bytes: + target_burst_bytes = dict() + self.__init_handle_by_constructor__( _ffi_api.PoolInfo, # type: ignore # pylint: disable=no-member pool_name, target_access, size_hint_bytes, + clock_frequency_hz, + read_bandwidth_bytes_per_cycle, + write_bandwidth_bytes_per_cycle, + read_latency_cycles, + write_latency_cycles, + target_burst_bytes, ) diff --git a/src/tir/usmp/algo/greedy.cc b/src/tir/usmp/algo/greedy.cc index 324474c569d4..c4f7cabb99f1 100644 --- a/src/tir/usmp/algo/greedy.cc +++ b/src/tir/usmp/algo/greedy.cc @@ -61,7 +61,7 @@ size_t GreedyBase::round_up_to_byte_alignment(const size_t& non_aligned_byte_off */ bool GreedyBase::IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, const size_t& size_bytes) { - if (candidate_pool->size_hint_bytes == -1) { + if (candidate_pool->size_hint_bytes == PoolInfo::kUnrestrictedPoolSizeHint) { // this means pool is not bounded return true; } diff --git a/src/tir/usmp/transform/assign_pool_info.cc b/src/tir/usmp/transform/assign_pool_info.cc index 516ddd1a241b..009083373690 100644 --- a/src/tir/usmp/transform/assign_pool_info.cc +++ b/src/tir/usmp/transform/assign_pool_info.cc @@ -48,9 +48,11 @@ class PoolInfoAssigner : public StmtExprMutator { ICHECK(target_host) << "main function does not have a target attr"; Array pool_infos = module->GetAttr>(tvm::attr::kPoolInfoIRModuleAttr) - .value_or({usmp::PoolInfo("global_workspace", - {{target_host.value(), usmp::kTargetPoolReadWriteAccess}}, - usmp::kUnrestrictedPoolSizeHint, Bool(true))}); + .value_or({usmp::PoolInfo( + "global_workspace", {{target_host.value(), PoolInfo::kTargetPoolReadWriteAccess}}, + PoolInfo::kUnrestrictedPoolSizeHint, PoolInfo::kUnknownClockFrequency, + PoolInfo::kUnknownReadBandwidth, PoolInfo::kUnknownWriteBandwidth, 0, 0, + {{target_host.value(), 1}}, Bool(true))}); for (const usmp::PoolInfo& pool_info : pool_infos) { for (const auto& kv : pool_info->target_access) { Target tgt = kv.first; diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc index 1fff70f5892e..b7c1a5f59f24 100644 --- a/src/tir/usmp/utils.cc +++ b/src/tir/usmp/utils.cc @@ -93,31 +93,47 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); PoolInfo::PoolInfo(String pool_name, Map target_access, Integer size_hint_bytes, + Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle, + Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles, + Integer write_latency_cycles, Map target_burst_bytes, Bool is_internal) { auto poolinfo_node = make_object(); poolinfo_node->pool_name = pool_name; poolinfo_node->size_hint_bytes = size_hint_bytes; poolinfo_node->target_access = target_access; + poolinfo_node->clock_frequency_hz = clock_frequency_hz; + poolinfo_node->read_bandwidth_bytes_per_cycle = read_bandwidth_bytes_per_cycle; + poolinfo_node->write_bandwidth_bytes_per_cycle = write_bandwidth_bytes_per_cycle; + poolinfo_node->read_latency_cycles = read_latency_cycles; + poolinfo_node->write_latency_cycles = write_latency_cycles; + poolinfo_node->target_burst_bytes = target_burst_bytes; poolinfo_node->is_internal = is_internal; data_ = std::move(poolinfo_node); } TVM_REGISTER_NODE_TYPE(PoolInfoNode); TVM_REGISTER_GLOBAL("tir.usmp.PoolInfo") - .set_body_typed([](String pool_name, Map target_access, - Integer size_hint_bytes) { - if (size_hint_bytes.defined()) { - return PoolInfo(pool_name, target_access, size_hint_bytes); - } - return PoolInfo(pool_name, target_access); + .set_body_typed([](String pool_name, Map target_access, Integer size_hint_bytes, + Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle, + Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles, + Integer write_latency_cycles, Map target_burst_bytes) { + return PoolInfo(pool_name, target_access, size_hint_bytes, clock_frequency_hz, + read_bandwidth_bytes_per_cycle, write_bandwidth_bytes_per_cycle, + read_latency_cycles, write_latency_cycles, target_burst_bytes, Bool(false)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "PoolInfoNode(\n" - << "pool_name=" << node->pool_name << ",\n target_access=" << node->target_access - << ",\n size_hint_bytes=" << node->size_hint_bytes << ")"; + << " pool_name=" << node->pool_name << ",\n target_access=" << node->target_access + << ",\n size_hint_bytes=" << node->size_hint_bytes + << ",\n clock_frequency_hz=" << node->clock_frequency_hz + << ",\n read_bandwidth_bytes_per_cycle=" << node->read_bandwidth_bytes_per_cycle + << ",\n write_bandwidth_bytes_per_cycle=" << node->write_bandwidth_bytes_per_cycle + << ",\n read_latency_cycles=" << node->read_latency_cycles + << ",\n write_latency_cycles=" << node->write_latency_cycles + << ",\n target_burst_bytes=" << node->target_burst_bytes << ")"; }); PoolAllocation::PoolAllocation(PoolInfo pool_info, Integer byte_offset) {