diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 24d139983132b..c42d44fd97274 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -644,9 +644,9 @@ TVM_DLL const Op& vectorcombine(); */ TVM_DLL const Op& atomic_add(); /*! - * \brief Create a texture 2d memory allocation + * \brief Create an Nd memory allocation with storage scope */ -TVM_DLL const Op& texture2d_alloca(); +TVM_DLL const Op& nd_mem_alloc_with_scope(); /*! * \brief Store to texture 2d memory diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 4330c4f7c64a8..24c3cfa78f721 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -459,6 +459,13 @@ TVM_DLL Pass FlattenBuffer(); */ TVM_DLL Pass TextureFlatten(); +/* + * \brief Lower VTCM allocations + * + * \return The Pass + */ +TVM_DLL Pass LowerVtcmAlloc(); + /*! * \brief Implements a Common Subexpression Elimination (CSE) for TIR * which introduces let-in bindings for duplicated sub-expressions. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 54126aaa5119e..2d9cdd8912ff6 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -242,6 +242,7 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::InjectPrefetch()); pass_list.push_back(tir::transform::TextureFlatten()); pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + pass_list.push_back(tir::transform::LowerVtcmAlloc()); pass_list.push_back(tir::transform::LowerCrossThreadReduction()); pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); diff --git a/src/runtime/hexagon/hexagon/hexagon_common.cc b/src/runtime/hexagon/hexagon/hexagon_common.cc index 7a94e8c4f9f8c..414def9dee18b 100644 --- a/src/runtime/hexagon/hexagon/hexagon_common.cc +++ b/src/runtime/hexagon/hexagon/hexagon_common.cc @@ -91,7 +91,7 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& DLTensor* tensor = static_cast(arg_values[i].v_handle); buffer_args.emplace_back(i, static_cast(tensor->data)); // Assumes a single contiguous allocation - // TODO(Straw): Enable discontiguous allocation after RFC 39 lands + // TODO(Straw): Enable discontiguous allocation tensor->data = buffer_args.back().second->GetPointer()[0]; } } diff --git a/src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc b/src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc index b6686807ef395..27619eac12dce 100644 --- a/src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc +++ b/src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc @@ -62,7 +62,7 @@ void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, int ndim, const int64_t* sh CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type; // Forcing contiguous allocation, for now - // TODO(Straw): Enable discontiguous allocation after RFC 39 lands + // TODO(Straw): Enable discontiguous allocation size_t nallocs = 1; size_t nbytes = 1; for (int i = 0; i < ndim; ++i) { @@ -107,7 +107,7 @@ void* HexagonDeviceAPIv2::AllocWorkspace(Device dev, size_t size, DLDataType typ dmlc::ThreadLocalStore::Get()->AllocWorkspace(dev, size)); // Assumes a single contiguous allocation - // TODO(Straw): Enable discontiguous allocation after RFC 39 lands + // TODO(Straw): Enable discontiguous allocation void* ptr = hexbuf->GetPointer()[0]; workspace_allocations_.insert({ptr, hexbuf}); return ptr; @@ -122,6 +122,20 @@ void HexagonDeviceAPIv2::FreeWorkspace(Device dev, void* data) { workspace_allocations_.erase(it); } +void* HexagonDeviceAPIv2::AllocVtcmWorkspace(Device dev, int ndim, const int64_t* shape, + DLDataType dtype, Optional mem_scope) { + CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type; + // Forcing contiguous allocation, for now + // TODO(Straw): Enable discontiguous allocation + CHECK_EQ(ndim, 1); + return AllocDataSpace(dev, ndim, shape, dtype, mem_scope); +} + +void HexagonDeviceAPIv2::FreeVtcmWorkspace(Device dev, void* ptr) { + CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type; + FreeDataSpace(dev, ptr); +} + void HexagonDeviceAPIv2::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { CHECK_EQ(from->byte_offset, 0); CHECK_EQ(to->byte_offset, 0); @@ -166,6 +180,60 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVM *rv = static_cast(0); }); +std::map vtcmallocs; + +TVM_REGISTER_GLOBAL("device_api.hexagon.AllocNd").set_body([](TVMArgs args, TVMRetValue* rv) { + int32_t device_type = args[0]; + int32_t device_id = args[1]; + int32_t dtype_code_hint = args[2]; + int32_t dtype_bits_hint = args[3]; + std::string scope = args[4]; + CHECK(scope.find("global.vtcm") != std::string::npos); + int64_t ndim = args[5]; + // Forcing contiguous allocation, for now + // TODO(Straw): Enable discontiguous allocation + CHECK_EQ(ndim, 1); + int64_t* shape = static_cast(static_cast(args[6])); + + Device dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + + DLDataType type_hint; + type_hint.code = static_cast(dtype_code_hint); + type_hint.bits = static_cast(dtype_bits_hint); + type_hint.lanes = 1; + + HexagonDeviceAPIv2* hexapi = HexagonDeviceAPIv2::Global(); + HexagonBuffer* hexbuf = reinterpret_cast( + hexapi->AllocVtcmWorkspace(dev, ndim, shape, type_hint, String(scope))); + + // Assumes a single contiguous allocation + // TODO(Straw): Enable discontiguous allocation + void* ptr = hexbuf->GetPointer()[0]; + vtcmallocs[ptr] = hexbuf; + *rv = ptr; +}); + +TVM_REGISTER_GLOBAL("device_api.hexagon.FreeNd").set_body([](TVMArgs args, TVMRetValue* rv) { + int32_t device_type = args[0]; + int32_t device_id = args[1]; + std::string scope = args[2]; + CHECK(scope.find("vtcm") != std::string::npos); + void* ptr = args[3]; + CHECK(vtcmallocs.find(ptr) != vtcmallocs.end()); + + HexagonBuffer* hexbuf = vtcmallocs[ptr]; + + Device dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + + HexagonDeviceAPIv2* hexapi = HexagonDeviceAPIv2::Global(); + hexapi->FreeVtcmWorkspace(dev, hexbuf); + *rv = static_cast(0); +}); + TVM_REGISTER_GLOBAL("device_api.hexagon.v2").set_body([](TVMArgs args, TVMRetValue* rv) { DeviceAPI* ptr = HexagonDeviceAPIv2::Global(); *rv = static_cast(ptr); diff --git a/src/runtime/hexagon/hexagon/hexagon_device_api_v2.h b/src/runtime/hexagon/hexagon/hexagon_device_api_v2.h index 3d866307f17c1..9e39fc0b0f977 100644 --- a/src/runtime/hexagon/hexagon/hexagon_device_api_v2.h +++ b/src/runtime/hexagon/hexagon/hexagon_device_api_v2.h @@ -22,7 +22,10 @@ #include +#include +#include #include +#include namespace tvm { namespace runtime { @@ -82,6 +85,20 @@ class HexagonDeviceAPIv2 final : public DeviceAPI { void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, Optional mem_scope) final; + /*! + * \brief Allocate an Nd VTCM workspace. + * \param dev The device to perform the operation. + * \param ndim The number of dimensions of allocated tensor. + * \param shape The shape of allocated tensor. + * \param dtype The element type. + * \return The allocated HexagonBuffer pointer. + */ + void* AllocVtcmWorkspace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, + Optional mem_scope); + + //! \brief Free the allocated Nd VTCM workspace. + void FreeVtcmWorkspace(Device dev, void* ptr); + /*! * \brief Copy data from one storage to another. * \note This API is designed to support special memory with shape dependent layout. diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 66561dcdf279f..36bb156c8e9f8 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -438,13 +438,19 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic initialized_ = true; } -TVM_REGISTER_GLOBAL("device_api.opencl.AllocTexture").set_body([](TVMArgs args, TVMRetValue* rv) { - int device_type = args[0]; - int device_id = args[1]; - int width = args[2]; - int height = args[3]; - int dtype_code_hint = args[4]; - int dtype_bits_hint = args[5]; +TVM_REGISTER_GLOBAL("device_api.opencl.AllocNd").set_body([](TVMArgs args, TVMRetValue* rv) { + int32_t device_type = args[0]; + int32_t device_id = args[1]; + int32_t dtype_code_hint = args[2]; + int32_t dtype_bits_hint = args[3]; + std::string scope = args[4]; + CHECK(scope.find("texture") != std::string::npos); + int64_t ndim = args[5]; + CHECK_EQ(ndim, 2); + int64_t* shape = static_cast(static_cast(args[6])); + int64_t width = shape[0]; + int64_t height = shape[1]; + Device dev; dev.device_type = static_cast(device_type); dev.device_id = device_id; @@ -459,10 +465,12 @@ TVM_REGISTER_GLOBAL("device_api.opencl.AllocTexture").set_body([](TVMArgs args, type_hint); }); -TVM_REGISTER_GLOBAL("device_api.opencl.FreeTexture").set_body([](TVMArgs args, TVMRetValue* rv) { - int device_type = args[0]; - int device_id = args[1]; - void* data = args[2]; +TVM_REGISTER_GLOBAL("device_api.opencl.FreeNd").set_body([](TVMArgs args, TVMRetValue* rv) { + int32_t device_type = args[0]; + int32_t device_id = args[1]; + std::string scope = args[2]; + CHECK(scope.find("texture") != std::string::npos); + void* data = args[3]; OpenCLWorkspace* ptr = OpenCLWorkspace::Global(); Device dev; dev.device_type = static_cast(device_type); diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 7d8a997f52e9c..465428e1e8801 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -256,7 +256,7 @@ TIR_DEFINE_BUILTIN_FUNC(vectorcombine) TIR_DEFINE_BUILTIN_FUNC(atomic_add) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(texture2d_alloca) +TIR_DEFINE_BUILTIN_FUNC(nd_mem_alloc_with_scope) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TIR_DEFINE_BUILTIN_FUNC(texture2d_store) diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index e5c45a5a5f04b..8b37a116beeac 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -158,8 +158,8 @@ class BuiltinLower : public StmtExprMutator { Stmt VisitStmt_(const LetStmtNode* op) final { if (const CallNode* call = op->value.as()) { - if (call->op.same_as(builtin::texture2d_alloca())) { - return StmtExprMutator::VisitStmt(MakeTextureAlloc(op, call)); + if (call->op.same_as(builtin::nd_mem_alloc_with_scope())) { + return StmtExprMutator::VisitStmt(MakeNdMemAllocWithScope(op, call)); } } return StmtExprMutator::VisitStmt_(op); @@ -459,7 +459,7 @@ class BuiltinLower : public StmtExprMutator { return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args); } - Stmt MakeTextureAlloc(const LetStmtNode* let, const CallNode* call) { + Stmt MakeNdMemAllocWithScope(const LetStmtNode* let, const CallNode* call) { ICHECK(device_type_.defined()) << "Unknown device type in current IR"; ICHECK(device_id_.defined()) << "Unknown device id in current IR"; Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); @@ -467,24 +467,32 @@ class BuiltinLower : public StmtExprMutator { Stmt body = SeqStmt( {IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), throw_last_error), let->body}); + DataType dtype = let->var->type_annotation.as()->element_type.as()->dtype; std::string fdevapi_prefix = "device_api."; fdevapi_prefix += runtime::DeviceName(device_type_.as()->value); - Call call_packed = - Call(let->var.dtype(), builtin::tvm_call_packed(), - {StringImm(fdevapi_prefix + ".AllocTexture"), cast(DataType::Int(32), device_type_), - cast(DataType::Int(32), device_id_), cast(DataType::UInt(64), call->args[0]), - cast(DataType::UInt(64), call->args[1]), IntImm(DataType::Int(32), dtype.code()), - IntImm(DataType::Int(32), dtype.bits())}); + Array args = { + StringImm(fdevapi_prefix + ".AllocNd"), + device_type_, + device_id_, + IntImm(DataType::Int(32), dtype.code()), + IntImm(DataType::Int(32), dtype.bits()), + }; + + for (size_t i = 0; i < call->args.size(); ++i) { + args.push_back(call->args[i]); + } + + Call call_packed = Call(let->var.dtype(), builtin::tvm_call_packed(), args); Stmt alloca = LetStmt(let->var, call_packed, body); - Call free_op = - Call(DataType::Int(32), builtin::tvm_call_packed(), - {StringImm(fdevapi_prefix + ".FreeTexture"), cast(DataType::Int(32), device_type_), - cast(DataType::Int(32), device_id_), let->var}); + PrimExpr storage_scope = call->args[0]; + Call free_op = Call( + DataType::Int(32), builtin::tvm_call_packed(), + {StringImm(fdevapi_prefix + ".FreeNd"), device_type_, device_id_, storage_scope, let->var}); Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); body = SeqStmt({alloca, free_stmt}); diff --git a/src/tir/transforms/lower_vtcm_alloc.cc b/src/tir/transforms/lower_vtcm_alloc.cc new file mode 100644 index 0000000000000..0b5f7bf1554d6 --- /dev/null +++ b/src/tir/transforms/lower_vtcm_alloc.cc @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include "../../arith/ir_visitor_with_analyzer.h" + +namespace tvm { +namespace tir { + +inline bool IsVtcmStorage(std::string scope) { + return scope.find("global.vtcm") != std::string::npos; +} + +class VtcmAllocator : public StmtExprMutator { + public: + using StmtExprMutator::VisitStmt_; + VtcmAllocator() {} + + Stmt VisitStmt_(const AllocateNode* op) final { + std::string storage_scope = GetStorageScope(op->buffer_var); + if (IsVtcmStorage(storage_scope)) { + Stmt body = this->VisitStmt(op->body); + Array args; + args.push_back(StringImm(storage_scope)); + args.push_back(IntImm(DataType::Int(64), op->extents.size())); + args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), op->extents)); + return LetStmt(op->buffer_var, + Call(op->buffer_var.dtype(), builtin::nd_mem_alloc_with_scope(), args), body); + } + return StmtExprMutator::VisitStmt_(op); + } + + protected: + std::string GetStorageScope(const Var& var) { + auto* ptr = var->type_annotation.as(); + ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; + return ptr->storage_scope; + } +}; + +PrimFunc LowerVtcmAlloc(PrimFunc func) { + auto fptr = func.CopyOnWrite(); + fptr->body = VtcmAllocator()(std::move(fptr->body)); + return func; +} + +namespace transform { + +Pass LowerVtcmAlloc() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return LowerVtcmAlloc(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerVtcmAlloc", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerVtcmAlloc").set_body_typed(LowerVtcmAlloc); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/texture_flatten.cc b/src/tir/transforms/texture_flatten.cc index 7dc8007379444..a607e5914b39e 100644 --- a/src/tir/transforms/texture_flatten.cc +++ b/src/tir/transforms/texture_flatten.cc @@ -115,8 +115,13 @@ class TextureFlattener : public TextureLoweringBase { size_t axis = DefaultTextureLayoutSeparator(op->bounds.size(), storage_scope); auto texture = ApplyTexture2DFlattening(ShapeFromRange{op->bounds}, op->bounds.size(), axis); - Array args = {texture.width, texture.height}; - stmt = LetStmt(buffer_var, Call(buffer_var.dtype(), builtin::texture2d_alloca(), args), body); + Array args; + args.push_back(StringImm(storage_scope)); + args.push_back(IntImm(DataType::Int(64), 2)); // 2d + args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), + {texture.width, texture.height})); + stmt = LetStmt(buffer_var, Call(buffer_var.dtype(), builtin::nd_mem_alloc_with_scope(), args), + body); } return stmt; diff --git a/tests/python/contrib/test_hexagon/test_cache_read_write.py b/tests/python/contrib/test_hexagon/test_cache_read_write.py index fb9b352476bdd..a638d733b0d2f 100644 --- a/tests/python/contrib/test_hexagon/test_cache_read_write.py +++ b/tests/python/contrib/test_hexagon/test_cache_read_write.py @@ -125,9 +125,15 @@ def test_cache_read_write( with launcher.start_session() as sess: mod = launcher.load_module(dso_binary, sess) - xt = tvm.nd.array(np.random.uniform(size=size).astype(x.dtype), device=sess.device) - yt = tvm.nd.array(np.random.uniform(size=size).astype(y.dtype), device=sess.device) - zt = tvm.nd.array(np.random.uniform(size=size).astype(z.dtype), device=sess.device) + xt = tvm.nd.array( + np.random.randint(-128, high=127, size=size, dtype=x.dtype), device=sess.device + ) + yt = tvm.nd.array( + np.random.randint(-128, high=127, size=size, dtype=x.dtype), device=sess.device + ) + zt = tvm.nd.array( + np.random.randint(-128, high=127, size=size, dtype=x.dtype), device=sess.device + ) mod["dmacpy"](xt, yt, zt) launcher.stop_server()