Skip to content

Commit

Permalink
[Hexagon] Add concept of DMA groups (#14254)
Browse files Browse the repository at this point in the history
* [Hexagon] Add concept of groups to QueuedRingBuffer

* fix failing RingBuffer unit test

* add StartGroup and EndGroup builtins and lowering

* use max queue number to increase dma copy perf

* increase max dma queues to fix test failures

* use DMA groups in LowerAsyncDMA pass
Co-authored-by: Noah Verke <[email protected]>
Co-authored-by: Eric Lunderberg <[email protected]>

* elide merge_async_commit_queue_scope

* LowerAsyncDMA bug fix + disallow non-contig copy; comments, tests pass

* format and lint

* add comments to Hex User DMA header

* use unsigned queue ID; fix test fails

* format and lint

* use dma_copy_dltensor in TIR tenor intrin; fix test fails

* address feedback: comments, types, names
  • Loading branch information
adstraw authored Mar 20, 2023
1 parent fc2a9e5 commit c6c89c3
Show file tree
Hide file tree
Showing 19 changed files with 488 additions and 199 deletions.
3 changes: 1 addition & 2 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,9 @@ class Postproc : public runtime::ObjectRef {
TVM_DLL static Postproc DisallowDynamicLoop();
/*!
* \brief Create a postprocessor that checks if all async mem copies are not strided.
* \param merge_async_commit_queue_scope Whether or not to merge async commit queue scope.
* \return The postprocessor created
*/
TVM_DLL static Postproc DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope = true);
TVM_DLL static Postproc DisallowAsyncStridedMemCopy();
/*!
* \brief Create a postprocessor that rewrites the cooperative fetch annotation to
* actual vectorized cooperative fetching in loop bindings.
Expand Down
37 changes: 36 additions & 1 deletion include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -727,14 +727,49 @@ TVM_DLL const Op& texture2d_load();

/*!
* \brief Initiate a non-blocking DMA copy from source to destination
*
* The copy is launched immediately.
*
* If a `dma_start_group()` call is active, the copy will be added
* to the current group for tracking of in-flight group counts.
*
* If no `dma_start_group()` call is active, the copy will be tracked
* individually i.e. as a group with size 1.
*/
TVM_DLL const Op& dma_copy();

/*!
* \brief Wait until the number of DMAs in flight is less than or equal to some maximum
* \brief Wait until the number of DMA groups in flight is less than
* or equal to some maximum
*
* Calling `dma_wait()` while a group is active is unsupported.
*/
TVM_DLL const Op& dma_wait();

/*!
* \brief Start a group of DMA copies
*
* Any call to `dma_copy()` that occurs after `dma_start_group()` will
* be added to the current group for tracking of in-flight group counts.
*
* Only one DMA group may be active at a given time. Calling
* `dma_start_group()` while a group is active is unsupported.
*/
TVM_DLL const Op& dma_start_group();

/*!
* \brief End a group of DMA copies
*
* Track all calls to `dma_copy()` that occurred since the preceding
* `dma_start_group()` as a single group in-flight.
*
* Calling `dma_end_group()` without an active group is unsupported.
*
* Note: A group of DMA calls may be empty, and will still contribute
* to the count of in-flight groups used by `dma_wait()`.
*/
TVM_DLL const Op& dma_end_group();

/*!
* \brief Provide a true statement that can be used for simplifications
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,9 @@

@register_object("meta_schedule.DisallowAsyncStridedMemCopy")
class DisallowAsyncStridedMemCopy(Postproc):
"""A postprocessor that disallows schedules that use async strided mem copies.
"""A postprocessor that disallows schedules that use async strided mem copies."""

Parameters
----------
merge_async_commit_queue_scope : bool
Whether or not to merge the async commit queue scope.
"""

def __init__(self, merge_async_commit_queue_scope=True) -> None:
def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocDisallowAsyncStridedMemCopy, # type: ignore # pylint: disable=no-member
merge_async_commit_queue_scope,
)
35 changes: 21 additions & 14 deletions python/tvm/tir/tensor_intrin/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,27 @@ def sync_dma_load_impl(a: T.handle, c: T.handle) -> None:
T.writes(C[0:size])
T.evaluate(
T.tvm_call_packed(
"device_api.hexagon.dma_copy",
-1, # Use QueueId of -1 to not interfere with async copies.
T.address_of(C[0], dtype="handle"),
T.address_of(A[0], dtype="handle"),
size,
0, # Do not use experimental bypass mode.
dtype="int32",
)
)
T.evaluate(
T.tvm_call_packed(
"device_api.hexagon.dma_wait",
-1,
0, # Wait for the sync queue (-1) to have 0 messages.
"device_api.hexagon.dma_copy_dltensor",
T.tvm_stack_make_array(
T.address_of(C[0], dtype="handle"),
T.tvm_stack_make_shape(size, dtype="handle"),
0,
1,
C.dtype,
0,
dtype="handle",
),
T.tvm_stack_make_array(
T.address_of(A[0], dtype="handle"),
T.tvm_stack_make_shape(size, dtype="handle"),
0,
1,
A.dtype,
0,
dtype="handle",
),
T.cast(size, dtype="int"),
False, # Do not use experimental bypass mode.
dtype="int32",
)
)
Expand Down
1 change: 0 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,6 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {
pass_list.push_back(tir::transform::InjectDoubleBuffer());
pass_list.push_back(tir::transform::VectorizeLoop(true));
pass_list.push_back(tir::transform::StorageRewrite());
transform::PassContext pass_ctx = transform::PassContext::Current();
pass_ctx->config.Set("tir.merge_async_commit_queue_scope",
Bool(merge_async_commit_queue_scope));
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
runtime::String(g_var->name_hint));
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
Expand All @@ -169,15 +166,12 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {
return Postproc(n);
}

bool merge_async_commit_queue_scope = true;

static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy";
TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode);
};

Postproc Postproc::DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope) {
Postproc Postproc::DisallowAsyncStridedMemCopy() {
ObjectPtr<DisallowAsyncStridedMemCopyNode> n = make_object<DisallowAsyncStridedMemCopyNode>();
n->merge_async_commit_queue_scope = merge_async_commit_queue_scope;
return Postproc(n);
}

Expand Down
19 changes: 16 additions & 3 deletions src/runtime/hexagon/hexagon_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,10 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor")
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVMRetValue* rv) {
int queue_id = args[0];
uint32_t queue_id = static_cast<int>(args[0]);
void* dst = args[1];
void* src = args[2];
int size = args[3];
uint32_t size = static_cast<int>(args[3]);
ICHECK(size > 0);
bool bypass_cache = args[4];

Expand All @@ -226,13 +226,26 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVM
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVMRetValue* rv) {
int queue_id = args[0];
uint32_t queue_id = static_cast<int>(args[0]);
int inflight = args[1];
ICHECK(inflight >= 0);
HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_start_group")
.set_body([](TVMArgs args, TVMRetValue* rv) {
uint32_t queue_id = static_cast<int>(args[0]);
HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_end_group").set_body([](TVMArgs args, TVMRetValue* rv) {
uint32_t queue_id = static_cast<int>(args[0]);
HexagonDeviceAPI::Global()->UserDMA()->EndGroup(queue_id);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
int32_t device_type = args[0];
int32_t device_id = args[1];
Expand Down
14 changes: 8 additions & 6 deletions src/runtime/hexagon/hexagon_user_dma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ unsigned int HexagonUserDMA::Init() {
return status;
}

int HexagonUserDMA::Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache) {
int HexagonUserDMA::Copy(uint32_t queue_id, void* dst, void* src, uint32_t length,
bool bypass_cache) {
// length limited to 24 bits
if (length > DESC_LENGTH_MASK) {
return DMA_FAILURE;
Expand Down Expand Up @@ -103,15 +104,15 @@ int HexagonUserDMA::Copy(int queue_id, void* dst, void* src, uint32_t length, bo
return DMA_SUCCESS;
}

void HexagonUserDMA::Wait(int queue_id, uint32_t max_dmas_in_flight) {
void HexagonUserDMA::Wait(uint32_t queue_id, uint32_t max_dmas_in_flight) {
// wait (forever) until max DMAs in flight <= actual DMAs in flight
while (DMAsInFlight(queue_id) > max_dmas_in_flight) {
while (DMAGroupsInFlight(queue_id) > max_dmas_in_flight) {
}
}

uint32_t HexagonUserDMA::Poll(int queue_id) { return DMAsInFlight(queue_id); }
uint32_t HexagonUserDMA::Poll(uint32_t queue_id) { return DMAGroupsInFlight(queue_id); }

uint32_t HexagonUserDMA::DMAsInFlight(int queue_id) {
uint32_t HexagonUserDMA::DMAGroupsInFlight(uint32_t queue_id) {
dmpoll(); // update DMA engine status
return descriptors_->InFlight(queue_id);
}
Expand All @@ -125,7 +126,8 @@ HexagonUserDMA::HexagonUserDMA() {
unsigned int done = dma_desc_get_done(dma_desc);
return (done != DESC_DONE_COMPLETE);
};
descriptors_ = new QueuedRingBuffer<dma_desc_2d_t>(MAX_DMA_DESCRIPTORS, desc_in_flight);
descriptors_ =
new QueuedRingBuffer<dma_desc_2d_t>(MAX_DMA_QUEUES, MAX_DMA_DESCRIPTORS, desc_in_flight);
}

HexagonUserDMA::~HexagonUserDMA() {
Expand Down
31 changes: 25 additions & 6 deletions src/runtime/hexagon/hexagon_user_dma.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ namespace hexagon {
#define DMA_FAILURE -1
#define DMA_RETRY 1
#define MAX_DMA_DESCRIPTORS 100
#define SYNC_DMA_QUEUE -1
#define MAX_DMA_QUEUES 10
#define SYNC_DMA_QUEUE MAX_DMA_QUEUES - 1

class HexagonUserDMA {
public:
Expand All @@ -47,32 +48,50 @@ class HexagonUserDMA {

/*!
* \brief Initiate DMA to copy memory from source to destination address
* \param queue_id The virtual DMA queue
* \param dst Destination address
* \param src Source address
* \param length Length in bytes to copy
* \returns Status: DMA_SUCCESS or DMA_FAILURE
*/
int Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache);
int Copy(uint32_t queue_id, void* dst, void* src, uint32_t length, bool bypass_cache);

