Skip to content

Commit

Permalink
[LLVM] Use llvm::Align with LLVM 11+ to avoid warnings (#5264)
Browse files Browse the repository at this point in the history
LLVM 11 is introducing a separate class to represent alignment.
The functions in IRBuilder that create aligned loads and stores,
and which accept the alignment as an unsigned value have been
deprecated (and now cause warnings to be emitted).
  • Loading branch information
Krzysztof Parzyszek authored Apr 7, 2020
1 parent e11a609 commit 36ce2e2
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
32 changes: 28 additions & 4 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,11 @@ llvm::GlobalVariable* CodeGenCPU::InitContextPtr(

llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) {
CHECK(gv != nullptr);
#if TVM_LLVM_VERSION >= 110
llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, llvm::Align(gv->getAlignment()));
#else
llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment());
#endif
faddr->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
Expand Down Expand Up @@ -642,7 +646,11 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
*ctx_, "handle_init", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "handle_init_end", function_);
#if TVM_LLVM_VERSION >= 110
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, llvm::Align(align));
#else
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align);
#endif
llvm::Value* handle_not_null = builder_->CreateICmpNE(
handle, llvm::Constant::getNullValue(t_tvm_func_handle_));
builder_->CreateCondBr(
Expand All @@ -652,15 +660,24 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
llvm::Value* out = WithFunctionEntry([&]() {
return builder_->CreateAlloca(t_tvm_func_handle_);
});
#if TVM_LLVM_VERSION >= 110
llvm::LoadInst* ctx = builder_->CreateAlignedLoad(
gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment()));
#else
llvm::LoadInst* ctx = builder_->CreateAlignedLoad(
gv_mod_ctx_, gv_mod_ctx_->getAlignment());
#endif
ctx->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0));
llvm::Value* retcode = builder_->CreateCall(
RuntimeTVMGetFuncFromEnv(), {ctx, GetConstString(fname), out});
init_block = CheckCallSuccess(retcode);
#if TVM_LLVM_VERSION >= 110
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, llvm::Align(align));
#else
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
#endif
// Store the handle
builder_->CreateStore(loaded_handle, hptr);
builder_->CreateBr(end_block);
Expand Down Expand Up @@ -697,10 +714,13 @@ CodeGenCPU::MakeCallPacked(const Array<PrimExpr> &args, llvm::Value **rvalue,
RuntimeTVMFuncCall(), {handle, arg_value, arg_tcode, ConstInt32(nargs),
ret_value, *ret_tcode}));
DataType r_api_type = tir::APIType(r_type);
*rvalue = builder_->CreateAlignedLoad(
builder_->CreatePointerCast(
ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()),
8);
llvm::Value* load_ptr = builder_->CreatePointerCast(
ret_value, DTypeToLLVMType(r_api_type)->getPointerTo());
#if TVM_LLVM_VERSION >= 110
*rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8));
#else
*rvalue = builder_->CreateAlignedLoad(load_ptr, 8);
#endif
*rvalue = CreateCast(r_api_type, r_type, *rvalue);
return end_block;
}
Expand Down Expand Up @@ -732,7 +752,11 @@ llvm::Value *CodeGenCPU::CreateCallTracePacked(const CallNode *op) {
// traced value.
BasicBlock *continue_block =
BasicBlock::Create(*ctx_, "continue_block", function_);
#if TVM_LLVM_VERSION >= 110
llvm::Value *ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8));
#else
llvm::Value *ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8);
#endif
// Check the ret_type_code and create cmp instruction.
llvm::Value *cmp = builder_->CreateICmpNE(
ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr));
Expand Down
31 changes: 31 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,12 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
#if TVM_LLVM_VERSION >= 110
llvm::LoadInst* load =
builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile);
#else
llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
#endif
AddAliasInfo(load, op->buffer_var.get(), op->index, t);
return load;
} else {
Expand All @@ -996,7 +1001,12 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
t.element_of(), buffer, MakeValue(ramp->base));
ptr = builder_->CreatePointerCast(
ptr, DTypeToLLVMType(t)->getPointerTo(addrspace));
#if TVM_LLVM_VERSION >= 110
llvm::LoadInst* load = builder_->CreateAlignedLoad(
ptr, llvm::Align(alignment), is_volatile);
#else
llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
#endif
AddAliasInfo(load, op->buffer_var.get(), op->index, t);
return load;
}
Expand All @@ -1007,8 +1017,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(t));
auto f = [&](int i, llvm::Value* index) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
#if TVM_LLVM_VERSION >= 110
llvm::LoadInst* load = builder_->CreateAlignedLoad(
ptr, llvm::Align(basic_align), is_volatile);
#else
llvm::LoadInst* load = builder_->CreateAlignedLoad(
ptr, basic_align, is_volatile);
#endif
ret = builder_->CreateInsertElement(ret, load, ConstInt32(i));
AddAliasInfo(load, op->buffer_var.get(), PrimExpr(), t);
};
Expand Down Expand Up @@ -1077,7 +1092,12 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
#if TVM_LLVM_VERSION >= 110
llvm::StoreInst* store =
builder_->CreateAlignedStore(value, ptr, llvm::Align(alignment), is_volatile);
#else
llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
#endif
AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.dtype());
return;
} else {
Expand All @@ -1092,7 +1112,12 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
llvm::Value* ptr = CreateBufferPtr(
t.element_of(), buffer, MakeValue(ramp->base));
ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace));
#if TVM_LLVM_VERSION >= 110
llvm::StoreInst* store =
builder_->CreateAlignedStore(value, ptr, llvm::Align(alignment), is_volatile);
#else
llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
#endif
AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.dtype());
return;
}
Expand All @@ -1103,9 +1128,15 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
int basic_align = t.bits() / 8;
auto f = [&](int i, llvm::Value* index) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
#if TVM_LLVM_VERSION >= 110
llvm::StoreInst* store = builder_->CreateAlignedStore(
builder_->CreateExtractElement(value, i),
ptr, llvm::Align(basic_align), is_volatile);
#else
llvm::StoreInst* store = builder_->CreateAlignedStore(
builder_->CreateExtractElement(value, i),
ptr, basic_align, is_volatile);
#endif
AddAliasInfo(store, op->buffer_var.get(), PrimExpr(), op->value.dtype());
};
this->Scalarize(op->index, f);
Expand Down

0 comments on commit 36ce2e2

Please sign in to comment.