Skip to content

Commit

Permalink
[Hexagon] Refactor to keep HexagonBuffer private to the device api (a…
Browse files Browse the repository at this point in the history
…pache#10910)

* No longer return HexagonBuffer from device api.

* fixup! No longer return HexagonBuffer from device api.
  • Loading branch information
csullivan authored and mehrdadh committed Apr 11, 2022
1 parent 30bfbd3 commit 3b46108
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 126 deletions.
71 changes: 1 addition & 70 deletions src/runtime/hexagon/hexagon/hexagon_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,72 +47,6 @@ namespace tvm {
namespace runtime {
namespace hexagon {

void HexagonLookupLinkedParam(TVMArgs args, TVMRetValue* rv) {
Module mod = args[0];
int64_t storage_id = args[1];
DLTensor* template_tensor = args[2];
Device dev = args[3];
auto lookup_linked_param = mod.GetFunction(::tvm::runtime::symbol::tvm_lookup_linked_param, true);
if (lookup_linked_param == nullptr) {
*rv = nullptr;
return;
}

TVMRetValue opaque_handle = lookup_linked_param(storage_id);
if (opaque_handle.type_code() == kTVMNullptr) {
*rv = nullptr;
return;
}

std::vector<int64_t> shape_vec{template_tensor->shape,
template_tensor->shape + template_tensor->ndim};

Optional<String> scope("global");
auto* param_buffer =
new HexagonBuffer(static_cast<void*>(opaque_handle), GetDataSize(*template_tensor), scope);
auto* container = new NDArray::Container(static_cast<void*>(param_buffer), shape_vec,
template_tensor->dtype, dev);
container->SetDeleter([](Object* container) {
// The NDArray::Container needs to be deleted
// along with the HexagonBuffer wrapper. However the
// buffer's data points to global const memory and
// so should not be deleted.
auto* ptr = static_cast<NDArray::Container*>(container);
delete static_cast<HexagonBuffer*>(ptr->dl_tensor.data);
delete ptr;
});
*rv = NDArray(GetObjectPtr<Object>(container));
}

PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
TVMValue ret_value;
int ret_type_code = kTVMNullptr;

TVMValue* arg_values = const_cast<TVMValue*>(args.values);
std::vector<std::pair<size_t, HexagonBuffer*>> buffer_args;
for (int i = 0; i < args.num_args; i++) {
if (args.type_codes[i] == kTVMDLTensorHandle) {
DLTensor* tensor = static_cast<DLTensor*>(arg_values[i].v_handle);
buffer_args.emplace_back(i, static_cast<HexagonBuffer*>(tensor->data));
HexagonBuffer* hexbuf = buffer_args.back().second;
tensor->data = hexbuf->GetPointer();
}
}
int ret = (*faddr)(const_cast<TVMValue*>(args.values), const_cast<int*>(args.type_codes),
args.num_args, &ret_value, &ret_type_code, nullptr);
ICHECK_EQ(ret, 0) << TVMGetLastError();

for (auto& arg : buffer_args) {
DLTensor* tensor = static_cast<DLTensor*>(arg_values[arg.first].v_handle);
tensor->data = arg.second;
}

if (ret_type_code != kTVMNullptr) {
*rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code);
}
});
}

#if defined(__hexagon__)
class HexagonTimerNode : public TimerNode {
Expand Down Expand Up @@ -165,12 +99,9 @@ void LogMessageImpl(const std::string& file, int lineno, const std::string& mess
}
} // namespace detail

TVM_REGISTER_GLOBAL("tvm.runtime.hexagon.lookup_linked_params")
.set_body(hexagon::HexagonLookupLinkedParam);

TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon").set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectPtr<Library> n = CreateDSOLibraryObject(args[0]);
*rv = CreateModuleFromLibrary(n, hexagon::WrapPackedFunc);
*rv = CreateModuleFromLibrary(n);
});
} // namespace runtime
} // namespace tvm
14 changes: 0 additions & 14 deletions src/runtime/hexagon/hexagon/hexagon_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,6 @@
} \
} while (0)

namespace tvm {
namespace runtime {
namespace hexagon {

/*! \brief Unpack HexagonBuffers in packed functions
* prior to invoking.
* \param faddr The function address.
* \param mptr The module pointer node.
* \return A packed function wrapping the requested function.
*/
PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>& mptr);
} // namespace hexagon
} // namespace runtime
} // namespace tvm
inline bool IsHexagonDevice(DLDevice dev) {
return TVMDeviceExtType(dev.device_type) == kDLHexagon;
}
Expand Down
63 changes: 28 additions & 35 deletions src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, int ndim, const int64_t* sh

