Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use convert-to-llvm pass for conversions to LLVMIR Dialect #1679

Open
wants to merge 14 commits into
base: target_attr
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ namespace mlir {
/// Create a pass that performs dialect conversion to LLVM for all dialects
/// implementing `ConvertToLLVMPatternInterface`.
std::unique_ptr<Pass> createConvertToLLVMPass();
std::unique_ptr<Pass> createConvertToLLVMPass(unsigned indexBitwidth,
bool useBarePtrCallConv);

/// Register the extension that will load dependent dialects for LLVM
/// conversion. This is useful to implement a pass similar to "convert-to-llvm".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void configureGpuToROCDLConversionLegality(ConversionTarget &target);
/// is configurable.
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
createLowerGpuOpsToROCDLOpsPass(
const std::string &chipset = "gfx900",
const std::string &chipset = "infer",
unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout,
bool useBarePtrCallConv = false,
gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown);
Expand Down
10 changes: 8 additions & 2 deletions external/llvm-project/mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
let options = [
ListOption<"filterDialects", "filter-dialects", "std::string",
"Test conversion patterns of only the specified dialects">,
Option<"useBarePtrCallConv", "use-bare-ptr-call-conv", "bool", "false", "Whether memrefs can be converted to bare ptr">,
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
"Bitwidth of the index type, 0 to use size of machine word">,

];
}

