Skip to content

Commit

Permalink
[Vulkan] Some fixes of Vulkan codegen
Browse files Browse the repository at this point in the history
- add exp2 dispatch
- handle thread invariant op
- disable smem merge for vulkan
  • Loading branch information
spectrometerHBH authored and junrushao committed Jan 17, 2024
1 parent 5556d56 commit 371fddb
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
spirv::SType ptr_type = builder_->GetPointerType(ele_stype, buffer_val.stype.storage_class);
ICHECK(var_map_.count(buffer_node));
return builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index));
} else if (op->op.same_as(builtin::tvm_thread_invariant())) {
return MakeValue(op->args[0]);
} else {
LOG(FATAL) << "Unresolved call " << op->op;
}
Expand Down
3 changes: 3 additions & 0 deletions src/target/spirv/intrin_rule_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ TVM_REGISTER_OP("tir.fabs")
TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
DispatchGLSLPureIntrin<GLSLstd450Exp>);

TVM_REGISTER_OP("tir.exp2")
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Exp2>);

TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
DispatchGLSLPureIntrin<GLSLstd450Sin>);

Expand Down
5 changes: 5 additions & 0 deletions src/tir/transforms/merge_shared_memory_allocations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,11 @@ namespace transform {
Pass MergeSharedMemoryAllocations() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
// disable this pass for Vulkan
auto target = Target::Current(true);
if (target.defined() && target->kind->name == "vulkan") {
return f;
}
auto* n = f.CopyOnWrite();
n->body = MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem);
return f;
Expand Down
7 changes: 6 additions & 1 deletion src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1705,8 +1705,13 @@ namespace transform {
Pass StorageRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
// disable merge_static_smem for Vulkan
auto target = Target::Current(true);
if (target.defined() && target->kind->name == "vulkan") {
merge_static_smem = false;
}
auto* n = f.CopyOnWrite();
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, !merge_static_smem);
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, merge_static_smem);
// Parameters may not be rewritten, but internal allocations may.
// Vectorization of AllocateConst is currently disabled, as it has
// indexing issues for types that include padding (e.g. int8x3
Expand Down

0 comments on commit 371fddb

Please sign in to comment.