if (ndim == 1) {
size_t nbytes = shape[0] * typesize;
return new HexagonBuffer(nbytes, alignment, mem_scope);
return AllocateHexagonBuffer(nbytes, alignment, mem_scope);
} else if (ndim == 2) {
size_t nallocs = shape[0];
size_t nbytes = shape[1] * typesize;
return new HexagonBuffer(nallocs, nbytes, alignment, mem_scope);
return AllocateHexagonBuffer(nallocs, nbytes, alignment, mem_scope);
} else {
LOG(FATAL) << "Hexagon Device API supports only 1d and 2d allocations, but received ndim = "
<< ndim;
Expand All @@ -94,16 +94,14 @@ void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, size_t nbytes, size_t align
if (alignment < kHexagonAllocAlignment) {
alignment = kHexagonAllocAlignment;
}
return new HexagonBuffer(nbytes, alignment, String("global"));
return AllocateHexagonBuffer(nbytes, alignment, String("global"));
}

void HexagonDeviceAPIv2::FreeDataSpace(Device dev, void* ptr) {
bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) ||
(DLDeviceType(dev.device_type) == kDLCPU);
CHECK(is_valid_device) << "dev.device_type: " << dev.device_type;
auto* hexbuf = static_cast<HexagonBuffer*>(ptr);
CHECK(hexbuf != nullptr);
delete hexbuf;
FreeHexagonBuffer(ptr);
}

// WorkSpace: runtime allocations for Hexagon
Expand All @@ -114,21 +112,14 @@ struct HexagonWorkspacePool : public WorkspacePool {

void* HexagonDeviceAPIv2::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) {
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;
auto* hexbuf = static_cast<HexagonBuffer*>(
dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->AllocWorkspace(dev, size));

void* ptr = hexbuf->GetPointer();
workspace_allocations_.insert({ptr, hexbuf});
return ptr;
return dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->AllocWorkspace(dev, size);
}

void HexagonDeviceAPIv2::FreeWorkspace(Device dev, void* data) {
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;
auto it = workspace_allocations_.find(data);
CHECK(it != workspace_allocations_.end())
CHECK(hexagon_buffer_map_.count(data) != 0)
<< "Attempt made to free unknown or already freed workspace allocation";
dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->FreeWorkspace(dev, it->second);
workspace_allocations_.erase(it);
dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->FreeWorkspace(dev, data);
}

void* HexagonDeviceAPIv2::AllocVtcmWorkspace(Device dev, int ndim, const int64_t* shape,
Expand All @@ -148,21 +139,26 @@ void HexagonDeviceAPIv2::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamH
CHECK_EQ(to->byte_offset, 0);
CHECK_EQ(GetDataSize(*from), GetDataSize(*to));

HexagonBuffer* hex_from_buf = static_cast<HexagonBuffer*>(from->data);
HexagonBuffer* hex_to_buf = static_cast<HexagonBuffer*>(to->data);
auto lookup_hexagon_buffer = [this](void* ptr) -> HexagonBuffer* {
auto it = this->hexagon_buffer_map_.find(ptr);
CHECK(it != this->hexagon_buffer_map_.end())
<< "Lookup failed for non-HexagonBuffer allocation, CopyDataFromTo can only copy data "
"from, to or between HexagonBuffers";
return it->second.get();
};

if (TVMDeviceExtType(from->device.device_type) == kDLHexagon &&
TVMDeviceExtType(to->device.device_type) == kDLHexagon) {
CHECK(hex_from_buf != nullptr);
CHECK(hex_to_buf != nullptr);
HexagonBuffer* hex_from_buf = lookup_hexagon_buffer(from->data);
HexagonBuffer* hex_to_buf = lookup_hexagon_buffer(to->data);
hex_to_buf->CopyFrom(*hex_from_buf, GetDataSize(*from));
} else if (from->device.device_type == kDLCPU &&
TVMDeviceExtType(to->device.device_type) == kDLHexagon) {
CHECK(hex_to_buf != nullptr);
HexagonBuffer* hex_to_buf = lookup_hexagon_buffer(to->data);
hex_to_buf->CopyFrom(from->data, GetDataSize(*from));
} else if (TVMDeviceExtType(from->device.device_type) == kDLHexagon &&
to->device.device_type == kDLCPU) {
CHECK(hex_from_buf != nullptr);
HexagonBuffer* hex_from_buf = lookup_hexagon_buffer(from->data);
hex_from_buf->CopyTo(to->data, GetDataSize(*to));
} else {
CHECK(false)
Expand All @@ -177,6 +173,14 @@ void HexagonDeviceAPIv2::CopyDataFromTo(const void* from, size_t from_offset, vo
memcpy(static_cast<char*>(to) + to_offset, static_cast<const char*>(from) + from_offset, size);
}

void HexagonDeviceAPIv2::FreeHexagonBuffer(void* ptr) {
auto it = hexagon_buffer_map_.find(ptr);
CHECK(it != hexagon_buffer_map_.end())
<< "Attempt made to free unknown or already freed dataspace allocation";
CHECK(it->second != nullptr);
hexagon_buffer_map_.erase(it);
}

TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVMRetValue* rv) {
void* dst = args[0];
void* src = args[1];
Expand All @@ -187,8 +191,6 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVM
*rv = static_cast<int32_t>(0);
});