/*!
* \brief Wait until the number of DMAs in flight is less than or equal to some maximum
* \param queue_id The virtual DMA queue
* \param max_dmas_in_flight Maximum number of DMAs allowed to be in flight
* to satisfy the `Wait` e.g. use `Wait(0)` to wait on "all" outstanding DMAs to complete
*/
void Wait(int queue_id, uint32_t max_dmas_in_flight);
void Wait(uint32_t queue_id, uint32_t max_dmas_in_flight);

/*!
* \brief Poll the number of DMAs in flight
* \param queue_id The virtual DMA queue
* \returns Number of DMAs in flight
*/
uint32_t Poll(int queue_id);
uint32_t Poll(uint32_t queue_id);

/*!
* \brief Start a group of DMA copies
* \param queue_id The virtual DMA queue
*/
void StartGroup(uint32_t queue_id) { descriptors_->StartGroup(queue_id); }

/*!
* \brief End a group of DMA copies
* \param queue_id The virtual DMA queue
*/
void EndGroup(uint32_t queue_id) { descriptors_->EndGroup(queue_id); }

private:
//! \brief Initializes the Hexagon User DMA engine
unsigned int Init();

//! \brief Calculates and returns the number of DMAs in flight
uint32_t DMAsInFlight(int queue_id);
/*!
* \brief Calculates and returns the number of DMAs in flight
* \param queue_id The virtual DMA queue
*/
uint32_t DMAGroupsInFlight(uint32_t queue_id);

//! \brief Tracks whether the very first DMA has been executed
bool first_dma_ = true;
Expand Down
Loading

0 comments on commit c6c89c3

Please sign in to comment.