Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hexagon] Generalize builtin for Nd memory alloc with storage scope and add lowering for VTCM / Hexagon #10558

Merged
merged 44 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
6bbe6cc
repurpose texture flatten for vtcm; TIR lowering correct
adstraw Mar 2, 2022
3cb0121
clean up remaining code in texture flatten pass
adstraw Mar 2, 2022
5365fab
add Alloc and FreeTexture, but failing to run over rpc
adstraw Mar 2, 2022
8044dc0
test passing with malloc in the device api
adstraw Mar 3, 2022
82f61be
cleanup
adstraw Mar 3, 2022
c1843ed
fails in very reliable way with memory corruption
adstraw Mar 4, 2022
4cef769
working with non-HexagonBuffer vtcm alloc
adstraw Mar 4, 2022
f13cd4c
cleanup
adstraw Mar 4, 2022
e9ef946
do not pass scope through mem_copy api
adstraw Mar 7, 2022
daac188
[Hexagon] Resolve breakage in test_hexagon/test_cache_read_write
Lunderberg Mar 7, 2022
5ca8970
use HexagonBuffer in Alloc and Free packed funcs
adstraw Mar 7, 2022
8cea1e1
Added comment indicating need for StackSizeChecker::MakeMemCopy.
Lunderberg Mar 7, 2022
05423ea
Merge branch 'hexagon_mem_copy_lower' into hexagon-alloc-nd-lower-vtcm
Lunderberg Mar 7, 2022
a96c062
add AllocVtcmWorkspace and FreeVtcmWorkspace
adstraw Mar 7, 2022
0ecd017
cleanup
adstraw Mar 8, 2022
2c1ee84
Updated unittests to run all contrib/test_hexagon at CI.
Lunderberg Mar 8, 2022
351b0af
create separate vtcm alloc lowering pass and transform
adstraw Mar 8, 2022
0f37782
reset texture_flatten.cc
adstraw Mar 8, 2022
6678e14
comments
adstraw Mar 8, 2022
1088c66
CI bump
Lunderberg Mar 8, 2022
7de3ae0
Fix lint formatting error.
Lunderberg Mar 9, 2022
794dbbf
Updated fix to remove StackSizeChecker entirely.
Lunderberg Mar 9, 2022
c21b254
Merge remote-tracking branch 'lunderberg/hexagon_mem_copy_lower' into…
adstraw Mar 9, 2022
7b06e7c
pass device and type to device api
adstraw Mar 9, 2022
bc372da
Bugfix, verify the precheck's allocations, not own.
Lunderberg Mar 9, 2022
4ff6471
Bugfix, pass context information to the precheck.
Lunderberg Mar 9, 2022
1c23651
pass order and shape to device api
adstraw Mar 9, 2022
7e43cd8
working
adstraw Mar 9, 2022
5132fd6
fix up types and arg passing
adstraw Mar 10, 2022
b28ff9c
pass scope to device api
adstraw Mar 10, 2022
3d28c59
common builtin for texture / vtcm
adstraw Mar 10, 2022
47268c5
add scope to freend api
adstraw Mar 10, 2022
40b0dd5
Merge remote-tracking branch 'lunderberg/hexagon_mem_copy_lower' into…
adstraw Mar 10, 2022
10ee3a5
Merge branch 'main' into hexagon-alloc-nd-lower-vtcm
adstraw Mar 10, 2022
1645238
format and lint
adstraw Mar 10, 2022
caed9f1
fixed missed format error
adstraw Mar 10, 2022
0f59317
restart ci
adstraw Mar 11, 2022
520f517
Merge branch 'main' into hexagon-alloc-nd-lower-vtcm
adstraw Mar 11, 2022
b7d8dd0
fix test random value issue + code review feedback
adstraw Mar 11, 2022
c3a3b30
fix test hang
adstraw Mar 11, 2022
53ce1ee
restructure lower vtcm pass per code review feedback (option a)
adstraw Mar 14, 2022
412857c
format error
adstraw Mar 14, 2022
38cd975
global.vtcm + tvm_stack_make_shape
adstraw Mar 14, 2022
484cc7b
Merge branch 'main' into hexagon-alloc-nd-lower-vtcm
adstraw Mar 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -644,9 +644,9 @@ TVM_DLL const Op& vectorcombine();
*/
TVM_DLL const Op& atomic_add();
/*!
* \brief Create a texture 2d memory allocation
* \brief Create an Nd memory allocation with storage scope
*/
TVM_DLL const Op& texture2d_alloca();
TVM_DLL const Op& nd_mem_alloc_with_scope();