std::map<void*, HexagonBuffer*> vtcmallocs;

TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
int32_t device_type = args[0];
int32_t device_id = args[1];
Expand All @@ -210,12 +212,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVM
type_hint.lanes = 1;

HexagonDeviceAPIv2* hexapi = HexagonDeviceAPIv2::Global();
HexagonBuffer* hexbuf = reinterpret_cast<HexagonBuffer*>(
hexapi->AllocVtcmWorkspace(dev, ndim, shape, type_hint, String(scope)));

void* ptr = hexbuf->GetPointer();
vtcmallocs[ptr] = hexbuf;
*rv = ptr;
*rv = hexapi->AllocVtcmWorkspace(dev, ndim, shape, type_hint, String(scope));
});

TVM_REGISTER_GLOBAL("device_api.hexagon.free_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
Expand All @@ -224,17 +221,13 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.free_nd").set_body([](TVMArgs args, TVMR
std::string scope = args[2];
CHECK(scope.find("global.vtcm") != std::string::npos);
void* ptr = args[3];
CHECK(vtcmallocs.find(ptr) != vtcmallocs.end());

HexagonBuffer* hexbuf = vtcmallocs[ptr];
vtcmallocs.erase(ptr);

Device dev;
dev.device_type = static_cast<DLDeviceType>(device_type);
dev.device_id = device_id;

HexagonDeviceAPIv2* hexapi = HexagonDeviceAPIv2::Global();
hexapi->FreeVtcmWorkspace(dev, hexbuf);
hexapi->FreeVtcmWorkspace(dev, ptr);
*rv = static_cast<int32_t>(0);
});

Expand Down
27 changes: 22 additions & 5 deletions src/runtime/hexagon/hexagon/hexagon_device_api_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@
#include <tvm/runtime/device_api.h>

#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "hexagon_buffer.h"

namespace tvm {
namespace runtime {
namespace hexagon {

class HexagonBuffer;

/*!
* \brief Hexagon Device API that is compiled and run on Hexagon.
*/
Expand Down Expand Up @@ -70,7 +72,7 @@ class HexagonDeviceAPIv2 final : public DeviceAPI {
*/
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;

//! Dereference workspace pool and erase from tracked workspace_allocations_.
//! Erase from tracked hexagon_buffer_map and free
void FreeWorkspace(Device dev, void* data) final;

/*!
Expand Down Expand Up @@ -125,8 +127,23 @@ class HexagonDeviceAPIv2 final : public DeviceAPI {
TVMStreamHandle stream) final;

private:
//! Lookup table for the HexagonBuffer managing a workspace allocation.
std::unordered_map<void*, HexagonBuffer*> workspace_allocations_;
/*! \brief Helper to allocate a HexagonBuffer and register the result
* in the owned buffer map.
* \return Raw data storage managed by the hexagon buffer
*/
template <typename... Args>
void* AllocateHexagonBuffer(Args&&... args) {
auto buf = std::make_unique<HexagonBuffer>(std::forward<Args>(args)...);
void* ptr = buf->GetPointer();
hexagon_buffer_map_.insert({ptr, std::move(buf)});
return ptr;
}
/*! \brief Helper to free a HexagonBuffer and unregister the result
* from the owned buffer map.
*/
void FreeHexagonBuffer(void* ptr);
//! Lookup table for the HexagonBuffer managing an allocation.
std::unordered_map<void*, std::unique_ptr<HexagonBuffer>> hexagon_buffer_map_;
};
} // namespace hexagon
} // namespace runtime
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/hexagon/rpc/hexagon/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,5 +288,5 @@ TVM_REGISTER_GLOBAL("tvm.hexagon.load_module")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue* rv) {
std::string soname = args[0];
tvm::ObjectPtr<tvm::runtime::Library> n = tvm::runtime::CreateDSOLibraryObject(soname);
*rv = CreateModuleFromLibrary(n, tvm::runtime::hexagon::WrapPackedFunc);
*rv = CreateModuleFromLibrary(n);
});
2 changes: 1 addition & 1 deletion src/runtime/hexagon/rpc/simulator/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,5 +321,5 @@ TVM_REGISTER_GLOBAL("tvm.hexagon.load_module")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue* rv) {
std::string soname = args[0];
tvm::ObjectPtr<tvm::runtime::Library> n = tvm::runtime::CreateDSOLibraryObject(soname);
*rv = CreateModuleFromLibrary(n, tvm::runtime::hexagon::WrapPackedFunc);
*rv = CreateModuleFromLibrary(n);
});

0 comments on commit 3b46108

Please sign in to comment.