From 44ad353597973c482fdb4d7894ac7c208f862eca Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 25 Nov 2022 07:53:34 +0900 Subject: [PATCH] simplify MultiLevelTilingHexagon --- .../schedule_rule/multi_level_tiling.h | 18 + .../multi_level_tiling_hexagon.cc | 354 ++---------------- .../multi_level_tiling_with_intrin.cc | 17 - 3 files changed, 58 insertions(+), 331 deletions(-) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 98b4634af106c..f7a3f8d7c6120 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -21,6 +21,7 @@ #include #include +#include "../../tir/schedule/transform.h" #include #include @@ -135,6 +136,23 @@ std::vector SubRule(std::vector states, FLambda sub_rule) { return results; } +/*! + * \brief Tile a subset of loops in the block according to the given tensor intrinsic, and annotate + * the tiled block for tensorization by postproc rewrite. + */ +inline Optional TileForIntrin(tir::Schedule sch, tir::BlockRV block, + const std::string& intrin_name) { + Optional tiled_loop_rv = TileWithTensorIntrin(sch, block, intrin_name); + if (!tiled_loop_rv) { + return NullOpt; + } + ICHECK(tiled_loop_rv.defined()); + tir::BlockRV outer_block = sch->Blockize(tiled_loop_rv.value()); + sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, String(intrin_name)); + return outer_block; +} + + /*! * \brief The mega rule: multi-level tiling with data reuse */ diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc index d37b3c1d1c01e..ddae455daff19 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc @@ -16,14 +16,11 @@ * specific language governing permissions and limitations * under the License. */ -#include - -#include -#include -#include +#include "../../tir/schedule/analysis.h" +#include "../../tir/schedule/transform.h" #include "../utils.h" -#include "./multi_level_tiling.h" +#include "multi_level_tiling.h" namespace tvm { namespace meta_schedule { @@ -32,86 +29,10 @@ using tir::BlockRV; using tir::LoopRV; using tir::Schedule; -struct HexagonIntrinGroup { - String compute_intrin; - - /*! \brief Create HexagonIntrinGroup from config in a map. The map should contains the - * following keys: - * - compute - * The values of the keys should be the names of the corresponding intrinsics and should be - * registered via TensorIntrin.Register beforehand. - */ - static HexagonIntrinGroup FromConfig(const Map& config); -}; - -HexagonIntrinGroup HexagonIntrinGroup::FromConfig(const Map& config) { - auto f_initialize_intrin = [&config](String key_name, String* intrin_name) { - CHECK(config.count(key_name)) << "ValueError: " << key_name << " is not set."; - *intrin_name = config.at(key_name); - // Check the existence of the intrin - tir::TensorIntrin::Get(*intrin_name); - }; - HexagonIntrinGroup intrin_group; - f_initialize_intrin("compute", &intrin_group.compute_intrin); - return intrin_group; -} - -class HexagonStateNode : public StateNode { - public: - /*! \brief The hexagon intrinsic group. */ - HexagonIntrinGroup intrin_group; - /*! \brief The auto tensorization maping info. */ - tir::AutoTensorizeMappingInfo mapping_info{nullptr}; - /*! \brief The hexagon reindex block A for hexagon computation */ - tir::BlockRV hexagon_reindex_A; - /*! \brief The hexagon reindex block B for hexagon computation */ - tir::BlockRV hexagon_reindex_B; - /*! \brief The hexagon reindex store block for hexagon computation */ - tir::BlockRV hexagon_reindex_store; - - State Copy() const final; - - static constexpr const char* _type_key = "meta_schedule.TensorCoreState"; - TVM_DECLARE_FINAL_OBJECT_INFO(HexagonStateNode, StateNode); -}; - -class HexagonState : public State { - public: - explicit HexagonState(HexagonIntrinGroup intrin_group, tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, BlockRV block_rv, Array> tiles = {}); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(HexagonState, State, HexagonStateNode); -}; - -TVM_REGISTER_OBJECT_TYPE(HexagonStateNode); - -HexagonState::HexagonState(HexagonIntrinGroup intrin_group, - tir::AutoTensorizeMappingInfo mapping_info, Schedule sch, - BlockRV block_rv, Array> tiles) { - ObjectPtr node = make_object(); - node->intrin_group = intrin_group; - node->mapping_info = mapping_info; - node->sch = std::move(sch); - node->block_rv = std::move(block_rv); - node->tiles = std::move(tiles); - data_ = std::move(node); -} - -State HexagonStateNode::Copy() const { - ObjectPtr node = make_object(*this); - node->sch = sch->Copy(); - return State(node); -} - -/*! - * \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of hexagon - * intrinsics. - */ class MultiLevelTilingHexagonNode : public MultiLevelTilingNode { private: - // SubRule: Add tensorization-related transformations - inline std::vector TransformForTensorization(HexagonState state) const; // Subrule: Add software pipeline - inline std::vector AddSoftwarePipeline(HexagonState state) const; + inline std::vector AddSoftwarePipeline(State state) const; // Override ApplySubRules to apply tensorization-specific sub-rules std::vector ApplySubRules(std::vector states) final; @@ -121,31 +42,11 @@ class MultiLevelTilingHexagonNode : public MultiLevelTilingNode { // Inherited from ScheduleRuleNode ScheduleRule Clone() const final { - ObjectPtr n = - make_object(*this); + ObjectPtr n = make_object(*this); return ScheduleRule(n); } - /*! - * \brief Transform and tensorize with the given tensor intrin - * \param state The state of the meta schedule rule - * \param intrin_name The name of the tensor intrin - * \return The loop to be tensorized. NullOpt if the workload can't be tensorized. - */ - Optional TransformWithTensorIntrin(HexagonStateNode* state, - const String& intrin_name) const; - - /*! - * \brief Tile, blockize and annotate for tensorization with the given intrin - * \param block_rv The block to be tensorized - * \param intrin_name The name of the tensor intrin - */ - void TileAndAnnotateTensorize(Schedule* sch, const BlockRV& block_rv, - const String& intrin_name) const; - public: - /*! \brief The candidate hexagon intrin groups to apply */ - std::vector intrin_groups; /*! \brief Whether to use software pipeline */ bool use_software_pipeline = false; static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingHexagon"; @@ -154,76 +55,40 @@ class MultiLevelTilingHexagonNode : public MultiLevelTilingNode { private: }; -// Entry of the mega rule; Inherited from ScheduleRuleNode -Array MultiLevelTilingHexagonNode::Apply(const Schedule& sch, const BlockRV& block_rv) { - if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { +Array MultiLevelTilingHexagonNode::Apply(const tir::Schedule& sch, + const tir::BlockRV& block_rv) { + auto intrin_name = "dot_32x4_u8u8i32_vtcm_vrmpy"; + auto desc_func = tir::TensorIntrin::Get(intrin_name).value()->desc; + if (!CheckAutoTensorizeApplicable(sch, block_rv, desc_func)) { + TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized."; return {sch}; } - std::unordered_map intrin_group_to_mapping_info; - for (int i = 0, n = intrin_groups.size(); i < n; ++i) { - HexagonIntrinGroup intrin_group = intrin_groups[i]; - Optional mapping_info = tir::GetAutoTensorizeMappingInfo( - sch->state(), sch->GetSRef(block_rv), - tir::TensorIntrin::Get(intrin_groups[i].compute_intrin).value()->desc); - if (mapping_info.defined()) { - intrin_group_to_mapping_info.emplace(i, mapping_info.value()); - } - } + auto res = MultiLevelTilingNode::Apply(sch->Copy(), block_rv); - if (intrin_group_to_mapping_info.empty()) { - // No tensor intrinsics can be applied. + if (res.empty()) { + TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized."; return {sch}; } - - // Save the original schedule so that we can roll back transformations if tensorization - // fails. - Schedule original_sch = sch; - - std::vector initial_states; - for (const auto& kv : intrin_group_to_mapping_info) { - const HexagonIntrinGroup& intrin_group = intrin_groups[kv.first]; - const tir::AutoTensorizeMappingInfo& mapping_info = kv.second; - Schedule new_sch = sch->Copy(); - new_sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); - initial_states.push_back(HexagonState(intrin_group, mapping_info, new_sch, block_rv)); - } - Array results; - for (auto&& state : ApplySubRules(initial_states)) { - TVM_PY_LOG(INFO, logger) << "Sketch " << results.size() << ": tensorizing with " - << state.as()->intrin_group.compute_intrin; - results.push_back(std::move(state->sch)); - } - if (results.empty()) { - return {original_sch}; - } - return results; + TVM_PY_LOG(INFO, logger) << "Tensorizing with " << intrin_name; + return res; } std::vector MultiLevelTilingHexagonNode::ApplySubRules(std::vector states) { + auto intrin_name = "dot_32x4_u8u8i32_vtcm_vrmpy"; states = SubRule(std::move(states), [&](State state) { - return TransformForTensorization(Downcast(state)); - }); - states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); - states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); - states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); }); - states = SubRule(std::move(states), [&](State state) { - return AddSoftwarePipeline(Downcast(state)); + if (auto block_rv = TileForIntrin(state->sch, state->block_rv, intrin_name)) { + state->block_rv = block_rv.value(); + return std::vector(1, state); + } + return std::vector(); }); - return states; -} -void MultiLevelTilingHexagonNode::TileAndAnnotateTensorize(Schedule* sch, - const BlockRV& block_rv, - const String& intrin_name) const { - Optional loop = TileWithTensorIntrin(*sch, block_rv, intrin_name).value(); - ICHECK(loop.defined()); - BlockRV blockized_outer = (*sch)->Blockize(loop.value()); - (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name); + states = MultiLevelTilingNode::ApplySubRules(states); + return SubRule(std::move(states), [&](State state) { return AddSoftwarePipeline((state)); }); } -std::vector MultiLevelTilingHexagonNode::AddSoftwarePipeline( - HexagonState state) const { +std::vector MultiLevelTilingHexagonNode::AddSoftwarePipeline(State state) const { if (!use_software_pipeline) { return {state}; } @@ -251,15 +116,15 @@ std::vector MultiLevelTilingHexagonNode::AddSoftwarePipeline( size_t cache_read_count = state->read_reuse.size(); if (cache_read_count > 2 || cache_read_count == 0) { return {state}; - } - - // Add annotations for software pipelining at the loop right above the cache read stages. + } + + // Add annotations for software pipelining at the loop right above the cache read stages. tir::BlockRV cache_read_block = state->read_reuse.begin()->second; Array cache_read_loops = sch->GetLoops(cache_read_block); - Array software_pipeline_stage; - Array software_pipeline_order; - Array software_pipeline_async_stages; - if(cache_read_count == 2) { + Array software_pipeline_stage; + Array software_pipeline_order; + Array software_pipeline_async_stages; + if (cache_read_count == 2) { software_pipeline_stage = Array{0, 0, 1}; software_pipeline_order = Array{0, 1, 2}; software_pipeline_async_stages = Array{0}; @@ -268,153 +133,18 @@ std::vector MultiLevelTilingHexagonNode::AddSoftwarePipeline( software_pipeline_order = Array{0, 1}; software_pipeline_async_stages = Array{0}; } - sch->Annotate(cache_read_loops[cache_read_loops.size() - 2], tir::attr::software_pipeline_stage, software_pipeline_stage); - sch->Annotate(cache_read_loops[cache_read_loops.size() - 2], tir::attr::software_pipeline_order, software_pipeline_order); - sch->Annotate(cache_read_loops[cache_read_loops.size() - 2], tir::attr::software_pipeline_async_stages, software_pipeline_async_stages); - + sch->Annotate(cache_read_loops[cache_read_loops.size() - 2], tir::attr::software_pipeline_stage, + software_pipeline_stage); + sch->Annotate(cache_read_loops[cache_read_loops.size() - 2], tir::attr::software_pipeline_order, + software_pipeline_order); + sch->Annotate(cache_read_loops[cache_read_loops.size() - 2], + tir::attr::software_pipeline_async_stages, software_pipeline_async_stages); + // TODO: Add support for nested async pipelines. - // TODO: Add support for async cache writes. + // TODO: Add support for async cache writes. return {state}; } -Optional MultiLevelTilingHexagonNode::TransformWithTensorIntrin( - HexagonStateNode* state, const String& intrin_name) const { - BlockRV block_rv = state->block_rv; - const tir::AutoTensorizeMappingInfo& mapping_info = state->mapping_info; - tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv); - - // Add reindex stages - const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - // Hold the reference of the block before reindex - const tir::Block block_before_reindex = GetRef(block); - if (block->reads.size() != 2 || block->writes.size() != 1) { - // only matmul-like computation is allowed - return NullOpt; - } - state->hexagon_reindex_store = - state->sch->ReIndex(state->block_rv, 0, tir::BufferIndexType::kWrite); - state->hexagon_reindex_A = - state->sch->ReIndex(state->block_rv, 0, tir::BufferIndexType::kRead); - state->hexagon_reindex_B = - state->sch->ReIndex(state->block_rv, 1, tir::BufferIndexType::kRead); - - // Transform the layout of reindex buffers accordingly. - // The index map defines the mapping for the computation block. We need to extract the sub index - // map to transform the load and store block. - ICHECK_EQ(mapping_info->mappings.size(), 1U); // assume only one mapping is present - const tir::IndexMap& index_map = mapping_info->mappings[0]; - - // Find the correspondence between block iters and the iters in the index map. - std::unordered_map lhs_to_index_map_src; - std::unordered_map rhs_to_index_map_tgt; - std::unordered_set unmapped_index_map_src; - ICHECK_EQ(mapping_info->lhs_iters.size(), index_map->initial_indices.size()); - for (int i = 0; i < static_cast(mapping_info->lhs_iters.size()); ++i) { - lhs_to_index_map_src[mapping_info->lhs_iters[i]->var] = index_map->initial_indices[i]; - } - // The number of result iters in the index map is equal or more than the number of rhs (the - // tensor intrin) iters. When there are extra iters, these iters represent unmapped iters from - // the lhs. They will be skipped during pattern matching for tensorization. An example of such - // case is batch matmul, the batch dimension is kept after layout transformations and it will be - // kept as a outer loop after tensorization. - int offset = static_cast(index_map->final_indices.size()) - - static_cast(mapping_info->rhs_iters.size()); - ICHECK_GE(offset, 0); - for (int i = 0; i < offset; ++i) { - const tir::VarNode* var_ptr = index_map->final_indices[i].as(); - ICHECK(var_ptr != nullptr); - unmapped_index_map_src.insert(GetRef(var_ptr)); - } - for (int i = offset; i < static_cast(index_map->final_indices.size()); ++i) { - rhs_to_index_map_tgt[mapping_info->rhs_iters[i - offset]->var] = index_map->final_indices[i]; - } - - auto f_get_sub_index_map = [&](const tir::Buffer& lhs_buffer, const tir::Region& lhs_region) { - std::vector sub_index_map_src; - std::vector sub_index_map_tgt; - const tir::Buffer& rhs_buffer = mapping_info->lhs_buffer_map[lhs_buffer]; - for (const Range& range : lhs_region) { - ICHECK(tir::is_one(range->extent)); - const tir::VarNode* var_ptr = range->min.as(); - ICHECK(var_ptr != nullptr); - const tir::Var& lhs_representer = lhs_to_index_map_src[GetRef(var_ptr)]; - sub_index_map_src.push_back(lhs_representer); - if (unmapped_index_map_src.count(lhs_representer)) { - sub_index_map_tgt.push_back(lhs_representer); - } - } - for (size_t i = 0; i < mapping_info->rhs_buffer_indices[rhs_buffer].size(); ++i) { - const tir::VarNode* var = mapping_info->rhs_buffer_indices[rhs_buffer][i].as(); - ICHECK(var != nullptr); - sub_index_map_tgt.push_back(rhs_to_index_map_tgt[GetRef(var)]); - } - return tir::IndexMap(sub_index_map_src, sub_index_map_tgt); - }; - - std::unordered_set visited_buffers; - - Map buffer_sub_index_map; // cache of the sub index map associated - // with each buffer - - auto f_transform_buffer_layout = [&](tir::BufferIndexType index_type, int buffer_index) { - const tir::Buffer& lhs_buffer = tir::GetNthAccessBuffer( - state->sch->state(), block_before_reindex, buffer_index, index_type); - if (visited_buffers.count(lhs_buffer)) { - return; - } - visited_buffers.insert(lhs_buffer); - // Refresh block pointer (block sref is not invalidated) - block = TVM_SREF_TO_BLOCK(block_sref); - const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion( - state->sch->state(), GetRef(block), buffer_index, index_type); - auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); - buffer_sub_index_map.Set(lhs_buffer, sub_index_map); - state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, NullOpt); - }; - - for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) { - f_transform_buffer_layout(tir::BufferIndexType::kRead, i); - } - for (int i = 0, n = block_before_reindex->writes.size(); i < n; ++i) { - f_transform_buffer_layout(tir::BufferIndexType::kWrite, i); - } - - // Transform the layout of current block and reindex blocks - auto f_transform_reindex_block_layout = [&](const BlockRV& block_rv, - tir::BufferIndexType buffer_type) { - tir::Buffer buffer = - tir::GetNthAccessBuffer(state->sch->state(), state->sch->Get(block_rv), 0, buffer_type); - const auto& sub_index_map = buffer_sub_index_map.at(buffer); - state->sch->TransformBlockLayout(block_rv, sub_index_map); - }; - f_transform_reindex_block_layout(state->hexagon_reindex_store, tir::BufferIndexType::kWrite); - f_transform_reindex_block_layout(state->hexagon_reindex_A, tir::BufferIndexType::kRead); - f_transform_reindex_block_layout(state->hexagon_reindex_B, tir::BufferIndexType::kRead); - state->sch->TransformBlockLayout(state->block_rv, index_map); - return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name, - /*allow_padding=*/true); -} - -inline std::vector MultiLevelTilingHexagonNode::TransformForTensorization( - HexagonState state) const { - // Do reindex and layout transformations. - Optional transformed_loop_rv = - TransformWithTensorIntrin(state.operator->(), state->intrin_group.compute_intrin); - if (!transformed_loop_rv.defined()) { - // The workload can't be tensorized. - return {}; - } - - // Do blockize - state->block_rv = state->sch->Blockize(transformed_loop_rv.value()); - - // Add annotations for post processors. - state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize, - state->intrin_group.compute_intrin); - // state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Integer(1)); - return {std::move(state)}; -} - ScheduleRule ScheduleRule::MultiLevelTilingHexagon( Array> intrin_groups, String structure, Optional> tile_binds, Optional max_innermost_factor, Optional> vector_load_lens, @@ -424,10 +154,6 @@ ScheduleRule ScheduleRule::MultiLevelTilingHexagon( auto node = MultiLevelTilingInitCommon( structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); - node->intrin_groups.reserve(intrin_groups.size()); - for (const auto& intrin_group_config : intrin_groups) { - node->intrin_groups.emplace_back(HexagonIntrinGroup::FromConfig(intrin_group_config)); - } node->use_software_pipeline = use_software_pipeline; return ScheduleRule(node); } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc index 428a1206a4ca1..2a29d5313b984 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -18,29 +18,12 @@ */ #include "../../tir/schedule/analysis.h" -#include "../../tir/schedule/transform.h" #include "../utils.h" #include "multi_level_tiling.h" namespace tvm { namespace meta_schedule { -/*! - * \brief Tile a subset of loops in the block according to the given tensor intrinsic, and annotate - * the tiled block for tensorization by postproc rewrite. - */ -Optional TileForIntrin(tir::Schedule sch, tir::BlockRV block, - const std::string& intrin_name) { - Optional tiled_loop_rv = TileWithTensorIntrin(sch, block, intrin_name); - if (!tiled_loop_rv) { - return NullOpt; - } - ICHECK(tiled_loop_rv.defined()); - tir::BlockRV outer_block = sch->Blockize(tiled_loop_rv.value()); - sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, String(intrin_name)); - return outer_block; -} - /*! * \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic. */