/*!
* \brief Store to texture 2d memory
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,13 @@ TVM_DLL Pass FlattenBuffer();
*/
TVM_DLL Pass TextureFlatten();

/*
* \brief Lower VTCM allocations
*
* \return The Pass
*/
TVM_DLL Pass LowerVtcmAlloc();

/*!
* \brief Implements a Common Subexpression Elimination (CSE) for TIR
* which introduces let-in bindings for duplicated sub-expressions.
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::InjectPrefetch());
pass_list.push_back(tir::transform::TextureFlatten());
pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
pass_list.push_back(tir::transform::LowerVtcmAlloc());
pass_list.push_back(tir::transform::LowerCrossThreadReduction());
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/hexagon/hexagon/hexagon_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>&
DLTensor* tensor = static_cast<DLTensor*>(arg_values[i].v_handle);
buffer_args.emplace_back(i, static_cast<HexagonBuffer*>(tensor->data));
// Assumes a single contiguous allocation
// TODO(Straw): Enable discontiguous allocation after RFC 39 lands
// TODO(Straw): Enable discontiguous allocation
tensor->data = buffer_args.back().second->GetPointer()[0];
}
}
Expand Down
72 changes: 70 additions & 2 deletions src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, int ndim, const int64_t* sh
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;

// Forcing contiguous allocation, for now
// TODO(Straw): Enable discontiguous allocation after RFC 39 lands
// TODO(Straw): Enable discontiguous allocation
size_t nallocs = 1;
size_t nbytes = 1;
for (int i = 0; i < ndim; ++i) {
Expand Down Expand Up @@ -107,7 +107,7 @@ void* HexagonDeviceAPIv2::AllocWorkspace(Device dev, size_t size, DLDataType typ
dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->AllocWorkspace(dev, size));

// Assumes a single contiguous allocation
// TODO(Straw): Enable discontiguous allocation after RFC 39 lands
// TODO(Straw): Enable discontiguous allocation
void* ptr = hexbuf->GetPointer()[0];
workspace_allocations_.insert({ptr, hexbuf});
return ptr;
Expand All @@ -122,6 +122,20 @@ void HexagonDeviceAPIv2::FreeWorkspace(Device dev, void* data) {
workspace_allocations_.erase(it);
}

void* HexagonDeviceAPIv2::AllocVtcmWorkspace(Device dev, int ndim, const int64_t* shape,
DLDataType dtype, Optional<String> mem_scope) {
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;
// Forcing contiguous allocation, for now
// TODO(Straw): Enable discontiguous allocation
CHECK_EQ(ndim, 1);
return AllocDataSpace(dev, ndim, shape, dtype, mem_scope);
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
}

void HexagonDeviceAPIv2::FreeVtcmWorkspace(Device dev, void* ptr) {
CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type;
FreeDataSpace(dev, ptr);
}

void HexagonDeviceAPIv2::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) {
CHECK_EQ(from->byte_offset, 0);
CHECK_EQ(to->byte_offset, 0);
Expand Down Expand Up @@ -166,6 +180,60 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVM
*rv = static_cast<int32_t>(0);
});

std::map<void*, HexagonBuffer*> vtcmallocs;

TVM_REGISTER_GLOBAL("device_api.hexagon.AllocNd").set_body([](TVMArgs args, TVMRetValue* rv) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using a lambda for the body, can we move this to a separate function? It tends to make debugging easier later on, and becomes a template for a generalized function we can propose adding to c_runtime_api.cc.

Copy link
Contributor Author

@adstraw adstraw Mar 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR creates a separate function called AllocVtcmWorkspace in the Hexagon Device API that is called from the lambda and which mirrors AllocTextureWorkspace in the OpenCL device API. These two APIs are the template for a more generalized Device API to "allocate a workspace with storage scope" but this should occur in a follow-up PR, in my opinion.

Note that the lambda must construct Device and DLDataType arguments for AllocVtcmWorkspace from primitive types passed through the builtin before making the call to AllocVtcmWorkspace.

int32_t device_type = args[0];
int32_t device_id = args[1];
int32_t dtype_code_hint = args[2];
int32_t dtype_bits_hint = args[3];
std::string scope = args[4];
CHECK(scope.find("global.vtcm") != std::string::npos);
int64_t ndim = args[5];
// Forcing contiguous allocation, for now
// TODO(Straw): Enable discontiguous allocation
CHECK_EQ(ndim, 1);
int64_t* shape = static_cast<int64_t*>(static_cast<void*>(args[6]));

Device dev;
dev.device_type = static_cast<DLDeviceType>(device_type);
dev.device_id = device_id;

DLDataType type_hint;
type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1;

HexagonDeviceAPIv2* hexapi = HexagonDeviceAPIv2::Global();
HexagonBuffer* hexbuf = reinterpret_cast<HexagonBuffer*>(
hexapi->AllocVtcmWorkspace(dev, ndim, shape, type_hint, String(scope)));

// Assumes a single contiguous allocation
// TODO(Straw): Enable discontiguous allocation
void* ptr = hexbuf->GetPointer()[0];
vtcmallocs[ptr] = hexbuf;
*rv = ptr;
});

TVM_REGISTER_GLOBAL("device_api.hexagon.FreeNd").set_body([](TVMArgs args, TVMRetValue* rv) {
adstraw marked this conversation as resolved.
Show resolved Hide resolved
int32_t device_type = args[0];
int32_t device_id = args[1];
std::string scope = args[2];
CHECK(scope.find("vtcm") != std::string::npos);
void* ptr = args[3];
CHECK(vtcmallocs.find(ptr) != vtcmallocs.end());

HexagonBuffer* hexbuf = vtcmallocs[ptr];

Device dev;
dev.device_type = static_cast<DLDeviceType>(device_type);
dev.device_id = device_id;

HexagonDeviceAPIv2* hexapi = HexagonDeviceAPIv2::Global();
hexapi->FreeVtcmWorkspace(dev, hexbuf);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.v2").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = HexagonDeviceAPIv2::Global();
*rv = static_cast<void*>(ptr);
Expand Down
17 changes: 17 additions & 0 deletions src/runtime/hexagon/hexagon/hexagon_device_api_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@

#include <tvm/runtime/device_api.h>

#include <map>
#include <string>
#include <unordered_map>
#include <vector>

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -82,6 +85,20 @@ class HexagonDeviceAPIv2 final : public DeviceAPI {
void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype,
Optional<String> mem_scope) final;

/*!
* \brief Allocate an Nd VTCM workspace.
* \param dev The device to perform the operation.
* \param ndim The number of dimensions of allocated tensor.
* \param shape The shape of allocated tensor.
* \param dtype The element type.
* \return The allocated HexagonBuffer pointer.
*/
void* AllocVtcmWorkspace(Device dev, int ndim, const int64_t* shape, DLDataType dtype,
Optional<String> mem_scope);

