From c6c89c3a252b92c379d79b9b8fdbccf2339b131c Mon Sep 17 00:00:00 2001 From: Adam Straw Date: Mon, 20 Mar 2023 06:03:37 -0700 Subject: [PATCH] [Hexagon] Add concept of DMA groups (#14254) * [Hexagon] Add concept of groups to QueuedRingBuffer * fix failing RingBuffer unit test * add StartGroup and EndGroup builtins and lowering * use max queue number to increase dma copy perf * increase max dma queues to fix test failures * use DMA groups in LowerAsyncDMA pass Co-authored-by: Noah Verke Co-authored-by: Eric Lunderberg * elide merge_async_commit_queue_scope * LowerAsyncDMA bug fix + disallow non-contig copy; comments, tests pass * format and lint * add comments to Hex User DMA header * use unsigned queue ID; fix test fails * format and lint * use dma_copy_dltensor in TIR tenor intrin; fix test fails * address feedback: comments, types, names --- include/tvm/meta_schedule/postproc.h | 3 +- include/tvm/tir/builtin.h | 37 +++- .../disallow_async_strided_mem_copy.py | 11 +- python/tvm/tir/tensor_intrin/hexagon.py | 35 +-- src/driver/driver_api.cc | 1 - .../disallow_async_strided_mem_copy.cc | 8 +- src/runtime/hexagon/hexagon_device_api.cc | 19 +- src/runtime/hexagon/hexagon_user_dma.cc | 14 +- src/runtime/hexagon/hexagon_user_dma.h | 31 ++- src/runtime/hexagon/ring_buffer.h | 76 ++++++- src/tir/op/builtin.cc | 6 + .../transforms/inject_software_pipeline.cc | 35 ++- src/tir/transforms/lower_async_dma.cc | 172 ++++++--------- src/tir/transforms/lower_tvm_builtin.cc | 26 +++ .../hexagon/hexagon_user_dma_tests.cc | 2 +- .../cpp-runtime/hexagon/ring_buffer_tests.cc | 203 +++++++++++++++++- .../metaschedule_e2e/test_resnet50_int8.py | 1 - .../test_hexagon/test_async_dma_pipeline.py | 6 +- .../test_software_pipeline_async.py | 1 - 19 files changed, 488 insertions(+), 199 deletions(-) diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 85fb9003e87f..f297ca090482 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -111,10 +111,9 @@ class Postproc : public runtime::ObjectRef { TVM_DLL static Postproc DisallowDynamicLoop(); /*! * \brief Create a postprocessor that checks if all async mem copies are not strided. - * \param merge_async_commit_queue_scope Whether or not to merge async commit queue scope. * \return The postprocessor created */ - TVM_DLL static Postproc DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope = true); + TVM_DLL static Postproc DisallowAsyncStridedMemCopy(); /*! * \brief Create a postprocessor that rewrites the cooperative fetch annotation to * actual vectorized cooperative fetching in loop bindings. diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 708abde2cd31..e8bcc028fc58 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -727,14 +727,49 @@ TVM_DLL const Op& texture2d_load(); /*! * \brief Initiate a non-blocking DMA copy from source to destination + * + * The copy is launched immediately. + * + * If a `dma_start_group()` call is active, the copy will be added + * to the current group for tracking of in-flight group counts. + * + * If no `dma_start_group()` call is active, the copy will be tracked + * individually i.e. as a group with size 1. */ TVM_DLL const Op& dma_copy(); /*! - * \brief Wait until the number of DMAs in flight is less than or equal to some maximum + * \brief Wait until the number of DMA groups in flight is less than + * or equal to some maximum + * + * Calling `dma_wait()` while a group is active is unsupported. */ TVM_DLL const Op& dma_wait(); +/*! + * \brief Start a group of DMA copies + * + * Any call to `dma_copy()` that occurs after `dma_start_group()` will + * be added to the current group for tracking of in-flight group counts. + * + * Only one DMA group may be active at a given time. Calling + * `dma_start_group()` while a group is active is unsupported. + */ +TVM_DLL const Op& dma_start_group(); + +/*! + * \brief End a group of DMA copies + * + * Track all calls to `dma_copy()` that occurred since the preceding + * `dma_start_group()` as a single group in-flight. + * + * Calling `dma_end_group()` without an active group is unsupported. + * + * Note: A group of DMA calls may be empty, and will still contribute + * to the count of in-flight groups used by `dma_wait()`. + */ +TVM_DLL const Op& dma_end_group(); + /*! * \brief Provide a true statement that can be used for simplifications * diff --git a/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py b/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py index 7e0e00de2949..0dcff9bf45a3 100644 --- a/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py +++ b/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py @@ -23,16 +23,9 @@ @register_object("meta_schedule.DisallowAsyncStridedMemCopy") class DisallowAsyncStridedMemCopy(Postproc): - """A postprocessor that disallows schedules that use async strided mem copies. + """A postprocessor that disallows schedules that use async strided mem copies.""" - Parameters - ---------- - merge_async_commit_queue_scope : bool - Whether or not to merge the async commit queue scope. - """ - - def __init__(self, merge_async_commit_queue_scope=True) -> None: + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocDisallowAsyncStridedMemCopy, # type: ignore # pylint: disable=no-member - merge_async_commit_queue_scope, ) diff --git a/python/tvm/tir/tensor_intrin/hexagon.py b/python/tvm/tir/tensor_intrin/hexagon.py index 7a348f3f1a45..22dd9a977c65 100644 --- a/python/tvm/tir/tensor_intrin/hexagon.py +++ b/python/tvm/tir/tensor_intrin/hexagon.py @@ -47,20 +47,27 @@ def sync_dma_load_impl(a: T.handle, c: T.handle) -> None: T.writes(C[0:size]) T.evaluate( T.tvm_call_packed( - "device_api.hexagon.dma_copy", - -1, # Use QueueId of -1 to not interfere with async copies. - T.address_of(C[0], dtype="handle"), - T.address_of(A[0], dtype="handle"), - size, - 0, # Do not use experimental bypass mode. - dtype="int32", - ) - ) - T.evaluate( - T.tvm_call_packed( - "device_api.hexagon.dma_wait", - -1, - 0, # Wait for the sync queue (-1) to have 0 messages. + "device_api.hexagon.dma_copy_dltensor", + T.tvm_stack_make_array( + T.address_of(C[0], dtype="handle"), + T.tvm_stack_make_shape(size, dtype="handle"), + 0, + 1, + C.dtype, + 0, + dtype="handle", + ), + T.tvm_stack_make_array( + T.address_of(A[0], dtype="handle"), + T.tvm_stack_make_shape(size, dtype="handle"), + 0, + 1, + A.dtype, + 0, + dtype="handle", + ), + T.cast(size, dtype="int"), + False, # Do not use experimental bypass mode. dtype="int32", ) ) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index da1bbc296a49..7f9bff66bf6b 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -52,7 +52,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index 952810a47aee..2d1507c8994d 100644 --- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -145,9 +145,6 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { pass_list.push_back(tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::VectorizeLoop(true)); pass_list.push_back(tir::transform::StorageRewrite()); - transform::PassContext pass_ctx = transform::PassContext::Current(); - pass_ctx->config.Set("tir.merge_async_commit_queue_scope", - Bool(merge_async_commit_queue_scope)); tir::PrimFunc f = WithAttr(GetRef(prim_func), "global_symbol", runtime::String(g_var->name_hint)); IRModule mod = IRModule(Map({{GlobalVar(g_var->name_hint), f}})); @@ -169,15 +166,12 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { return Postproc(n); } - bool merge_async_commit_queue_scope = true; - static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy"; TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode); }; -Postproc Postproc::DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope) { +Postproc Postproc::DisallowAsyncStridedMemCopy() { ObjectPtr n = make_object(); - n->merge_async_commit_queue_scope = merge_async_commit_queue_scope; return Postproc(n); } diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index ee2a826b02ea..16e67aa9650f 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -210,10 +210,10 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor") }); TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVMRetValue* rv) { - int queue_id = args[0]; + uint32_t queue_id = static_cast(args[0]); void* dst = args[1]; void* src = args[2]; - int size = args[3]; + uint32_t size = static_cast(args[3]); ICHECK(size > 0); bool bypass_cache = args[4]; @@ -226,13 +226,26 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVM }); TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVMRetValue* rv) { - int queue_id = args[0]; + uint32_t queue_id = static_cast(args[0]); int inflight = args[1]; ICHECK(inflight >= 0); HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight); *rv = static_cast(0); }); +TVM_REGISTER_GLOBAL("device_api.hexagon.dma_start_group") + .set_body([](TVMArgs args, TVMRetValue* rv) { + uint32_t queue_id = static_cast(args[0]); + HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id); + *rv = static_cast(0); + }); + +TVM_REGISTER_GLOBAL("device_api.hexagon.dma_end_group").set_body([](TVMArgs args, TVMRetValue* rv) { + uint32_t queue_id = static_cast(args[0]); + HexagonDeviceAPI::Global()->UserDMA()->EndGroup(queue_id); + *rv = static_cast(0); +}); + TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVMRetValue* rv) { int32_t device_type = args[0]; int32_t device_id = args[1]; diff --git a/src/runtime/hexagon/hexagon_user_dma.cc b/src/runtime/hexagon/hexagon_user_dma.cc index c30fd645bbd0..11214a46e809 100644 --- a/src/runtime/hexagon/hexagon_user_dma.cc +++ b/src/runtime/hexagon/hexagon_user_dma.cc @@ -32,7 +32,8 @@ unsigned int HexagonUserDMA::Init() { return status; } -int HexagonUserDMA::Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache) { +int HexagonUserDMA::Copy(uint32_t queue_id, void* dst, void* src, uint32_t length, + bool bypass_cache) { // length limited to 24 bits if (length > DESC_LENGTH_MASK) { return DMA_FAILURE; @@ -103,15 +104,15 @@ int HexagonUserDMA::Copy(int queue_id, void* dst, void* src, uint32_t length, bo return DMA_SUCCESS; } -void HexagonUserDMA::Wait(int queue_id, uint32_t max_dmas_in_flight) { +void HexagonUserDMA::Wait(uint32_t queue_id, uint32_t max_dmas_in_flight) { // wait (forever) until max DMAs in flight <= actual DMAs in flight - while (DMAsInFlight(queue_id) > max_dmas_in_flight) { + while (DMAGroupsInFlight(queue_id) > max_dmas_in_flight) { } } -uint32_t HexagonUserDMA::Poll(int queue_id) { return DMAsInFlight(queue_id); } +uint32_t HexagonUserDMA::Poll(uint32_t queue_id) { return DMAGroupsInFlight(queue_id); } -uint32_t HexagonUserDMA::DMAsInFlight(int queue_id) { +uint32_t HexagonUserDMA::DMAGroupsInFlight(uint32_t queue_id) { dmpoll(); // update DMA engine status return descriptors_->InFlight(queue_id); } @@ -125,7 +126,8 @@ HexagonUserDMA::HexagonUserDMA() { unsigned int done = dma_desc_get_done(dma_desc); return (done != DESC_DONE_COMPLETE); }; - descriptors_ = new QueuedRingBuffer(MAX_DMA_DESCRIPTORS, desc_in_flight); + descriptors_ = + new QueuedRingBuffer(MAX_DMA_QUEUES, MAX_DMA_DESCRIPTORS, desc_in_flight); } HexagonUserDMA::~HexagonUserDMA() { diff --git a/src/runtime/hexagon/hexagon_user_dma.h b/src/runtime/hexagon/hexagon_user_dma.h index 9397a16e3f03..70590cbe4faa 100644 --- a/src/runtime/hexagon/hexagon_user_dma.h +++ b/src/runtime/hexagon/hexagon_user_dma.h @@ -34,7 +34,8 @@ namespace hexagon { #define DMA_FAILURE -1 #define DMA_RETRY 1 #define MAX_DMA_DESCRIPTORS 100 -#define SYNC_DMA_QUEUE -1 +#define MAX_DMA_QUEUES 10 +#define SYNC_DMA_QUEUE MAX_DMA_QUEUES - 1 class HexagonUserDMA { public: @@ -47,32 +48,50 @@ class HexagonUserDMA { /*! * \brief Initiate DMA to copy memory from source to destination address + * \param queue_id The virtual DMA queue * \param dst Destination address * \param src Source address * \param length Length in bytes to copy * \returns Status: DMA_SUCCESS or DMA_FAILURE */ - int Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache); + int Copy(uint32_t queue_id, void* dst, void* src, uint32_t length, bool bypass_cache); /*! * \brief Wait until the number of DMAs in flight is less than or equal to some maximum + * \param queue_id The virtual DMA queue * \param max_dmas_in_flight Maximum number of DMAs allowed to be in flight * to satisfy the `Wait` e.g. use `Wait(0)` to wait on "all" outstanding DMAs to complete */ - void Wait(int queue_id, uint32_t max_dmas_in_flight); + void Wait(uint32_t queue_id, uint32_t max_dmas_in_flight); /*! * \brief Poll the number of DMAs in flight + * \param queue_id The virtual DMA queue * \returns Number of DMAs in flight */ - uint32_t Poll(int queue_id); + uint32_t Poll(uint32_t queue_id); + + /*! + * \brief Start a group of DMA copies + * \param queue_id The virtual DMA queue + */ + void StartGroup(uint32_t queue_id) { descriptors_->StartGroup(queue_id); } + + /*! + * \brief End a group of DMA copies + * \param queue_id The virtual DMA queue + */ + void EndGroup(uint32_t queue_id) { descriptors_->EndGroup(queue_id); } private: //! \brief Initializes the Hexagon User DMA engine unsigned int Init(); - //! \brief Calculates and returns the number of DMAs in flight - uint32_t DMAsInFlight(int queue_id); + /*! + * \brief Calculates and returns the number of DMAs in flight + * \param queue_id The virtual DMA queue + */ + uint32_t DMAGroupsInFlight(uint32_t queue_id); //! \brief Tracks whether the very first DMA has been executed bool first_dma_ = true; diff --git a/src/runtime/hexagon/ring_buffer.h b/src/runtime/hexagon/ring_buffer.h index 4294ded8f52a..91adad6a65e5 100644 --- a/src/runtime/hexagon/ring_buffer.h +++ b/src/runtime/hexagon/ring_buffer.h @@ -21,6 +21,8 @@ #define TVM_RUNTIME_HEXAGON_RING_BUFFER_H_ #include +#include +#include #include #include "hexagon_common.h" @@ -94,17 +96,33 @@ class RingBuffer { template class QueuedRingBuffer : RingBuffer { public: - QueuedRingBuffer(uint32_t ring_buff_size, std::function in_flight) - : RingBuffer(ring_buff_size, in_flight) {} + QueuedRingBuffer(uint32_t max_queues, uint32_t ring_buff_size, std::function in_flight) + : RingBuffer(ring_buff_size, in_flight), max_queues_(max_queues) { + queue_descriptors_.resize(max_queues_); + } //! \brief Returns pointer to next T; add the queue ID for tracking - T* Next(int queue_id) { + T* Next(uint32_t queue_id) { + CHECK_LT(queue_id, max_queues_); queue_ids_.push_back(queue_id); + queue_descriptor* d = &queue_descriptors_[queue_id]; + if (d->group_started) { + // if we have a group started just update then pending count + d->pending_in_group++; + } else { + // else create group with size one + d->groups.push(1); + d->pending_total++; + } return RingBuffer::Next(); } - //! \brief Returns the number of Ts in flight for a given queue ID - uint32_t InFlight(int queue_id) { + //! \brief Returns the number of groups of Ts in flight for a given queue ID + uint32_t InFlight(uint32_t queue_id) { + CHECK_LT(queue_id, max_queues_); + queue_descriptor* d = &queue_descriptors_[queue_id]; + CHECK(!d->group_started); + uint32_t in_flight = 0; // look at the queue IDs for the RingBuffer entries in flight for (size_t i = queue_ids_.size() - RingBuffer::InFlight(); i < queue_ids_.size(); ++i) { @@ -113,11 +131,57 @@ class QueuedRingBuffer : RingBuffer { in_flight++; } } - return in_flight; + + // calculate number of groups in flight + while (!d->groups.empty() && d->pending_total - d->groups.front() >= in_flight) { + d->pending_total -= d->groups.front(); + d->groups.pop(); + } + + // return the number of groups in flight + return d->groups.size(); + } + + //! \brief Start a group of Ts, if not called the deafault group size is one + void StartGroup(uint32_t queue_id) { + CHECK_LT(queue_id, max_queues_); + queue_descriptor* d = &queue_descriptors_[queue_id]; + CHECK(!d->group_started); + + // start group + d->group_started = true; + d->pending_in_group = 0; + } + + //! \brief End a group of Ts + void EndGroup(uint32_t queue_id) { + CHECK_LT(queue_id, max_queues_); + queue_descriptor* d = &queue_descriptors_[queue_id]; + CHECK(d->group_started); + CHECK(d->pending_in_group); + + // create group + if (d->pending_in_group) { + d->groups.emplace(d->pending_in_group); + } + d->pending_total += d->pending_in_group; + + // end group + d->group_started = false; + d->pending_in_group = 0; } private: + struct queue_descriptor { + uint32_t pending_total = 0; + uint32_t pending_in_group = 0; + bool group_started = false; + std::queue groups; + }; + + const int max_queues_; std::vector queue_ids_; + std::vector queue_descriptors_; }; } // namespace hexagon diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index e240b7b701ba..c85590428450 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -335,6 +335,12 @@ TIR_DEFINE_BUILTIN_FUNC(dma_copy).set_attr("TCallEffectKind", TIR_DEFINE_BUILTIN_FUNC(dma_wait).set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(dma_start_group) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(dma_end_group) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(assume) .set_attr("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo)) .set_num_inputs(1); diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 51523a37399b..7e6739a512cb 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -309,10 +309,9 @@ class PipelineRewriter : public StmtExprMutator { const Array pipeline_allocs, const For& pipeline_loop, const PipelineInfo& pipeline_info, const std::unordered_map& fragment_info, - const Map preserved_annotations, bool merge_async_commit_queue_scope) { + const Map preserved_annotations) { PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop, - pipeline_info, fragment_info, preserved_annotations, - merge_async_commit_queue_scope); + pipeline_info, fragment_info, preserved_annotations); return rewriter.BuildPipeline(); } @@ -322,8 +321,7 @@ class PipelineRewriter : public StmtExprMutator { const Array& pipeline_allocs, const For& pipeline_loop, const PipelineInfo& pipeline_info, const std::unordered_map& fragment_info, - const Map preserved_annotations, - bool merge_async_commit_queue_scope) + const Map preserved_annotations) : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), double_buffers_(double_buffers), @@ -331,8 +329,7 @@ class PipelineRewriter : public StmtExprMutator { pipeline_loop_(pipeline_loop), pipeline_info_(pipeline_info), fragment_info_(fragment_info), - preserved_annotations_(preserved_annotations), - merge_async_commit_queue_scope_(merge_async_commit_queue_scope) {} + preserved_annotations_(preserved_annotations) {} Stmt BuildPipeline() { // Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions @@ -766,7 +763,7 @@ class PipelineRewriter : public StmtExprMutator { group_bodies.push_back(new_blocks[i].block->body); } - if (merge_async_commit_queue_scope_ && group_bodies.size() > 1) { + if (group_bodies.size() > 1) { auto merged_bodies = SeqStmt(group_bodies); group_bodies.clear(); group_bodies.push_back(merged_bodies); @@ -853,8 +850,7 @@ class PipelineRewriter : public StmtExprMutator { auto& local_state = async_states_local[stage]; int commit_group_id = -1; - if (local_state.commit_groups.empty() || local_state.consumed || - !merge_async_commit_queue_scope_) { + if (local_state.commit_groups.empty() || local_state.consumed) { // consumed == true means there is already a consumer stage waiting for an // eariler async operation of this stage. In such cases, we make multiple commit_queue // for this stage. @@ -954,7 +950,6 @@ class PipelineRewriter : public StmtExprMutator { Array ordered_stmts_; std::map async_states; Map preserved_annotations_; - bool merge_async_commit_queue_scope_ = true; }; /*! @@ -993,8 +988,8 @@ void BuildDependencyGraph( class PipelineInjector : private StmtExprMutator { public: - static Stmt Inject(const PrimFunc& func, bool merge_async_commit_queue_scope) { - PipelineInjector injector(merge_async_commit_queue_scope); + static Stmt Inject(const PrimFunc& func) { + PipelineInjector injector; for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; injector.buffer_data_to_buffer_.Set(buffer->data, buffer); @@ -1004,8 +999,7 @@ class PipelineInjector : private StmtExprMutator { } private: - explicit PipelineInjector(bool merge_async_commit_queue_scope) - : merge_async_commit_queue_scope_(merge_async_commit_queue_scope) {} + PipelineInjector() {} /*! * \brief Check the pipeline satisfies the following conditions: @@ -1140,9 +1134,9 @@ class PipelineInjector : private StmtExprMutator { ValidatePipelineBody(pipeline_info, original_order); // Step 4: Rewrite the pipeline body. - Stmt pipeline = PipelineRewriter::Rewrite( - buffer_data_to_buffer_, double_buffers, pipeline_allocs, GetRef(op), pipeline_info, - fragment_info_, preserved_annotations, merge_async_commit_queue_scope_); + Stmt pipeline = PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers, + pipeline_allocs, GetRef(op), pipeline_info, + fragment_info_, preserved_annotations); if (const auto* realize = op->body.as()) { const auto& block = realize->block; @@ -1211,7 +1205,6 @@ class PipelineInjector : private StmtExprMutator { Map buffer_data_to_buffer_; std::unordered_map fragment_info_; std::unordered_set double_buffers; - bool merge_async_commit_queue_scope_ = true; }; } // namespace software_pipeline @@ -1225,9 +1218,7 @@ namespace transform { Pass InjectSoftwarePipeline() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* fptr = f.CopyOnWrite(); - bool merge_async_commit_queue_scope = - ctx->GetConfig("tir.merge_async_commit_queue_scope", Bool(true)).value(); - fptr->body = software_pipeline::PipelineInjector::Inject(f, merge_async_commit_queue_scope); + fptr->body = software_pipeline::PipelineInjector::Inject(f); fptr->body = ConvertSSA(std::move(fptr->body)); return f; }; diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index 57cff7985f1c..d899b6ec70ab 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -22,26 +22,63 @@ */ #include +#include #include +#include +#include +#include #include #include +#include + +#include "../../arith/ir_mutator_with_analyzer.h" #include "ir_utils.h" namespace tvm { namespace tir { -class AsyncDMALowerer : public StmtExprMutator { +class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { public: - explicit AsyncDMALowerer(bool dma_bypass_cache) : dma_bypass_cache_(dma_bypass_cache) {} + explicit AsyncDMALowerer(bool dma_bypass_cache, arith::Analyzer* analyzer) + : IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {} + + Stmt VisitStmt_(const ForNode* loop) final { + // if for loop is not within async_commit_queue_scope + if (!async_queue_id_.has_value()) { + return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); + } - // Create member statement to track a mapping from iter var to iter range - Stmt VisitStmt_(const ForNode* op) final { - input_iters.Set(op->loop_var, Range(op->min, op->extent)); - return StmtExprMutator::VisitStmt_(op); + // if for loop is not a memcpy of a contiguous region + std::optional mem_copy = IdentifyMemCpy(GetRef(loop), analyzer_); + if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 || + mem_copy->source->region.size() != 1) { + LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access"; + } + + // now that we are about to perform the `copy` transform + // save queue ID for inspection in `wait` transform + // and, increment the number of DMA copies in the group + queue_ids_.insert(async_queue_id_.value()); + dmas_in_group_++; + + tvm::PrimExpr src_min = mem_copy->source->region[0]->min; + tvm::PrimExpr dst_min = mem_copy->dest->region[0]->min; + tvm::PrimExpr dst_extent = mem_copy->dest->region[0]->extent; + + auto src = BufferLoad(mem_copy->source->buffer, {src_min}); + auto dst = BufferLoad(mem_copy->dest->buffer, {dst_min}); + return Evaluate( + Call(DataType::Int(32), builtin::dma_copy(), + {async_queue_id_.value(), Call(DataType::Handle(), builtin::address_of(), {dst}), + Call(DataType::Handle(), builtin::address_of(), {src}), + dst_extent * src->dtype.bytes(), dma_bypass_cache_})); } Stmt VisitStmt_(const AttrStmtNode* op) final { + // populate analyzer knowledge of loop iterators + auto previsit = arith::IRMutatorWithAnalyzer::VisitStmt_(op); + // Convert this, for example: // attr [0] "async_wait_queue_scope" = 0; // attr [0] "async_wait_inflight_count" = 0; @@ -63,7 +100,7 @@ class AsyncDMALowerer : public StmtExprMutator { DLOG(INFO) << "AsyncDMALowerer exiting because the queue ID observed in the " "`async_wait_queue_scope` transform has not been previously observed in the " "`async_commit_queue_scope` transform"; - return StmtExprMutator::VisitStmt_(op); + return previsit; } auto async_wait = op->body.as(); @@ -71,14 +108,13 @@ class AsyncDMALowerer : public StmtExprMutator { DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " "`async_wait_queue_scope` does not contain an `AttrStmtNode` with key " "`async_wait_inflight_count`"; - return StmtExprMutator::VisitStmt_(op); + return previsit; } - auto call_dma_wait = Evaluate(Call(DataType::Int(32), builtin::dma_wait(), {queue_id, async_wait->value})); // concatenate the call with the body and return - return SeqStmt({call_dma_wait, StmtExprMutator::VisitStmt(async_wait->body)}); + return SeqStmt({call_dma_wait, arith::IRMutatorWithAnalyzer::VisitStmt(async_wait->body)}); // Convert this, for example: // attr [0] "async_commit_queue_scope" = 0; @@ -99,112 +135,27 @@ class AsyncDMALowerer : public StmtExprMutator { // get queue ID auto queue_id_node = op->value.as(); ICHECK(queue_id_node); - int queue_id = queue_id_node->value; - - // walk the graph to verify this is a mem copy ... - // 1) async_commit_queue_scope contains async_scope - auto async_scope = op->body.as(); - if (!async_scope || async_scope->attr_key != tir::attr::async_scope) { - DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " - "`async_commit_queue_scope` does not contain an `AttrStmtNode` with key " - "`async_scope`"; - return StmtExprMutator::VisitStmt_(op); - } - - // 2) async_scope contains single for loop - auto for_loop = async_scope->body.as(); - if (!for_loop) { - DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " - "`async_scope` does not contain a single `ForNode`"; - return StmtExprMutator::VisitStmt_(op); - } - - // Add the current loop to the input iters mapping. - input_iters.Set(for_loop->loop_var, Range(for_loop->min, for_loop->extent)); - - // 3) for loop contains buffer store with single index - auto bufferstorenode = for_loop->body.as(); - if (!bufferstorenode || bufferstorenode->indices.size() != 1) { - DLOG(INFO) - << "AsyncDMALowerer exiting because the body of the `ForNode` does not contain a " - "single `BufferStoreNode` with a single index variable"; - return StmtExprMutator::VisitStmt_(op); - } - - // 4) buffer store value is a buffer load with single index - auto bufferloadnode = bufferstorenode->value.as(); - if (!bufferloadnode || bufferloadnode->indices.size() != 1) { - DLOG(INFO) << "AsyncDMALowerer exiting because the value of the `BufferStoreNode` is not a " - "single `BufferLoadNode` with a single index variable"; - return StmtExprMutator::VisitStmt_(op); - } - - // get store buffer; assert it exists and is contiguous given it uses a single index - auto bufferstore = bufferstorenode->buffer.as(); - ICHECK(bufferstore && bufferstore->strides.empty()); - - // get load buffer; assert it exists and is contiguous given it uses a single index - auto bufferload = bufferloadnode->buffer.as(); - ICHECK(bufferload && bufferload->strides.empty()); - - // we will be replacing the entire for loop including its index - // with a DMA copy instrinsic that spans the entire index space of the for loop - // so we will need to replace the for loop index with value zero in the buffer indices - // thus we eliminate the index from the expression so the DMA copy receives the buffer range - // base address - Map loop_var_remap = {{for_loop->loop_var, IntImm(DataType::Int(32), 0)}}; - - // map loop variable to zero for the store index & simplify - Array store_index = bufferstorenode->indices; - - // Use DetectIterMap to detect whether store index is non-contiguous. - arith::Analyzer analyzer; - auto store_iter_map = DetectIterMap(store_index, input_iters, 1, - arith::IterMapLevel::Surjective, &analyzer, false); - if (!store_iter_map->errors.empty()) { - LOG(FATAL) - << "Unable to lower async dma for non contiguous memory access with store index: " - << store_index; - } - - store_index.MutateByApply([&](PrimExpr expr) { - arith::Analyzer analyzer; - return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); - }); - - // map loop variable to zero for the load index & simplify - Array load_index = bufferloadnode->indices; - - // Use DetectIterMap to detect whether load index is non-contiguous. - auto load_iter_map = DetectIterMap(load_index, input_iters, 1, - arith::IterMapLevel::Surjective, &analyzer, false); - if (!load_iter_map->errors.empty()) { - LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with load index: " - << load_index; + async_queue_id_ = queue_id_node->value; + auto result = arith::IRMutatorWithAnalyzer::VisitStmt_(op); + if (dmas_in_group_ > 1) { + auto call_dma_start_group = Evaluate( + Call(DataType::Int(32), builtin::dma_start_group(), {async_queue_id_.value()})); + auto call_dma_end_group = + Evaluate(Call(DataType::Int(32), builtin::dma_end_group(), {async_queue_id_.value()})); + result = SeqStmt({call_dma_start_group, result, call_dma_end_group}); } - load_index.MutateByApply([&](PrimExpr expr) { - arith::Analyzer analyzer; - return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); - }); - - // now that we are about to perform the `copy` transform - // save queue ID for inspection in `wait` transform - queue_ids_.insert(queue_id); - - return Evaluate(Call(DataType::Int(32), builtin::dma_copy(), - {queue_id, - Call(DataType::Handle(), builtin::address_of(), - {BufferLoad(bufferstorenode->buffer, store_index)}), - Call(DataType::Handle(), builtin::address_of(), - {BufferLoad(bufferloadnode->buffer, load_index)}), - for_loop->extent * bufferloadnode->dtype.bytes(), dma_bypass_cache_})); + async_queue_id_ = std::nullopt; + dmas_in_group_ = 0; + return result; } - return StmtExprMutator::VisitStmt_(op); + return arith::IRMutatorWithAnalyzer::VisitStmt_(op); } private: + int dmas_in_group_ = 0; std::set queue_ids_; + std::optional async_queue_id_ = std::nullopt; bool dma_bypass_cache_; Map input_iters = Map(); }; @@ -214,9 +165,10 @@ namespace transform { Pass LowerAsyncDMA() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto fptr = f.CopyOnWrite(); + arith::Analyzer analyzer; bool dma_bypass_cache = ctx->GetConfig("tir.experimental_dma_bypass_cache", Bool(false)).value(); - fptr->body = AsyncDMALowerer(dma_bypass_cache)(std::move(fptr->body)); + fptr->body = AsyncDMALowerer(dma_bypass_cache, &analyzer)(std::move(fptr->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {}); diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 082a54f9c73d..49023a5ad01f 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -319,6 +319,10 @@ class BuiltinLower : public StmtExprMutator { return MakeDMACopy(op); } else if (op->op.same_as(builtin::dma_wait())) { return MakeDMAWait(op); + } else if (op->op.same_as(builtin::dma_start_group())) { + return MakeDMAStartGroup(op); + } else if (op->op.same_as(builtin::dma_end_group())) { + return MakeDMAEndGroup(op); } else { return StmtExprMutator::VisitExpr_(op); } @@ -352,6 +356,28 @@ class BuiltinLower : public StmtExprMutator { return VisitExpr(call_packed); } + PrimExpr MakeDMAStartGroup(const CallNode* op) { + PrimExpr queue_id = op->args[0]; + + std::string fdevapi_prefix = + "device_api." + std::string(runtime::DeviceName(device_type_.as()->value)); + + Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(), + {StringImm(fdevapi_prefix + ".dma_start_group"), queue_id}); + return VisitExpr(call_packed); + } + + PrimExpr MakeDMAEndGroup(const CallNode* op) { + PrimExpr queue_id = op->args[0]; + + std::string fdevapi_prefix = + "device_api." + std::string(runtime::DeviceName(device_type_.as()->value)); + + Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(), + {StringImm(fdevapi_prefix + ".dma_end_group"), queue_id}); + return VisitExpr(call_packed); + } + // call shape PrimExpr MakeShape(const CallNode* op) { // if args.size() == 0, it represents a scalar shape () diff --git a/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc index e4ffe3a0de9c..9697006bf1aa 100644 --- a/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc @@ -47,7 +47,7 @@ class HexagonUserDMATest : public ::testing::Test { public: HexagonUserDMA* user_dma; int ret = 0; - int queue_id = 0; + uint32_t queue_id = 0; void* src = nullptr; void* dst = nullptr; char* src_char = nullptr; diff --git a/tests/cpp-runtime/hexagon/ring_buffer_tests.cc b/tests/cpp-runtime/hexagon/ring_buffer_tests.cc index 8cf363bae0b3..a3abf82b863f 100644 --- a/tests/cpp-runtime/hexagon/ring_buffer_tests.cc +++ b/tests/cpp-runtime/hexagon/ring_buffer_tests.cc @@ -40,7 +40,7 @@ class RingBufferTest : public ::testing::Test { int finished = 42; int inflight = 43; - uint32_t size = 4; + uint32_t size = 8; uint32_t half = size / 2; RingBuffer* ring_buff = nullptr; }; @@ -160,11 +160,11 @@ TEST_F(RingBufferTest, half_in_flight) { // mark it inflight and check *ptr = inflight; - ASSERT_EQ(ring_buff->InFlight(), 3); + ASSERT_EQ(ring_buff->InFlight(), half + 1); // mark it finished and check also blocked *ptr = finished; - ASSERT_EQ(ring_buff->InFlight(), 3); + ASSERT_EQ(ring_buff->InFlight(), half + 1); } TEST_F(RingBufferTest, half_in_flight_blocked) { @@ -190,13 +190,23 @@ TEST_F(RingBufferTest, half_in_flight_blocked) { } class QueuedRingBufferTest : public RingBufferTest { - void SetUp() override { queued_ring_buff = new QueuedRingBuffer(size, in_flight); } + void SetUp() override { + queued_ring_buff = new QueuedRingBuffer(MAX_QUEUES, size, in_flight); + } void TearDown() override { delete queued_ring_buff; } public: + int MAX_QUEUES = 2; QueuedRingBuffer* queued_ring_buff = nullptr; }; +TEST_F(QueuedRingBufferTest, invalid_queue) { + ASSERT_THROW(queued_ring_buff->Next(MAX_QUEUES), InternalError); + ASSERT_THROW(queued_ring_buff->InFlight(MAX_QUEUES), InternalError); + ASSERT_THROW(queued_ring_buff->StartGroup(MAX_QUEUES), InternalError); + ASSERT_THROW(queued_ring_buff->EndGroup(MAX_QUEUES), InternalError); +} + TEST_F(QueuedRingBufferTest, two_queues) { int* q0 = queued_ring_buff->Next(0); *q0 = inflight; @@ -216,3 +226,188 @@ TEST_F(QueuedRingBufferTest, two_queues) { ASSERT_EQ(queued_ring_buff->InFlight(0), 0); ASSERT_EQ(queued_ring_buff->InFlight(1), 0); } + +TEST_F(QueuedRingBufferTest, group_end_before_group_start) { + ASSERT_THROW(queued_ring_buff->EndGroup(0), InternalError); +} + +TEST_F(QueuedRingBufferTest, group_restart) { + queued_ring_buff->StartGroup(0); + ASSERT_THROW(queued_ring_buff->StartGroup(0), InternalError); +} + +TEST_F(QueuedRingBufferTest, zero_size_group) { + queued_ring_buff->StartGroup(0); + ASSERT_THROW(queued_ring_buff->EndGroup(0), InternalError); +} + +TEST_F(QueuedRingBufferTest, in_flight_before_group_end) { + queued_ring_buff->StartGroup(0); + ASSERT_THROW(queued_ring_buff->InFlight(0), InternalError); +} + +TEST_F(QueuedRingBufferTest, group_of_one) { + queued_ring_buff->StartGroup(0); + int* g0_0 = queued_ring_buff->Next(0); + *g0_0 = inflight; + queued_ring_buff->EndGroup(0); + + ASSERT_EQ(queued_ring_buff->InFlight(0), 1); + *g0_0 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 0); +} + +TEST_F(QueuedRingBufferTest, group_of_two) { + queued_ring_buff->StartGroup(0); + int* g0_0 = queued_ring_buff->Next(0); + *g0_0 = inflight; + int* g0_1 = queued_ring_buff->Next(0); + *g0_1 = inflight; + queued_ring_buff->EndGroup(0); + + // neither done => group in flight + ASSERT_EQ(queued_ring_buff->InFlight(0), 1); + + // half done => group in flight + *g0_0 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 1); + + // both done => group finished + *g0_1 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 0); +} + +TEST_F(QueuedRingBufferTest, group_of_three) { + queued_ring_buff->StartGroup(0); + int* g0_0 = queued_ring_buff->Next(0); + *g0_0 = inflight; + int* g0_1 = queued_ring_buff->Next(0); + *g0_1 = inflight; + int* g0_2 = queued_ring_buff->Next(0); + *g0_2 = inflight; + queued_ring_buff->EndGroup(0); + + // neither done => group in flight + ASSERT_EQ(queued_ring_buff->InFlight(0), 1); + + // 1/3 done => group in flight + *g0_0 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 1); + + // 2/3 done => group in flight + *g0_1 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 1); + + // all done => group finished + *g0_2 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 0); +} + +TEST_F(QueuedRingBufferTest, two_groups_of_two) { + queued_ring_buff->StartGroup(0); + int* g0_0 = queued_ring_buff->Next(0); + *g0_0 = inflight; + int* g0_1 = queued_ring_buff->Next(0); + *g0_1 = inflight; + queued_ring_buff->EndGroup(0); + + queued_ring_buff->StartGroup(0); + int* g1_0 = queued_ring_buff->Next(0); + *g1_0 = inflight; + int* g1_1 = queued_ring_buff->Next(0); + *g1_1 = inflight; + queued_ring_buff->EndGroup(0); + + // two groups in flight + ASSERT_EQ(queued_ring_buff->InFlight(0), 2); + + // group 0 half done => two groups in flight + *g0_0 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 2); + + // group 0 done => one group in flight + *g0_1 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 1); + + // group 1 half done => one group in flight + *g1_0 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 1); + + // group 1 done => zero groups in flight + *g1_1 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 0); +} + +TEST_F(QueuedRingBufferTest, two_queues_two_groups_of_two) { + queued_ring_buff->StartGroup(0); + int* q0g0_0 = queued_ring_buff->Next(0); + *q0g0_0 = inflight; + int* q0g0_1 = queued_ring_buff->Next(0); + *q0g0_1 = inflight; + queued_ring_buff->EndGroup(0); + + queued_ring_buff->StartGroup(1); + int* q1g0_0 = queued_ring_buff->Next(1); + *q1g0_0 = inflight; + int* q1g0_1 = queued_ring_buff->Next(1); + *q1g0_1 = inflight; + queued_ring_buff->EndGroup(1); + + queued_ring_buff->StartGroup(0); + int* q0g1_0 = queued_ring_buff->Next(0); + *q0g1_0 = inflight; + int* q0g1_1 = queued_ring_buff->Next(0); + *q0g1_1 = inflight; + queued_ring_buff->EndGroup(0); + + queued_ring_buff->StartGroup(1); + int* q1g1_0 = queued_ring_buff->Next(1); + *q1g1_0 = inflight; + int* q1g1_1 = queued_ring_buff->Next(1); + *q1g1_1 = inflight; + queued_ring_buff->EndGroup(1); + + // two queues with two groups in flight each + ASSERT_EQ(queued_ring_buff->InFlight(0), 2); + ASSERT_EQ(queued_ring_buff->InFlight(1), 2); + + // queue 0 group 0 half done => no change + *q0g0_0 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 2); + ASSERT_EQ(queued_ring_buff->InFlight(1), 2); + + // queue 0 group 0 done => queue 0 with one group in flight + *q0g0_1 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 1); + ASSERT_EQ(queued_ring_buff->InFlight(1), 2); + + // queue 1 group 0 half done => no change + *q1g0_0 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 1); + ASSERT_EQ(queued_ring_buff->InFlight(1), 2); + + // queue 1 group 0 done => queue 1 with one group in flight + *q1g0_1 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 1); + ASSERT_EQ(queued_ring_buff->InFlight(1), 1); + + // queue 0 group 1 half done => no change + *q0g1_0 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 1); + ASSERT_EQ(queued_ring_buff->InFlight(1), 1); + + // queue 0 group 1 done => queue 0 with zero groups in flight + *q0g1_1 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 0); + ASSERT_EQ(queued_ring_buff->InFlight(1), 1); + + // queue 1 group 1 half done => no change + *q1g1_0 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 0); + ASSERT_EQ(queued_ring_buff->InFlight(1), 1); + + // queue 1 group 1 done => queue 1 with zero groups in flight + *q1g1_1 = finished; + ASSERT_EQ(queued_ring_buff->InFlight(0), 0); + ASSERT_EQ(queued_ring_buff->InFlight(1), 0); +} diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index d985d2120936..030e47ac581d 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -521,7 +521,6 @@ def test_async_dma_resnet50(hexagon_launcher): pass_config = { "tir.use_async_copy": 1, - "tir.merge_async_commit_queue_scope": False, "relay.backend.use_meta_schedule": True, "relay.backend.tir_converter": "default", } diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index bc8edca7d844..04e7595abf37 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -268,7 +268,6 @@ def evaluate( c_data, expected_output=None, use_async_copy=0, - merge_async_commit_queue_scope=False, ): """Evaluate function.""" target_hexagon = tvm.target.hexagon("v68", link_params=True) @@ -276,7 +275,6 @@ def evaluate( config={ "tir.use_async_copy": use_async_copy, "tir.experimental_dma_bypass_cache": 1, - "tir.merge_async_commit_queue_scope": merge_async_commit_queue_scope, } ): func_tir = tvm.build( @@ -485,7 +483,6 @@ def test_loading_vtcm_for_vrmpy( np.zeros(expected_output.shape, "int32"), expected_output, use_async_copy=1, - merge_async_commit_queue_scope=False, ) sch = get_fake_conv_vtcm_schedule(size_a, size_w) @@ -886,14 +883,13 @@ def test_non_contiguous(): """Test Non Contiguous memory lowering.""" sch = tvm.tir.Schedule(conv2d_async_non_contig) target_hexagon = tvm.target.hexagon("v68", link_params=True) - err_rgx = r"Unable to lower async dma for non contiguous memory access with load index: " + err_rgx = r"Unable to lower async dma due to non contiguous memory access" # Currently we do not support non contiguous memory access being lowered to # async dma so we throw an error. with pytest.raises(tvm.TVMError, match=err_rgx): with tvm.transform.PassContext( config={ "tir.use_async_copy": 1, - "tir.merge_async_commit_queue_scope": 0, } ): tvm.build( diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py index 7c010f363fe1..498e29e407b4 100644 --- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py +++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py @@ -181,7 +181,6 @@ def test_async_software_pipeline( config={ "tir.use_async_copy": 1, "tir.experimental_dma_bypass_cache": 1, - "tir.merge_async_commit_queue_scope": False, } ): # tvm.lower(schedule.mod["main"]).show()