Skip to content

Commit

Permalink
[Hexagon] Generalize builtin for Nd memory alloc with storage scope a…
Browse files Browse the repository at this point in the history
…nd add lowering for VTCM / Hexagon (apache#10558)

* repurpose texture flatten for vtcm; TIR lowering correct

* clean up remaining code in texture flatten pass

* add Alloc and FreeTexture, but failing to run over rpc

* test passing with malloc in the device api

* cleanup

* fails in very reliable way with memory corruption

* working with non-HexagonBuffer vtcm alloc

* cleanup

* do not pass scope through mem_copy api

* [Hexagon] Resolve breakage in test_hexagon/test_cache_read_write

Breakage was caused by apache#9727, which
didn't account for the new `builtin::mem_copy()` when computing the
stack size in `StackSizeChecker`.

* use HexagonBuffer in Alloc and Free packed funcs

* Added comment indicating need for StackSizeChecker::MakeMemCopy.

* add AllocVtcmWorkspace and FreeVtcmWorkspace

* cleanup

* Updated unittests to run all contrib/test_hexagon at CI.

* create separate vtcm alloc lowering pass and transform

* reset texture_flatten.cc

* comments

* CI bump

* Fix lint formatting error.

* Updated fix to remove StackSizeChecker entirely.

* pass device and type to device api

* Bugfix, verify the precheck's allocations, not own.

* Bugfix, pass context information to the precheck.

* pass order and shape to device api

* working

* fix up types and arg passing

* pass scope to device api

* common builtin for texture / vtcm

* add scope to freend api

* format and lint

* fixed missed format error

* restart ci

* fix test random value issue + code review feedback

* fix test hang

* restructure lower vtcm pass per code review feedback (option a)

* format error

* global.vtcm + tvm_stack_make_shape

Co-authored-by: Eric Lunderberg <[email protected]>
  • Loading branch information
2 people authored and pfk-beta committed Apr 11, 2022
1 parent f2c5ec3 commit ee516a8
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 35 deletions.
4 changes: 2 additions & 2 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -644,9 +644,9 @@ TVM_DLL const Op& vectorcombine();
*/
TVM_DLL const Op& atomic_add();
/*!
* \brief Create a texture 2d memory allocation
* \brief Create an Nd memory allocation with storage scope
*/
TVM_DLL const Op& texture2d_alloca();
TVM_DLL const Op& nd_mem_alloc_with_scope();

/*!
* \brief Store to texture 2d memory
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,13 @@ TVM_DLL Pass FlattenBuffer();
*/
TVM_DLL Pass TextureFlatten();

/*
* \brief Lower VTCM allocations
*
* \return The Pass
*/
TVM_DLL Pass LowerVtcmAlloc();

/*!
* \brief Implements a Common Subexpression Elimination (CSE) for TIR
* which introduces let-in bindings for duplicated sub-expressions.
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::InjectPrefetch());
pass_list.push_back(tir::transform::TextureFlatten());
pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
pass_list.push_back(tir::transform::LowerVtcmAlloc());
pass_list.push_back(tir::transform::LowerCrossThreadReduction());
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/hexagon/hexagon/hexagon_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>&
DLTensor* tensor = static_cast<DLTensor*>(arg_values[i].v_handle);
buffer_args.emplace_back(i, static_cast<HexagonBuffer*>(tensor->data));
// Assumes a single contiguous allocation
// TODO(Straw): Enable discontiguous allocation after RFC 39 lands
// TODO(Straw): Enable discontiguous allocation
tensor->data = buffer_args.back().second->GetPointer()[0];
}
}
Expand Down
72 changes: 70 additions & 2 deletions src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, int ndim, const int64_t* sh
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;

// Forcing contiguous allocation, for now
// TODO(Straw): Enable discontiguous allocation after RFC 39 lands
// TODO(Straw): Enable discontiguous allocation
size_t nallocs = 1;
size_t nbytes = 1;
for (int i = 0; i < ndim; ++i) {
Expand Down Expand Up @@ -107,7 +107,7 @@ void* HexagonDeviceAPIv2::AllocWorkspace(Device dev, size_t size, DLDataType typ
dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->AllocWorkspace(dev, size));

// Assumes a single contiguous allocation
// TODO(Straw): Enable discontiguous allocation after RFC 39 lands
// TODO(Straw): Enable discontiguous allocation
void* ptr = hexbuf->GetPointer()[0];
workspace_allocations_.insert({ptr, hexbuf});
return ptr;
Expand All @@ -122,6 +122,20 @@ void HexagonDeviceAPIv2::FreeWorkspace(Device dev, void* data) {
workspace_allocations_.erase(it);
}

void* HexagonDeviceAPIv2::AllocVtcmWorkspace(Device dev, int ndim, const int64_t* shape,
DLDataType dtype, Optional<String> mem_scope) {
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;
// Forcing contiguous allocation, for now
// TODO(Straw): Enable discontiguous allocation
CHECK_EQ(ndim, 1);
return AllocDataSpace(dev, ndim, shape, dtype, mem_scope);
}

void HexagonDeviceAPIv2::FreeVtcmWorkspace(Device dev, void* ptr) {
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;
FreeDataSpace(dev, ptr);
}

void HexagonDeviceAPIv2::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) {
CHECK_EQ(from->byte_offset, 0);
CHECK_EQ(to->byte_offset, 0);
Expand Down Expand Up @@ -166,6 +180,60 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVM
*rv = static_cast<int32_t>(0);
});

std::map<void*, HexagonBuffer*> vtcmallocs;

TVM_REGISTER_GLOBAL("device_api.hexagon.AllocNd").set_body([](TVMArgs args, TVMRetValue* rv) {
int32_t device_type = args[0];
int32_t device_id = args[1];
int32_t dtype_code_hint = args[2];
int32_t dtype_bits_hint = args[3];
std::string scope = args[4];
CHECK(scope.find("global.vtcm") != std::string::npos);
int64_t ndim = args[5];
// Forcing contiguous allocation, for now
// TODO(Straw): Enable discontiguous allocation
CHECK_EQ(ndim, 1);
int64_t* shape = static_cast<int64_t*>(static_cast<void*>(args[6]));

Device dev;
dev.device_type = static_cast<DLDeviceType>(device_type);
dev.device_id = device_id;

DLDataType type_hint;
type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1;

HexagonDeviceAPIv2* hexapi = HexagonDeviceAPIv2::Global();
HexagonBuffer* hexbuf = reinterpret_cast<HexagonBuffer*>(
hexapi->AllocVtcmWorkspace(dev, ndim, shape, type_hint, String(scope)));

// Assumes a single contiguous allocation
// TODO(Straw): Enable discontiguous allocation
void* ptr = hexbuf->GetPointer()[0];
vtcmallocs[ptr] = hexbuf;
*rv = ptr;
});

TVM_REGISTER_GLOBAL("device_api.hexagon.FreeNd").set_body([](TVMArgs args, TVMRetValue* rv) {
int32_t device_type = args[0];
int32_t device_id = args[1];
std::string scope = args[2];
CHECK(scope.find("vtcm") != std::string::npos);
void* ptr = args[3];
CHECK(vtcmallocs.find(ptr) != vtcmallocs.end());

HexagonBuffer* hexbuf = vtcmallocs[ptr];

Device dev;
dev.device_type = static_cast<DLDeviceType>(device_type);
dev.device_id = device_id;

HexagonDeviceAPIv2* hexapi = HexagonDeviceAPIv2::Global();
hexapi->FreeVtcmWorkspace(dev, hexbuf);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.v2").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = HexagonDeviceAPIv2::Global();
*rv = static_cast<void*>(ptr);
Expand Down
17 changes: 17 additions & 0 deletions src/runtime/hexagon/hexagon/hexagon_device_api_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@

#include <tvm/runtime/device_api.h>

#include <map>
#include <string>
#include <unordered_map>
#include <vector>

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -82,6 +85,20 @@ class HexagonDeviceAPIv2 final : public DeviceAPI {
void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype,
Optional<String> mem_scope) final;

/*!
* \brief Allocate an Nd VTCM workspace.
* \param dev The device to perform the operation.
* \param ndim The number of dimensions of allocated tensor.
* \param shape The shape of allocated tensor.
* \param dtype The element type.
* \return The allocated HexagonBuffer pointer.
*/
void* AllocVtcmWorkspace(Device dev, int ndim, const int64_t* shape, DLDataType dtype,
Optional<String> mem_scope);

//! \brief Free the allocated Nd VTCM workspace.
void FreeVtcmWorkspace(Device dev, void* ptr);

/*!
* \brief Copy data from one storage to another.
* \note This API is designed to support special memory with shape dependent layout.
Expand Down
30 changes: 19 additions & 11 deletions src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,19 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
initialized_ = true;
}

TVM_REGISTER_GLOBAL("device_api.opencl.AllocTexture").set_body([](TVMArgs args, TVMRetValue* rv) {
int device_type = args[0];
int device_id = args[1];
int width = args[2];
int height = args[3];
int dtype_code_hint = args[4];
int dtype_bits_hint = args[5];
TVM_REGISTER_GLOBAL("device_api.opencl.AllocNd").set_body([](TVMArgs args, TVMRetValue* rv) {
int32_t device_type = args[0];
int32_t device_id = args[1];
int32_t dtype_code_hint = args[2];
int32_t dtype_bits_hint = args[3];
std::string scope = args[4];
CHECK(scope.find("texture") != std::string::npos);
int64_t ndim = args[5];
CHECK_EQ(ndim, 2);
int64_t* shape = static_cast<int64_t*>(static_cast<void*>(args[6]));
int64_t width = shape[0];
int64_t height = shape[1];

Device dev;
dev.device_type = static_cast<DLDeviceType>(device_type);
dev.device_id = device_id;
Expand All @@ -459,10 +465,12 @@ TVM_REGISTER_GLOBAL("device_api.opencl.AllocTexture").set_body([](TVMArgs args,
type_hint);
});

TVM_REGISTER_GLOBAL("device_api.opencl.FreeTexture").set_body([](TVMArgs args, TVMRetValue* rv) {
int device_type = args[0];
int device_id = args[1];
void* data = args[2];
TVM_REGISTER_GLOBAL("device_api.opencl.FreeNd").set_body([](TVMArgs args, TVMRetValue* rv) {
int32_t device_type = args[0];
int32_t device_id = args[1];
std::string scope = args[2];
CHECK(scope.find("texture") != std::string::npos);
void* data = args[3];
OpenCLWorkspace* ptr = OpenCLWorkspace::Global();
Device dev;
dev.device_type = static_cast<DLDeviceType>(device_type);
Expand Down
2 changes: 1 addition & 1 deletion src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ TIR_DEFINE_BUILTIN_FUNC(vectorcombine)
TIR_DEFINE_BUILTIN_FUNC(atomic_add)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(texture2d_alloca)
TIR_DEFINE_BUILTIN_FUNC(nd_mem_alloc_with_scope)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(texture2d_store)
Expand Down
34 changes: 21 additions & 13 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ class BuiltinLower : public StmtExprMutator {

Stmt VisitStmt_(const LetStmtNode* op) final {
if (const CallNode* call = op->value.as<CallNode>()) {
if (call->op.same_as(builtin::texture2d_alloca())) {
return StmtExprMutator::VisitStmt(MakeTextureAlloc(op, call));
if (call->op.same_as(builtin::nd_mem_alloc_with_scope())) {
return StmtExprMutator::VisitStmt(MakeNdMemAllocWithScope(op, call));
}
}
return StmtExprMutator::VisitStmt_(op);
Expand Down Expand Up @@ -459,32 +459,40 @@ class BuiltinLower : public StmtExprMutator {
return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args);
}

Stmt MakeTextureAlloc(const LetStmtNode* let, const CallNode* call) {
Stmt MakeNdMemAllocWithScope(const LetStmtNode* let, const CallNode* call) {
ICHECK(device_type_.defined()) << "Unknown device type in current IR";
ICHECK(device_id_.defined()) << "Unknown device id in current IR";
Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}));

Stmt body = SeqStmt(
{IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), throw_last_error),
let->body});

DataType dtype =
let->var->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>()->dtype;

std::string fdevapi_prefix = "device_api.";
fdevapi_prefix += runtime::DeviceName(device_type_.as<IntImmNode>()->value);
Call call_packed =
Call(let->var.dtype(), builtin::tvm_call_packed(),
{StringImm(fdevapi_prefix + ".AllocTexture"), cast(DataType::Int(32), device_type_),
cast(DataType::Int(32), device_id_), cast(DataType::UInt(64), call->args[0]),
cast(DataType::UInt(64), call->args[1]), IntImm(DataType::Int(32), dtype.code()),
IntImm(DataType::Int(32), dtype.bits())});

Array<PrimExpr> args = {
StringImm(fdevapi_prefix + ".AllocNd"),
device_type_,
device_id_,
IntImm(DataType::Int(32), dtype.code()),
IntImm(DataType::Int(32), dtype.bits()),
};

for (size_t i = 0; i < call->args.size(); ++i) {
args.push_back(call->args[i]);
}

Call call_packed = Call(let->var.dtype(), builtin::tvm_call_packed(), args);
Stmt alloca = LetStmt(let->var, call_packed, body);

Call free_op =
Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(fdevapi_prefix + ".FreeTexture"), cast(DataType::Int(32), device_type_),
cast(DataType::Int(32), device_id_), let->var});
PrimExpr storage_scope = call->args[0];
Call free_op = Call(
DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(fdevapi_prefix + ".FreeNd"), device_type_, device_id_, storage_scope, let->var});

Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
body = SeqStmt({alloca, free_stmt});
Expand Down
80 changes: 80 additions & 0 deletions src/tir/transforms/lower_vtcm_alloc.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/transform.h>

#include "../../arith/ir_visitor_with_analyzer.h"

namespace tvm {
namespace tir {

inline bool IsVtcmStorage(std::string scope) {
return scope.find("global.vtcm") != std::string::npos;
}

class VtcmAllocator : public StmtExprMutator {
public:
using StmtExprMutator::VisitStmt_;
VtcmAllocator() {}

Stmt VisitStmt_(const AllocateNode* op) final {
std::string storage_scope = GetStorageScope(op->buffer_var);
if (IsVtcmStorage(storage_scope)) {
Stmt body = this->VisitStmt(op->body);
Array<PrimExpr> args;
args.push_back(StringImm(storage_scope));
args.push_back(IntImm(DataType::Int(64), op->extents.size()));
args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), op->extents));
return LetStmt(op->buffer_var,
Call(op->buffer_var.dtype(), builtin::nd_mem_alloc_with_scope(), args), body);
}
return StmtExprMutator::VisitStmt_(op);
}

protected:
std::string GetStorageScope(const Var& var) {
auto* ptr = var->type_annotation.as<PointerTypeNode>();
ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType";
return ptr->storage_scope;
}
};

PrimFunc LowerVtcmAlloc(PrimFunc func) {
auto fptr = func.CopyOnWrite();
fptr->body = VtcmAllocator()(std::move(fptr->body));
return func;
}

namespace transform {

Pass LowerVtcmAlloc() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return LowerVtcmAlloc(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerVtcmAlloc", {});
}

TVM_REGISTER_GLOBAL("tir.transform.LowerVtcmAlloc").set_body_typed(LowerVtcmAlloc);

} // namespace transform

} // namespace tir
} // namespace tvm
9 changes: 7 additions & 2 deletions src/tir/transforms/texture_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,13 @@ class TextureFlattener : public TextureLoweringBase {
size_t axis = DefaultTextureLayoutSeparator(op->bounds.size(), storage_scope);
auto texture =
ApplyTexture2DFlattening<PrimExpr>(ShapeFromRange{op->bounds}, op->bounds.size(), axis);
Array<PrimExpr> args = {texture.width, texture.height};
stmt = LetStmt(buffer_var, Call(buffer_var.dtype(), builtin::texture2d_alloca(), args), body);
Array<PrimExpr> args;
args.push_back(StringImm(storage_scope));
args.push_back(IntImm(DataType::Int(64), 2)); // 2d
args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(),
{texture.width, texture.height}));
stmt = LetStmt(buffer_var, Call(buffer_var.dtype(), builtin::nd_mem_alloc_with_scope(), args),
body);
}

return stmt;
Expand Down
Loading

0 comments on commit ee516a8

Please sign in to comment.