//! \brief Free the allocated Nd VTCM workspace.
void FreeVtcmWorkspace(Device dev, void* ptr);

/*!
* \brief Copy data from one storage to another.
* \note This API is designed to support special memory with shape dependent layout.
Expand Down
30 changes: 19 additions & 11 deletions src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,19 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
initialized_ = true;
}

TVM_REGISTER_GLOBAL("device_api.opencl.AllocTexture").set_body([](TVMArgs args, TVMRetValue* rv) {
int device_type = args[0];
int device_id = args[1];
int width = args[2];
int height = args[3];
int dtype_code_hint = args[4];
int dtype_bits_hint = args[5];
TVM_REGISTER_GLOBAL("device_api.opencl.AllocNd").set_body([](TVMArgs args, TVMRetValue* rv) {
int32_t device_type = args[0];
int32_t device_id = args[1];
int32_t dtype_code_hint = args[2];
int32_t dtype_bits_hint = args[3];
std::string scope = args[4];
CHECK(scope.find("texture") != std::string::npos);
int64_t ndim = args[5];
CHECK_EQ(ndim, 2);
int64_t* shape = static_cast<int64_t*>(static_cast<void*>(args[6]));
int64_t width = shape[0];
int64_t height = shape[1];

Device dev;
dev.device_type = static_cast<DLDeviceType>(device_type);
dev.device_id = device_id;
Expand All @@ -459,10 +465,12 @@ TVM_REGISTER_GLOBAL("device_api.opencl.AllocTexture").set_body([](TVMArgs args,
type_hint);
});

TVM_REGISTER_GLOBAL("device_api.opencl.FreeTexture").set_body([](TVMArgs args, TVMRetValue* rv) {
int device_type = args[0];
int device_id = args[1];
void* data = args[2];
TVM_REGISTER_GLOBAL("device_api.opencl.FreeNd").set_body([](TVMArgs args, TVMRetValue* rv) {
int32_t device_type = args[0];
int32_t device_id = args[1];
std::string scope = args[2];
CHECK(scope.find("texture") != std::string::npos);
void* data = args[3];
OpenCLWorkspace* ptr = OpenCLWorkspace::Global();
Device dev;
dev.device_type = static_cast<DLDeviceType>(device_type);
Expand Down
2 changes: 1 addition & 1 deletion src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ TIR_DEFINE_BUILTIN_FUNC(vectorcombine)
TIR_DEFINE_BUILTIN_FUNC(atomic_add)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(texture2d_alloca)
TIR_DEFINE_BUILTIN_FUNC(nd_mem_alloc_with_scope)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(texture2d_store)
Expand Down
34 changes: 21 additions & 13 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ class BuiltinLower : public StmtExprMutator {

Stmt VisitStmt_(const LetStmtNode* op) final {
if (const CallNode* call = op->value.as<CallNode>()) {
if (call->op.same_as(builtin::texture2d_alloca())) {
return StmtExprMutator::VisitStmt(MakeTextureAlloc(op, call));
if (call->op.same_as(builtin::nd_mem_alloc_with_scope())) {
return StmtExprMutator::VisitStmt(MakeNdMemAllocWithScope(op, call));
}
}
return StmtExprMutator::VisitStmt_(op);
Expand Down Expand Up @@ -459,32 +459,40 @@ class BuiltinLower : public StmtExprMutator {
return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args);
}

Stmt MakeTextureAlloc(const LetStmtNode* let, const CallNode* call) {
Stmt MakeNdMemAllocWithScope(const LetStmtNode* let, const CallNode* call) {
ICHECK(device_type_.defined()) << "Unknown device type in current IR";
ICHECK(device_id_.defined()) << "Unknown device id in current IR";
Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}));

Stmt body = SeqStmt(
{IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), throw_last_error),
let->body});

DataType dtype =
let->var->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>()->dtype;

std::string fdevapi_prefix = "device_api.";
fdevapi_prefix += runtime::DeviceName(device_type_.as<IntImmNode>()->value);
Call call_packed =
Call(let->var.dtype(), builtin::tvm_call_packed(),
{StringImm(fdevapi_prefix + ".AllocTexture"), cast(DataType::Int(32), device_type_),
cast(DataType::Int(32), device_id_), cast(DataType::UInt(64), call->args[0]),
cast(DataType::UInt(64), call->args[1]), IntImm(DataType::Int(32), dtype.code()),
IntImm(DataType::Int(32), dtype.bits())});

Array<PrimExpr> args = {
StringImm(fdevapi_prefix + ".AllocNd"),
device_type_,
device_id_,
IntImm(DataType::Int(32), dtype.code()),
IntImm(DataType::Int(32), dtype.bits()),
};

for (size_t i = 0; i < call->args.size(); ++i) {
args.push_back(call->args[i]);
}

Call call_packed = Call(let->var.dtype(), builtin::tvm_call_packed(), args);
Stmt alloca = LetStmt(let->var, call_packed, body);

Call free_op =
Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(fdevapi_prefix + ".FreeTexture"), cast(DataType::Int(32), device_type_),
cast(DataType::Int(32), device_id_), let->var});
PrimExpr storage_scope = call->args[0];
Call free_op = Call(
DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(fdevapi_prefix + ".FreeNd"), device_type_, device_id_, storage_scope, let->var});
adstraw marked this conversation as resolved.
Show resolved Hide resolved

Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
body = SeqStmt({alloca, free_stmt});
Expand Down
80 changes: 80 additions & 0 deletions src/tir/transforms/lower_vtcm_alloc.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/transform.h>

#include "../../arith/ir_visitor_with_analyzer.h"

namespace tvm {
namespace tir {

inline bool IsVtcmStorage(std::string scope) {
return scope.find("global.vtcm") != std::string::npos;
}

class VtcmAllocator : public StmtExprMutator {
public:
using StmtExprMutator::VisitStmt_;
VtcmAllocator() {}

Stmt VisitStmt_(const AllocateNode* op) final {
std::string storage_scope = GetStorageScope(op->buffer_var);
if (IsVtcmStorage(storage_scope)) {
Stmt body = this->VisitStmt(op->body);
adstraw marked this conversation as resolved.
Show resolved Hide resolved
Array<PrimExpr> args;
args.push_back(StringImm(storage_scope));
args.push_back(IntImm(DataType::Int(64), op->extents.size()));
args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), op->extents));
return LetStmt(op->buffer_var,
Call(op->buffer_var.dtype(), builtin::nd_mem_alloc_with_scope(), args), body);
}
return StmtExprMutator::VisitStmt_(op);
}

protected:
std::string GetStorageScope(const Var& var) {
auto* ptr = var->type_annotation.as<PointerTypeNode>();
ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType";
return ptr->storage_scope;
}
};

PrimFunc LowerVtcmAlloc(PrimFunc func) {
auto fptr = func.CopyOnWrite();
fptr->body = VtcmAllocator()(std::move(fptr->body));
return func;
}

namespace transform {

Pass LowerVtcmAlloc() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return LowerVtcmAlloc(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerVtcmAlloc", {});
}

TVM_REGISTER_GLOBAL("tir.transform.LowerVtcmAlloc").set_body_typed(LowerVtcmAlloc);

} // namespace transform

} // namespace tir
} // namespace tvm
9 changes: 7 additions & 2 deletions src/tir/transforms/texture_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,13 @@ class TextureFlattener : public TextureLoweringBase {
size_t axis = DefaultTextureLayoutSeparator(op->bounds.size(), storage_scope);
auto texture =
ApplyTexture2DFlattening<PrimExpr>(ShapeFromRange{op->bounds}, op->bounds.size(), axis);
Array<PrimExpr> args = {texture.width, texture.height};
stmt = LetStmt(buffer_var, Call(buffer_var.dtype(), builtin::texture2d_alloca(), args), body);
Array<PrimExpr> args;
args.push_back(StringImm(storage_scope));
args.push_back(IntImm(DataType::Int(64), 2)); // 2d
args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(),
{texture.width, texture.height}));
stmt = LetStmt(buffer_var, Call(buffer_var.dtype(), builtin::nd_mem_alloc_with_scope(), args),
body);
}

return stmt;
Expand Down
Loading