Expand Down Expand Up @@ -589,11 +594,12 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
"ROCDL::ROCDLDialect",
"cf::ControlFlowDialect",
"memref::MemRefDialect",
"ptr::PtrDialect",
];
let options = [
Option<"chipset", "chipset", "std::string",
/*default=*/"\"gfx000\"",
"Chipset that these operations will run on">,
/*default=*/"\"infer\"",
"Chipset that these operations will run on. By default it will infer target from attached Target Attribute on GPU Module">,
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
"Bitwidth of the index type, 0 to use size of machine word">,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def Ptr_SpecAttr : Ptr_Attr<"Spec", "spec"> {
- [Optional] index: bitwidth that should be used when performing index
computations for the type. Setting the field to `kOptionalSpecValue`, means
the field is optional.
- [Optional] llvmAddressSpace : Mapping from AddressSpace of ptr.PtrType's adddress Space to LLVM's address space.
Setting the field to 'kOptionalLLVMAddressSpaceValue`, means the field is optional.

Furthermore, the attribute will verify that all present values are divisible
by 8 (number of bits in a byte), and that `preferred` > `abi`.
Expand All @@ -43,26 +45,28 @@ def Ptr_SpecAttr : Ptr_Attr<"Spec", "spec"> {
```mlir
// Spec for a 64 bit ptr, with a required alignment of 64 bits, but with
// a preferred alignment of 128 bits and an index bitwidth of 64 bits.
#ptr.spec<size = 64, abi = 64, preferred = 128, index = 64>
#ptr.spec<size = 64, abi = 64, preferred = 128, index = 64, llvmAddressSpace = 0>
```
}];
let parameters = (ins
"uint32_t":$size,
"uint32_t":$abi,
"uint32_t":$preferred,
DefaultValuedParameter<"uint32_t", "kOptionalSpecValue">:$index
DefaultValuedParameter<"uint32_t", "kOptionalSpecValue">:$index,
DefaultValuedParameter<"uint32_t", "kOptionalLLVMAddressSpaceValue">:$llvmAddressSpace
);
let skipDefaultBuilders = 1;
let builders = [
AttrBuilder<(ins "uint32_t":$size, "uint32_t":$abi, "uint32_t":$preferred,
CArg<"uint32_t", "kOptionalSpecValue">:$index), [{
return $_get($_ctxt, size, abi, preferred, index);
CArg<"uint32_t", "kOptionalSpecValue">:$index, CArg<"uint32_t", "kOptionalLLVMAddressSpaceValue">:$llvmAddressSpace), [{
return $_get($_ctxt, size, abi, preferred, index, llvmAddressSpace);
}]>
];
let assemblyFormat = "`<` struct(params) `>`";
let extraClassDeclaration = [{
/// Constant for specifying a spec entry is optional.
static constexpr uint32_t kOptionalSpecValue = std::numeric_limits<uint32_t>::max();
static constexpr uint32_t kOptionalLLVMAddressSpaceValue = 0;
}];
let genVerifyDecl = 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_mlir_conversion_library(MLIRConvertToLLVMPass
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRPtrDialect
MLIRPass
MLIRRewrite
MLIRSupport
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Ptr/IR/PtrAttrs.h"
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -69,31 +75,102 @@ class ConvertToLLVMPass

public:
using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
ConvertToLLVMPass() = default;
ConvertToLLVMPass(unsigned indexBitwidth, bool useBarePtrCallConv) {
if (this->indexBitwidth.getNumOccurrences() == 0)
this->indexBitwidth = indexBitwidth;
if (this->useBarePtrCallConv.getNumOccurrences() == 0)
this->useBarePtrCallConv = useBarePtrCallConv;
}

void getDependentDialects(DialectRegistry &registry) const final {
registry.insert<LLVM::LLVMDialect>();
registry.insert<ptr::PtrDialect>();
registry.addExtensions<LoadDependentDialectExtension>();
}

LogicalResult initialize(MLIRContext *context) final {
LogicalResult initialize(MLIRContext *context) final { return success(); }

void runOnOperation() final {
auto *op = getOperation();
auto *context = op->getContext();
StringRef dataLayout;
auto dataLayoutAttr = dyn_cast_or_null<StringAttr>(
op->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
if (dataLayoutAttr)
dataLayout = dataLayoutAttr.getValue();

if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
dataLayout, [this](const Twine &message) {
getOperation()->emitError() << message.str();
}))) {
signalPassFailure();
return;
}

const DataLayoutAnalysis &dataLayoutAnalysis =
getAnalysis<DataLayoutAnalysis>();
LowerToLLVMOptions options(context,
dataLayoutAnalysis.getAtOrAbove(op));
options.useBarePtrCallConv = useBarePtrCallConv;
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(indexBitwidth);
options.dataLayout = llvm::DataLayout(dataLayout);
if (useBarePtrCallConv) {
options.useBarePtrCallConv = true;
}

RewritePatternSet tempPatterns(context);
auto target = std::make_shared<ConversionTarget>(*context);
target->addLegalDialect<LLVM::LLVMDialect>();
auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
auto typeConverter = std::make_shared<LLVMTypeConverter>(context, options);

DenseMap<Attribute, uint64_t> addressSpaceMap;
if (DataLayoutOpInterface iface = dyn_cast<DataLayoutOpInterface>(op)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is another usage of getAtOrAbove()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... no, it isn't, but it could be, if DataLayout (the class) grew a method.

if (DataLayoutSpecInterface dlSpec = iface.getDataLayoutSpec()) {
for (DataLayoutEntryInterface entry : dlSpec.getEntries()) {
ptr::PtrType ptrKey = llvm::dyn_cast_or_null<mlir::ptr::PtrType>(
entry.getKey().get<mlir::Type>());
if (!ptrKey) {
continue;
}
Attribute addressSpace = ptrKey.getMemorySpace();
auto value =
cast<mlir::ptr::SpecAttr>(entry.getValue()).getLlvmAddressSpace();
addressSpaceMap.insert({addressSpace, value});
}
}
typeConverter->addTypeAttributeConversion(
[addressSpaceMap](BaseMemRefType type, Attribute memorySpaceAttr) {
unsigned llvmAddressSpace = 0;
if (addressSpaceMap.contains(memorySpaceAttr)) {
llvmAddressSpace = addressSpaceMap.at(memorySpaceAttr);
}
return IntegerAttr::get(
IntegerType::get(memorySpaceAttr.getContext(), 64),
llvmAddressSpace);
});
}

if (!filterDialects.empty()) {
// Test mode: Populate only patterns from the specified dialects. Produce
// an error if the dialect is not loaded or does not implement the
// interface.
for (std::string &dialectName : filterDialects) {
Dialect *dialect = context->getLoadedDialect(dialectName);
if (!dialect)
return emitError(UnknownLoc::get(context))
<< "dialect not loaded: " << dialectName << "\n";
if (!dialect) {
emitError(UnknownLoc::get(context))
<< "dialect not loaded: " << dialectName << "\n";
signalPassFailure();
}
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface)
return emitError(UnknownLoc::get(context))
<< "dialect does not implement ConvertToLLVMPatternInterface: "
<< dialectName << "\n";
if (!iface) {
emitError(UnknownLoc::get(context))
<< "dialect does not implement ConvertToLLVMPatternInterface: "
<< dialectName << "\n";
signalPassFailure();
}

iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
tempPatterns);
}
Expand All @@ -110,15 +187,10 @@ class ConvertToLLVMPass
tempPatterns);
}
}

this->patterns =
std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
this->target = target;
this->typeConverter = typeConverter;
return success();
}

void runOnOperation() final {
if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
signalPassFailure();
}
Expand All @@ -134,3 +206,8 @@ void mlir::registerConvertToLLVMDependentDialectLoading(
std::unique_ptr<Pass> mlir::createConvertToLLVMPass() {
return std::make_unique<ConvertToLLVMPass>();
}

std::unique_ptr<Pass> mlir::createConvertToLLVMPass(unsigned indexBitwidth,
bool useBarePtrCallConv) {
return std::make_unique<ConvertToLLVMPass>(indexBitwidth, useBarePtrCallConv);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms
MLIRGPUToGPURuntimeTransforms
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRPtrDialect
MLIRMemRefToLLVM
MLIRROCDLDialect
MLIRPass
Expand Down
Loading