Skip to content

Commit

Permalink
[LLVM] Create TBAA information based on the unrelying buffer type (#6046
Browse files Browse the repository at this point in the history
)

Currently, the TBAA information is based on the access type, i.e.
the data type from the load or store instruction. When the same
memory area is accessed with different types, the corresponding
load/store instruction may end up not being aliased to each other.
This could lead to incorrect code being generated.

An example of when such a situation can occur is when two different
buffer_decl's are created for the same buffer:
  ba = buffer_decl(... dtype = 'int16' ...)
  bb = buffer_decl(data = ba.data, dtype = 'int32x32' ...)
Then instructions
  ba[x] = 0
  ... = bb[x]
may be reordered in the final code due to the alias info indicating
that they are not aliased.
  • Loading branch information
Krzysztof Parzyszek authored Jul 13, 2020
1 parent 5f4b9a9 commit af8af0c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
29 changes: 19 additions & 10 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,7 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const PrimExpr& expr) const {
//
// This trick comes from Halide's CodeGen_LLVM
//
void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, PrimExpr index,
DataType type) {
void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, PrimExpr index) {
if (alias_var_set_.count(buffer) != 0) {
// Mark all possibly aliased pointer as same type.
llvm::MDNode* meta = md_tbaa_alias_set_;
Expand Down Expand Up @@ -393,11 +392,21 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, P
}
}
llvm::MDNode* meta = md_tbaa_root_;
std::ostringstream buffer_addr, buffer_type;
std::ostringstream buffer_addr;
buffer_addr << buffer;
meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
buffer_type << type.element_of();

// Extract the underlying type of the allocated buffer.
llvm::Type* buf_type = GetVarValue(buffer)->getType()->getScalarType();
if (buf_type->isPointerTy()) {
buf_type = buf_type->getPointerElementType();
}

std::string tmp;
llvm::raw_string_ostream buffer_type(tmp);
buffer_type << *buf_type;
meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);

// create a tree-shape access structure.
if (width != 0) {
for (int64_t w = 1024; w >= width; w /= 2) {
Expand Down Expand Up @@ -1030,7 +1039,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
#else
llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
#endif
AddAliasInfo(load, op->buffer_var.get(), op->index, t);
AddAliasInfo(load, op->buffer_var.get(), op->index);
return load;
} else {
// vector load
Expand All @@ -1048,7 +1057,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
#else
llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
#endif
AddAliasInfo(load, op->buffer_var.get(), op->index, t);
AddAliasInfo(load, op->buffer_var.get(), op->index);
return load;
}
}
Expand All @@ -1064,7 +1073,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, basic_align, is_volatile);
#endif
ret = builder_->CreateInsertElement(ret, load, ConstInt32(i));
AddAliasInfo(load, op->buffer_var.get(), PrimExpr(), t);
AddAliasInfo(load, op->buffer_var.get(), PrimExpr());
};
this->Scalarize(op->index, f);
return ret;
Expand Down Expand Up @@ -1148,7 +1157,7 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
#else
llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
#endif
AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.dtype());
AddAliasInfo(store, op->buffer_var.get(), op->index);
return;
} else {
// vector store
Expand All @@ -1166,7 +1175,7 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
#else
llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
#endif
AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.dtype());
AddAliasInfo(store, op->buffer_var.get(), op->index);
return;
}
}
Expand All @@ -1183,7 +1192,7 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
llvm::StoreInst* store = builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i),
ptr, basic_align, is_volatile);
#endif
AddAliasInfo(store, op->buffer_var.get(), PrimExpr(), op->value.dtype());
AddAliasInfo(store, op->buffer_var.get(), PrimExpr());
};
this->Scalarize(op->index, f);
}
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
void CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride,
const Var& loop_var, const Stmt& body);
// add alias information.
void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index, DataType type);
void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index);
// The IRBuilder.
using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
// The current function
Expand Down

0 comments on commit af8af0c

Please sign in to comment.