diff --git a/apps/microtvm/zephyr_cmsisnn/src/main.c b/apps/microtvm/zephyr_cmsisnn/src/main.c index bb38a7791fb4..274bd63d3ea5 100644 --- a/apps/microtvm/zephyr_cmsisnn/src/main.c +++ b/apps/microtvm/zephyr_cmsisnn/src/main.c @@ -34,7 +34,7 @@ extern float output_storage[12]; extern const size_t output_len; -static uint8_t g_crt_workspace[TVMGEN_DEFAULT_WORKSPACE_SIZE + 256]; +static uint8_t g_crt_workspace[TVMGEN_DEFAULT_WORKSPACE_SIZE + 512]; tvm_workspace_t app_workspace; void TVMLogf(const char* msg, ...) { diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h index aa4f27de6d8b..d9f212f489f0 100644 --- a/include/tvm/tir/usmp/utils.h +++ b/include/tvm/tir/usmp/utils.h @@ -185,24 +185,24 @@ struct AllocatedPoolInfoNode : public Object { PoolInfo pool_info; /*! \brief The allocated size into this pool */ Integer allocated_size; - /*! \brief An optional associated pool Var*/ - Optional pool_var; + /*! \brief An optional associated pool Var index of PrimFunc params*/ + Optional pool_var_idx; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pool_info", &pool_info); v->Visit("allocated_size", &allocated_size); - v->Visit("pool_var", &pool_var); + v->Visit("pool_var_idx", &pool_var_idx); } bool SEqualReduce(const AllocatedPoolInfoNode* other, SEqualReducer equal) const { return equal(pool_info, other->pool_info) && equal(allocated_size, other->allocated_size) && - equal(pool_var, other->pool_var); + equal(pool_var_idx, other->pool_var_idx); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(pool_info); hash_reduce(allocated_size); - hash_reduce(pool_var); + hash_reduce(pool_var_idx); } static constexpr const char* _type_key = "tir.usmp.AllocatedPoolInfo"; @@ -211,7 +211,8 @@ struct AllocatedPoolInfoNode : public Object { class AllocatedPoolInfo : public ObjectRef { public: - TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var = Var()); + TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, + Integer pool_var_idx = Integer()); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AllocatedPoolInfo, ObjectRef, AllocatedPoolInfoNode); }; diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index b876058b6b0c..d53c4ed49939 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -181,42 +181,26 @@ def _build_function_memory_map(function_metadata): """ device_max_workspace = dict() main_func_metadata = function_metadata[MAIN_FUNC_NAME_STR] - num_targets = len(main_func_metadata.workspace_sizes.items()) - from tvm.driver import tvmc # pylint: disable=import-outside-toplevel - - external_codegens = tvmc.composite_target.get_codegen_names() func_entries = [] target_local_entries = dict() - for i in range(num_targets): - main_target = main_func_metadata.workspace_sizes.items()[i][0] - device_max_workspace[main_target] = 0 - for func_name, finfo in function_metadata.items(): - if func_name == MAIN_FUNC_NAME_STR: - continue - target_local_entries[func_name] = list() - for func_name, finfo in function_metadata.items(): - # Skip a few unsupported cases: - # 1. The main function metadata is exported elsewhere. - # 2. BYOC operator implementations do not currently export useful FunctionInfo. - if func_name == MAIN_FUNC_NAME_STR or not finfo.tir_primfuncs: - continue - assert ( - len(finfo.constant_sizes.items()) == num_targets - ), f"{func_name}: found {finfo.constant_sizes!r} vs {num_targets}" - assert len(finfo.io_sizes.items()) == num_targets - target = finfo.workspace_sizes.items()[i][0] - workspace_size = finfo.workspace_sizes.items()[i][1] + for func_name, finfo in function_metadata.items(): + # Skip a few unsupported cases: + # 1. The main function metadata is exported elsewhere. + # 2. BYOC operator implementations do not currently export useful FunctionInfo. + if func_name == MAIN_FUNC_NAME_STR or not finfo.tir_primfuncs: + continue + if func_name not in target_local_entries.keys(): + target_local_entries[func_name] = list() + for target in dict(finfo.workspace_sizes).keys(): + workspace_size = finfo.workspace_sizes[target] target_entry = { "device": int(target.kind.device_type), "workspace_size_bytes": int(workspace_size), } target_local_entries[func_name].append(target_entry) - if workspace_size > device_max_workspace.get(target, 0): - device_max_workspace[target] = workspace_size - # TODO(Mousius) - Remove this massive hack when Targets are unified - if target.kind.name in external_codegens: - device_max_workspace[main_target] += int(workspace_size) + if workspace_size >= device_max_workspace.get(int(target.kind.device_type), 0): + device_max_workspace[int(target.kind.device_type)] = workspace_size for func_name, target_entries_ in target_local_entries.items(): func_entry = { @@ -225,25 +209,46 @@ def _build_function_memory_map(function_metadata): } func_entries.append(func_entry) - target_main_entries = list() - for i in range(num_targets): - target = main_func_metadata.workspace_sizes.items()[i][0] - main_func_local_workspace = main_func_metadata.workspace_sizes.items()[i][1] - main_func_constants = main_func_metadata.constant_sizes.items()[i][1] - main_func_io = main_func_metadata.io_sizes.items()[i][1] - target_main_entries.append( - { - "device": int(target.kind.device_type), - "workspace_size_bytes": int(device_max_workspace[target]) - + int(main_func_local_workspace), - "constants_size_bytes": int(main_func_constants), - "io_size_bytes": int(main_func_io), - } + target_main_entries = dict() + + def _create_empty_entry(target_device_type): + return { + "device": int(target_device_type), + "workspace_size_bytes": 0, + "constants_size_bytes": 0, + "io_size_bytes": 0, + } + + for target in dict(main_func_metadata.workspace_sizes).keys(): + main_func_local_workspace = main_func_metadata.workspace_sizes[target] + target_main_entries[int(target.kind.device_type)] = _create_empty_entry( + int(target.kind.device_type) + ) + target_main_entries[int(target.kind.device_type)]["workspace_size_bytes"] = int( + device_max_workspace.get(int(target.kind.device_type), 0) + ) + int(main_func_local_workspace) + + for target in dict(main_func_metadata.constant_sizes).keys(): + if int(target.kind.device_type) not in target_main_entries.keys(): + target_main_entries[int(target.kind.device_type)] = _create_empty_entry( + int(target.kind.device_type) + ) + target_main_entries[int(target.kind.device_type)]["constants_size_bytes"] = int( + main_func_metadata.constant_sizes[target] + ) + + for target in dict(main_func_metadata.io_sizes).keys(): + if int(target.kind.device_type) not in target_main_entries.keys(): + target_main_entries[int(target.kind.device_type)] = _create_empty_entry( + int(target.kind.device_type) + ) + target_main_entries[int(target.kind.device_type)]["io_size_bytes"] = int( + main_func_metadata.io_sizes[target] ) ret = { "operator_functions": func_entries, - "main": target_main_entries, + "main": list(target_main_entries.values()), } return ret diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index c55a6310ffa5..c53c32801db2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -42,15 +42,6 @@ class BufferType(Enum): shram = auto() -_REGION_MAP = { - BufferType.constant: 0, - BufferType.scratch: 1, - BufferType.input: 3, - BufferType.output: 4, - BufferType.shram: int((1 << 8) | (3 << 0)), -} - - class BufferInfo(NamedTuple): """A data structure to hold metadata of the buffer.""" @@ -81,6 +72,111 @@ def get_accelerator_arch_config(accel_type): return accel_config_str_map[accel_type] +class RegionOffset(NamedTuple): + """A data structure to hold region and address offset corresponding to a tensor""" + + region: int + offset: int + + +def analyze_scratch_memory_acesses(mod: tvm.IRModule, candidate_regions_for_scratch: List[int]): + """ + This function analyzes the IRModule for intermediary tensors that can be resulting + from a offset of pool variables (via Let nodes) and/or allocate nodes. The allocate + nodes will be folded into a single TVMBackendallocWorkspace call with offsets. Ultimately + this will produce a mapping from each such node to a RegionOffset named tuple that + has the region and the obtained offset, as mentioned above. + + Parameters + ---------- + mod: tvm.IRModule + The TIR module containing ethosu extern calls + candidate_regions_for_scratch: List[int] + A list of region integers that could be used for scratch regions + + Returns + ------- + scratch_region_map : Dict[tvm.tir.Var, RegionOffset] + A map between buffer vars to scratch regions they are assigned + tvm_backend_alloc_workspace_size : int + The size of tvm_backend_alloc_workspace call required to service + remaining allocate nodes if any + tvm_backend_alloc_workspace_region : int + The region associated with the tvm_backend_alloc_workspace + """ + scratch_region_map = dict() + pool_var_region_map = dict() + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + if "pool_args" in primfunc.attrs.keys(): + pool_args = primfunc.attrs["pool_args"] + for pool_arg in pool_args: + pool_param = primfunc.params[int(pool_arg.pool_var_idx)] + pool_var_region_map[pool_param] = candidate_regions_for_scratch.pop() + scratch_region_map[pool_param] = RegionOffset( + region=pool_var_region_map[pool_param], offset=None + ) + + def analyze_pool_access(stmt): + if isinstance(stmt, tvm.tir.stmt.LetStmt): + call_address_of = stmt.value + load = call_address_of.args[0] + pool_var = load.buffer_var + scratch_region_map[stmt.var] = RegionOffset( + region=pool_var_region_map[pool_var], offset=int(load.index) + ) + + tvm.tir.stmt_functor.post_order_visit(primfunc.body, analyze_pool_access) + + tvmbaw_region = None + if len(candidate_regions_for_scratch) > 0: + tvmbaw_region = candidate_regions_for_scratch.pop() + tvmbaw_size = 0 + + # If there are tir.Allocate remaining by now, they need to be serviced via + # TVMBAW calls. + def analyze_remaining_allocates(stmt): + nonlocal tvmbaw_size + if isinstance(stmt, tvm.tir.stmt.Allocate): + allocate = stmt + pointer_type = allocate.buffer_var.type_annotation + storage_scope = pointer_type.storage_scope + if storage_scope == "global": + dtype_bytes = np.iinfo(np.dtype(allocate.dtype)).bits // 8 + size_in_bytes = int(dtype_bytes * np.prod(list(allocate.extents))) + # Every memory address the NPU access have to be 16 byte aligned + size_in_bytes = util.round_up(size_in_bytes, 16) + address = tvmbaw_size + tvmbaw_size += size_in_bytes + scratch_region_map[allocate.buffer_var] = RegionOffset( + region=tvmbaw_region, offset=address + ) + + tvm.tir.stmt_functor.post_order_visit(primfunc.body, analyze_remaining_allocates) + + return ( + scratch_region_map, + tvmbaw_size, + tvmbaw_region, + ) + + +def _get_region(buffer_type, var=None, scratch_region_map=None): + """A helper to obtain regions for buffer_types and buffer vars""" + static_regions = { + BufferType.constant: 0, + BufferType.input: 3, + BufferType.output: 4, + BufferType.shram: int((1 << 8) | (3 << 0)), + } + if buffer_type in static_regions.keys(): + return static_regions[buffer_type] + assert buffer_type == BufferType.scratch + assert var in scratch_region_map.keys(), f"{var} is not analyzed for scratch regions" + return scratch_region_map[var].region + + def translate(tir_module, params): """This will take an tir module for the NPU and compile to command stream @@ -106,21 +202,31 @@ def translate(tir_module, params): base addresses to be used by the driver """ + # The NPU has 6 usable regions ranging from 0-6 + # The regions 0, 3, and 4 is already used for input, + # output and constant, respectively (See _get_regions()). + # Thus, for scratch we are left with 5, 2 and 1. + candidate_regions_for_scratch = [5, 2, 1] + ( + scratch_region_map, + tvmbaw_workspace_size, + tvmbaw_region, + ) = analyze_scratch_memory_acesses(tir_module, candidate_regions_for_scratch) buffer_info = extract_buffer_info(tir_module, params) call_extern_list = extract_call_extern_list(tir_module) _npu_ops = list() for call_extern in call_extern_list: _npu_ops.append(translate_ethosu_tir_call_extern(call_extern)) - _npu_ops, constant_data, scratch_size = assign_addresses(buffer_info, _npu_ops) - base_addresses = extract_param_base_addresses(tir_module, buffer_info) - if scratch_size > 0: + _npu_ops, constant_data = assign_addresses(buffer_info, _npu_ops, scratch_region_map) + base_addresses = extract_param_base_addresses(tir_module, buffer_info, scratch_region_map) + if tvmbaw_workspace_size: base_addresses.append( util.BaseAddress( - "scratch", - None, - _REGION_MAP[BufferType.scratch], - scratch_size, - True, + name="tvmbaw", + primfunc_param_idx=None, + region=tvmbaw_region, + size=tvmbaw_workspace_size, + is_runtime_allocation=True, ) ) target_accel_config = vela_api.get_accelerator_config() @@ -129,7 +235,7 @@ def translate(tir_module, params): return payload.hex(), constant_data, base_addresses -def extract_param_base_addresses(mod, buffer_info) -> List[util.BaseAddress]: +def extract_param_base_addresses(mod, buffer_info, scratch_region_map) -> List[util.BaseAddress]: """This function extracts base addresses to be used by the driver Parameters @@ -161,7 +267,12 @@ def extract_param_base_addresses(mod, buffer_info) -> List[util.BaseAddress]: element_size_bytes = np.iinfo(dtype).bits // 8 size_bytes = element_size_bytes * np.prod(list(buffer.shape)) base_addresses.append( - util.BaseAddress(param.name, idx, _REGION_MAP[buffer_info[param].btype], size_bytes) + util.BaseAddress( + param.name, + idx, + _get_region(buffer_info[param].btype, param, scratch_region_map), + size_bytes, + ) ) idx += 1 @@ -227,39 +338,42 @@ def extract_buffer_info( const_data, const_data.shape, const_data.dtype, BufferType.constant ) - for param in primfunc.params: + pool_param_indices = list() + if "pool_args" in primfunc.attrs.keys(): + pool_args = primfunc.attrs["pool_args"] + pool_param_indices = [allocated_pool_info.pool_var_idx for allocated_pool_info in pool_args] + + for idx, param in enumerate(primfunc.params): if param not in buffer_info.keys(): + if idx in pool_param_indices: + btype = BufferType.scratch + else: + btype = BufferType.input_or_output buffer_info[param] = BufferInfo( None, None, None, - BufferType.input_or_output, + btype, ) def populate_allocate_buffer_info(stmt): if isinstance(stmt, tvm.tir.stmt.Allocate): allocate = stmt - if "placeholder" in allocate.buffer_var.name: - storage_scope = allocate.buffer_var.name.split(".")[-1] - else: - storage_scope = "global" - + pointer_type = allocate.buffer_var.type_annotation + storage_scope = pointer_type.storage_scope if storage_scope == "local": - buffer_type = BufferType.shram - else: - buffer_type = BufferType.scratch - buffer_info[allocate.buffer_var] = BufferInfo( - None, - allocate.extents, - allocate.dtype, - buffer_type, - ) + buffer_info[allocate.buffer_var] = BufferInfo( + None, + allocate.extents, + allocate.dtype, + BufferType.shram, + ) tvm.tir.stmt_functor.post_order_visit(primfunc.body, populate_allocate_buffer_info) return buffer_info -def assign_addresses(buffer_info, npu_ops): +def assign_addresses(buffer_info, npu_ops, scratch_region_map): """This function will assign addresses to tensors within two buffers : scratch and constants. The scratch is the buffer created to hold all intermediary data @@ -272,14 +386,14 @@ def assign_addresses(buffer_info, npu_ops): The key is the buffer name to BufferInfo npu_ops : list A list of Vela NpuOps with tir.Loads for addresses + scratch_region_map : Dict[tvm.tir.Var, RegionOffset] + A buffer_var to region and offset map. Returns ------- npu_ops : list A list of Vela NpuOps with addesses within scratch and constant buffers constant_tensor : NDArray A unified constant data array of uint8 as the constant buffer - scratch_size : int - The size of the scratch tensor. """ def replace_npu_fm_with_address(npu_fm): @@ -290,21 +404,34 @@ def replace_npu_fm_with_address(npu_fm): assert npu_fm.tiles.addresses[1:] == [0, 0, 0] npu_fm.tiles.addresses[1:] = [0, 0, 0] buffer = npu_fm.tiles.addresses[0].buffer_var - assert buffer in buffer_addresses.keys() - address, buffer_type = buffer_addresses[buffer] + + if buffer in scratch_region_map.keys(): + address = scratch_region_map[buffer].offset + region = scratch_region_map[buffer].region + else: + assert buffer in buffer_addresses.keys() + address, buffer_type = buffer_addresses[buffer] + region = _get_region(buffer_type) + index = npu_fm.tiles.addresses[0].index * ( np.iinfo(np.dtype(npu_fm.tiles.addresses[0])).bits // 8 ) npu_fm.tiles.addresses[0] = address + int(index) - npu_fm.region = _REGION_MAP[buffer_type] + npu_fm.region = region return npu_fm def replace_npu_address_range_with_address(npu_addr_range): assert isinstance(npu_addr_range.address, tvm.tir.Load) buffer = npu_addr_range.address.buffer_var + if buffer in scratch_region_map.keys(): + return vapi.NpuAddressRange( + scratch_region_map[buffer].region, + scratch_region_map[buffer].offset, + npu_addr_range.length, + ) assert buffer in buffer_addresses.keys(), f"searching for buffer : {buffer}, but not found" address, buffer_type = buffer_addresses[buffer] - return vapi.NpuAddressRange(_REGION_MAP[buffer_type], address, npu_addr_range.length) + return vapi.NpuAddressRange(_get_region(buffer_type), address, npu_addr_range.length) def replace_tir_loads(npu_object): if isinstance(npu_object, vapi.NpuFeatureMap): @@ -325,7 +452,6 @@ def classify_io(buffer): raise ValueError(f"Unused IO : {buffer} in tir module.") - scratch_size = 0 constant_hex_data = [] total_constant_len = 0 buffer_addresses = dict() @@ -345,8 +471,10 @@ def classify_io(buffer): constant_hex_data.append(constant_tensor) total_constant_len += len(constant_tensor) // 2 else: - if info.btype == BufferType.input_or_output: - buffer_type = classify_io(_buffer) + if info.btype == BufferType.input_or_output or info.btype == BufferType.input: + buffer_type = info.btype + if info.btype == BufferType.input_or_output: + buffer_type = classify_io(_buffer) assert buffer_type in (BufferType.input, BufferType.output) address = 0 buffer_addresses[_buffer] = (address, buffer_type) @@ -359,14 +487,8 @@ def classify_io(buffer): address = arch_config.lut_start_address buffer_addresses[_buffer] = (address, info.btype) else: - dtype_bytes = np.iinfo(np.dtype(info.dtype)).bits // 8 - size_in_bytes = int(dtype_bytes * np.prod(list(info.shape))) - # Every memory address the NPU access have to be 16 byte aligned - size_in_bytes = util.round_up(size_in_bytes, 16) + # These buffer_vars are already updated in scratch_region_map assert info.btype == BufferType.scratch - address = scratch_size - scratch_size += size_in_bytes - buffer_addresses[_buffer] = (address, info.btype) for npu_op in npu_ops: for attr_name, attr in npu_op.__dict__.items(): @@ -379,11 +501,7 @@ def classify_io(buffer): setattr(npu_op, attr_name, replace_tir_loads(attr)) constant_data = "".join(constant_hex_data) - return ( - npu_ops, - constant_data, - scratch_size, - ) + return (npu_ops, constant_data) def translate_ethosu_tir_call_extern(tir_call_extern): @@ -733,17 +851,18 @@ def _create_npu_rounding_mode( def _create_npu_dma_op(serial_copy): """This is a helper function to capture the list of arguments to create a NpuDmaOperation object""" + data_type_bytes = np.iinfo(np.dtype(serial_copy.read_address.dtype)).bits // 8 src = vapi.NpuAddressRange( # region will be updated later region=0, address=serial_copy.read_address, - length=int(serial_copy.length.value), + length=int(serial_copy.length.value) * data_type_bytes, ) dest = vapi.NpuAddressRange( # region will be updated later region=0, address=serial_copy.write_address, - length=int(serial_copy.length.value), + length=int(serial_copy.length.value) * data_type_bytes, ) return vapi.NpuDmaOperation(src, dest) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 402f7b07a181..3694d6bcef95 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -948,8 +948,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { tir_main_func->GetAttr>(tvm::attr::kPoolArgs); if (allocated_pool_infos) { for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) { - pool_vars.push_back(allocated_pool_info->pool_var.value()); - pool_var_info.Set(allocated_pool_info->pool_var.value(), allocated_pool_info); + int pool_var_index = allocated_pool_info->pool_var_idx.value()->value; + pool_vars.push_back(tir_main_func->params[pool_var_index]); + pool_var_info.Set(tir_main_func->params[pool_var_index], allocated_pool_info); } } Array devices = ListDevices(); diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index fb4fb52c507e..bb959647c7f0 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -36,6 +36,8 @@ #include +#include "../../../runtime/thread_storage_scope.h" + namespace tvm { namespace tir { namespace usmp { @@ -257,18 +259,20 @@ void BufferInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) { void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) { ScopeInfo& current_scope_info = scope_stack_.top(); const auto& type = Downcast(op->buffer_var->type_annotation); - const auto& storage_scope = type->storage_scope; + const auto& storage_scope = runtime::StorageScope::Create(type->storage_scope); // If the allocate is in a for loop, USMP currently only looks at serial for loops. // If its not a serial for loop, then memory planner will omit them in the current memory planning // process leaving them to as tir.allocate nodes for codegen. Additionally, the USMP can only work // with buffers that have global storage_scope - if (!current_scope_info.for_loop.defined()) { - RecordAllocateNodeInfo(op); - } else if (current_scope_info.for_loop.defined() && - current_scope_info.for_loop->kind == ForKind::kSerial && storage_scope == "global") { - RecordAllocateNodeInfo(op); + if (storage_scope.rank == runtime::StorageRank::kGlobal) { + if (!current_scope_info.for_loop.defined()) { + RecordAllocateNodeInfo(op); + } else if (current_scope_info.for_loop.defined() && + current_scope_info.for_loop->kind == ForKind::kSerial) { + RecordAllocateNodeInfo(op); + } } StmtExprVisitor::VisitStmt(op->body); current_scope_info.allocate_nodes.erase(GetRef(op)); diff --git a/src/tir/usmp/transform/assign_pool_info.cc b/src/tir/usmp/transform/assign_pool_info.cc index 1c528351566e..9d8e36137c37 100644 --- a/src/tir/usmp/transform/assign_pool_info.cc +++ b/src/tir/usmp/transform/assign_pool_info.cc @@ -48,10 +48,7 @@ class PoolInfoAssigner : public StmtExprMutator { ICHECK(target_host) << "main function does not have a target attr"; WorkspaceMemoryPools workspace_pools = module->GetAttr(tvm::attr::kWorkspaceMemoryPools) - .value_or(WorkspaceMemoryPools( - {PoolInfo("global_workspace", {{target_host.value(), kTargetPoolReadWriteAccess}}, - kUnrestrictedPoolSizeHint, kUnknownClockFrequency, kUnknownReadBandwidth, - kUnknownWriteBandwidth, 0, 0, {{target_host.value(), 1}}, Bool(true))})); + .value_or(WorkspaceMemoryPools({CreateDefaultMemoryPool(module)})); Array pool_infos = workspace_pools->pools; for (const PoolInfo& pool_info : pool_infos) { for (const auto& kv : pool_info->target_access) { @@ -76,8 +73,24 @@ class PoolInfoAssigner : public StmtExprMutator { IRModule mod_; Map> target_pool_infos_; PrimFunc func_; + PoolInfo CreateDefaultMemoryPool(const IRModule& module); }; +PoolInfo PoolInfoAssigner::CreateDefaultMemoryPool(const tvm::IRModule& module) { + Map target_access; + tir::PrimFunc tir_main_func = + Downcast(module->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix)); + Target target_host = tir_main_func->GetAttr(tvm::attr::kTarget).value(); + for (const auto& kv : module->functions) { + BaseFunc func = kv.second; + Optional target = func->GetAttr(tvm::attr::kTarget); + target_access.Set(target.value_or(target_host), kTargetPoolReadWriteAccess); + } + return PoolInfo("global_workspace", target_access, kUnrestrictedPoolSizeHint, + kUnknownClockFrequency, kUnknownReadBandwidth, kUnknownWriteBandwidth, 0, 0, {}, + Bool(true)); +} + Stmt PoolInfoAssigner::VisitStmt_(const AllocateNode* op) { Optional tgt = func_->GetAttr(tvm::attr::kTarget).value(); ICHECK(tgt) << "The following PrimFunc does not have a target attr: \n" << func_; diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index cd797681d474..999ca37d2128 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -189,7 +189,7 @@ PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::Upda si.params.push_back(pool_var); si.pools_to_params.Set(pool_info, pool_var); si.allocated_pool_params.push_back(AllocatedPoolInfo( - allocated_pool_info->pool_info, allocated_pool_info->allocated_size, pool_var)); + allocated_pool_info->pool_info, allocated_pool_info->allocated_size, si.params.size() - 1)); int pool_size = all_pools_sizes_[pool_info]; String buffer_var_name = pool_ref_name + "_buffer_var"; @@ -258,7 +258,7 @@ Array PoolAllocationToOffsetConverter::ReplaceAllocateArgsWithLetArgs( allocate_buf_to_let_var_.find(Downcast(arg)) != allocate_buf_to_let_var_.end()) { ret.push_back(allocate_buf_to_let_var_[Downcast(arg)]); } else { - ret.push_back(arg); + ret.push_back(VisitExpr(arg)); } } return ret; diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc index b789adbe81af..5c95f7d7a7be 100644 --- a/src/tir/usmp/utils.cc +++ b/src/tir/usmp/utils.cc @@ -114,12 +114,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ")"; }); -AllocatedPoolInfo::AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var) { +AllocatedPoolInfo::AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, + Integer pool_var_idx) { auto allocated_poolinfo_node = make_object(); allocated_poolinfo_node->pool_info = pool_info; allocated_poolinfo_node->allocated_size = allocated_size; - if (pool_var.defined()) { - allocated_poolinfo_node->pool_var = pool_var; + if (pool_var_idx.defined()) { + allocated_poolinfo_node->pool_var_idx = pool_var_idx; } data_ = std::move(allocated_poolinfo_node); } diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 7bfdaca6da28..4bdaef7a74ca 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -189,7 +189,7 @@ def deserialize_command_stream(blob): return cmms -def create_test_runner(accel="ethos-u55-256"): +def create_test_runner(accel="ethos-u55-256", enable_usmp=True): file_dir = os.path.dirname(os.path.abspath(__file__)) test_root = os.path.join(file_dir, "reference_system") _, ethosu_variant, ethosu_macs = accel.split("-") @@ -215,12 +215,15 @@ def create_test_runner(accel="ethos-u55-256"): "relay.ext.ethos-u.options": { "accelerator_config": accel, }, + "tir.usmp.enable": enable_usmp, }, ) -def build_source(module, inputs, outputs, accel="ethos-u55-256", output_tolerance=0): - test_runner = create_test_runner(accel) +def build_source( + module, inputs, outputs, accel="ethos-u55-256", output_tolerance=0, enable_usmp=True +): + test_runner = create_test_runner(accel, enable_usmp) return compile_models( models=AOTTestModel( module=module, diff --git a/tests/python/contrib/test_ethosu/test_networks.py b/tests/python/contrib/test_ethosu/test_networks.py index de263c18f368..e9c6da5be18a 100644 --- a/tests/python/contrib/test_ethosu/test_networks.py +++ b/tests/python/contrib/test_ethosu/test_networks.py @@ -34,11 +34,20 @@ from . import infra -ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32", "ethos-u65-256"] - -@pytest.mark.parametrize("accel_type", ACCEL_TYPES) -def test_forward_mobilenet_v1(accel_type): +@pytest.mark.parametrize( + "accel_type, enable_usmp", + [ + ("ethos-u55-256", True), + ("ethos-u55-256", False), + ("ethos-u65-256", True), + ("ethos-u65-256", False), + ("ethos-u55-128", True), + ("ethos-u55-64", True), + ("ethos-u55-32", True), + ], +) +def test_forward_mobilenet_v1(accel_type, enable_usmp): """Test the Mobilenet V1 TF Lite model.""" np.random.seed(23) tflite_model_file = tf_testing.get_workload_official( @@ -60,7 +69,7 @@ def test_forward_mobilenet_v1(accel_type): mod = partition_for_ethosu(relay_mod, params) compiled_models = infra.build_source( - mod, input_data, output_data, accel_type, output_tolerance=10 + mod, input_data, output_data, accel_type, output_tolerance=10, enable_usmp=enable_usmp ) infra.verify_source(compiled_models, accel_type) diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index add8021083c6..de214888be6b 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -645,7 +645,7 @@ def populate_ethosu_copy_calls(stmt): { "src": "placeholder_5", "dest": "placeholder_d_global", - "length": 8, + "length": 32, }, ], }, @@ -851,24 +851,45 @@ def _check_buffer(address, region, length, buffer_var): length, dtype=buffer_dtype ) elif buffer_type == tir_to_cs_translator.BufferType.scratch: - shape = list(buffer_info[buffer_var].shape) - assert length == np.prod(shape) - assert address < scratch_size + assert address < tvmbaw_workspace_size - size_in_bytes = int(np.prod(shape)) * dtype_bytes + size_in_bytes = allocate_node_sizes[buffer_var] # Every buffer is adjusted to align to 16 bytes size_in_bytes = util.round_up(size_in_bytes, 16) - assert address + size_in_bytes <= scratch_size + assert address + size_in_bytes <= tvmbaw_workspace_size # The scratch area should not be used by any other buffer - assert not scratch_mask[address : address + size_in_bytes].any() + assert not tvmbaw_workspace_mask[address : address + size_in_bytes].any() # The scratch area is marked as used - scratch_mask[address : address + size_in_bytes] = np.ones(size_in_bytes, dtype="uint8") + tvmbaw_workspace_mask[address : address + size_in_bytes] = np.ones( + size_in_bytes, dtype="uint8" + ) elif buffer_type == tir_to_cs_translator.BufferType.input: assert address == 0 else: assert buffer_type == tir_to_cs_translator.BufferType.output assert address == 0 + def _get_allocate_node_sizes(mod): + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + _allocate_node_sizes = dict() + + def analyze_remaining_allocates(stmt): + if isinstance(stmt, tvm.tir.stmt.Allocate): + allocate = stmt + pointer_type = allocate.buffer_var.type_annotation + storage_scope = pointer_type.storage_scope + if storage_scope == "global": + dtype_bytes = np.iinfo(np.dtype(allocate.dtype)).bits // 8 + size_in_bytes = int(dtype_bytes * np.prod(list(allocate.extents))) + # Every memory address the NPU access have to be 16 byte aligned + size_in_bytes = util.round_up(size_in_bytes, 16) + _allocate_node_sizes[allocate.buffer_var] = size_in_bytes + + tvm.tir.stmt_functor.post_order_visit(primfunc.body, analyze_remaining_allocates) + return _allocate_node_sizes + def verify(npu_ops): """This wrapper verifies the allocated addresses matches with original tir buffers""" checked_buffers = set() @@ -933,22 +954,29 @@ def check_buffer(address, region, length, buffer_var): tir_mod = test_case["tir_module"] tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u")) tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod) + candidate_regions_for_scratch = [5, 2, 1] + ( + scratch_region_map, + tvmbaw_workspace_size, + _, + ) = tir_to_cs_translator.analyze_scratch_memory_acesses( + tir_mod, candidate_regions_for_scratch + ) + allocate_node_sizes = _get_allocate_node_sizes(tir_mod) buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod, test_case["param_dict"]) extern_calls = extract_call_extern_list(tir_mod) _npu_ops = list() for extern_call in extern_calls: _npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_call_extern(extern_call)) npu_op_tir_buffers = collect_tir_buffer_info(_npu_ops) - ( - _npu_ops, - constant_hex_string, - scratch_size, - ) = tir_to_cs_translator.assign_addresses(buffer_info, _npu_ops) - scratch_mask = np.zeros(scratch_size, dtype="uint8") + (_npu_ops, constant_hex_string) = tir_to_cs_translator.assign_addresses( + buffer_info, _npu_ops, scratch_region_map + ) + tvmbaw_workspace_mask = np.zeros(tvmbaw_workspace_size, dtype="uint8") constant_tensor_read_mask = np.zeros(len(constant_hex_string) // 2, dtype="uint8") verify(_npu_ops) # This will be only 1 if all allocated scratch is used. - assert np.prod(scratch_mask) == 1 + assert np.prod(tvmbaw_workspace_mask) == 1 # This will be only 1 if all constant tensors is read at least once. assert np.prod(constant_tensor_read_mask) == 1